Skip to content

⚡️ Speed up function state_dict_prefix_replace by 127% #7743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

aseembits93
Copy link

@aseembits93 aseembits93 commented Apr 23, 2025

📄 127% (1.27x) speedup for state_dict_prefix_replace in comfy/utils.py

⏱️ Runtime : 1.61 millisecond 710 microseconds (best of 398 runs)

📝 Explanation and details

Here's an optimized version of your Python function. The primary changes are to minimize the creation of intermediate lists and to use dictionary comprehensions for more efficient data manipulation.

Changes and Optimizations

  1. Avoid Unneeded List Creation:

    • Instead of mapping and filtering the keys in a separate step (map and filter), it is done directly in the list comprehension.
  2. Dictionary Comprehension:

    • By directly assigning out to {} or state_dict, it forgoes unnecessary intermediate steps in the conditional initialization.
  3. In-Loop Item Assignment.

    • Keys to be replaced and corresponding operations are now handled directly within loops, reducing intermediate variable assignments.

This rewritten function should perform better, especially with large dictionaries, due to reduced overhead from list operations and more efficient key manipulation.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 29 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
import logging

# imports
import pytest  # used for our unit tests
import torch
from _codecs import encode
from comfy.utils import state_dict_prefix_replace
from numpy import dtype
from numpy.core.multiarray import scalar
from numpy.dtypes import Float64DType

# unit tests

def test_basic_single_prefix_replacement():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2}
    replace_prefix = {"layer1.": "layer2."}
    expected_output = {"layer2.weight": 1, "layer2.bias": 2}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_basic_multiple_prefix_replacement():
    state_dict = {"layer1.weight": 1, "layer2.bias": 2}
    replace_prefix = {"layer1.": "new_layer1.", "layer2.": "new_layer2."}
    expected_output = {"new_layer1.weight": 1, "new_layer2.bias": 2}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_no_replacement_needed():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2}
    replace_prefix = {"layer2.": "layer3."}
    expected_output = {"layer1.weight": 1, "layer1.bias": 2}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_empty_prefix_dictionary():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2}
    replace_prefix = {}
    expected_output = {"layer1.weight": 1, "layer1.bias": 2}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_empty_state_dictionary():
    state_dict = {}
    replace_prefix = {"layer1.": "layer2."}
    expected_output = {}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_prefix_overlap():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2, "layer1.weight_extra": 3}
    replace_prefix = {"layer1.": "layer2.", "layer1.weight": "layer3.weight"}
    expected_output = {"layer2.weight": 1, "layer2.bias": 2, "layer2.weight_extra": 3}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_filter_keys_enabled_basic_replacement():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2}
    replace_prefix = {"layer1.": "layer2."}
    expected_output = {"layer2.weight": 1, "layer2.bias": 2}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix, filter_keys=True)

def test_filter_keys_enabled_no_matching_prefix():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2}
    replace_prefix = {"layer2.": "layer3."}
    expected_output = {}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix, filter_keys=True)

def test_large_state_dictionary():
    state_dict = {f"layer{i}.weight": i for i in range(1000)}
    replace_prefix = {"layer": "new_layer"}
    expected_output = {f"new_layer{i}.weight": i for i in range(1000)}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_complex_nested_prefixes():
    state_dict = {f"block{i}.layer{j}.weight": i*j for i in range(10) for j in range(10)}
    replace_prefix = {"block": "module", "layer": "submodule"}
    expected_output = {f"module{i}.submodule{j}.weight": i*j for i in range(10) for j in range(10)}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_partial_matching_prefixes():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2, "layer2.weight": 3}
    replace_prefix = {"layer1.": "layerA.", "layer2.": "layerB."}
    expected_output = {"layerA.weight": 1, "layerA.bias": 2, "layerB.weight": 3}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_overlapping_prefixes():
    state_dict = {"layer1.weight": 1, "layer1.bias": 2, "layer1.weight_extra": 3}
    replace_prefix = {"layer1.": "layer2.", "layer1.weight": "layer3.weight"}
    expected_output = {"layer2.weight": 1, "layer2.bias": 2, "layer2.weight_extra": 3}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)

def test_nested_dictionaries():
    state_dict = {"block1.layer1.weight": 1, "block1.layer1.bias": 2, "block2.layer1.weight": 3}
    replace_prefix = {"block1.layer1.": "blockA.layerA.", "block2.layer1.": "blockB.layerA."}
    expected_output = {"blockA.layerA.weight": 1, "blockA.layerA.bias": 2, "blockB.layerA.weight": 3}
    codeflash_output = state_dict_prefix_replace(state_dict.copy(), replace_prefix)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

import logging

# imports
import pytest  # used for our unit tests
import torch
from _codecs import encode
from comfy.utils import state_dict_prefix_replace
from numpy import dtype
from numpy.core.multiarray import scalar
from numpy.dtypes import Float64DType

# unit tests

# Basic Functionality Tests
def test_single_prefix_replacement():
    state_dict = {'a1': 1, 'a2': 2}
    replace_prefix = {'a': 'b'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

def test_multiple_prefix_replacement():
    state_dict = {'a1': 1, 'b2': 2}
    replace_prefix = {'a': 'c', 'b': 'd'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

# Edge Case Tests
def test_empty_state_dict():
    state_dict = {}
    replace_prefix = {'a': 'b'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

def test_empty_replace_prefix():
    state_dict = {'a1': 1, 'a2': 2}
    replace_prefix = {}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

def test_prefix_not_in_state_dict():
    state_dict = {'a1': 1, 'a2': 2}
    replace_prefix = {'b': 'c'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

# Filtering Keys Tests
def test_filter_keys_true_with_matching_prefix():
    state_dict = {'a1': 1, 'a2': 2}
    replace_prefix = {'a': 'b'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True); result = codeflash_output

def test_filter_keys_true_with_non_matching_prefix():
    state_dict = {'a1': 1, 'a2': 2}
    replace_prefix = {'b': 'c'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True); result = codeflash_output

# Complex Scenarios Tests
def test_nested_prefix_replacement():
    state_dict = {'a1': 1, 'a2': 2, 'b3': 3}
    replace_prefix = {'a': 'b', 'b': 'c'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

def test_overlapping_prefixes():
    state_dict = {'a1': 1, 'ab2': 2}
    replace_prefix = {'a': 'b', 'ab': 'c'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

# Large Scale Test Cases
def test_large_state_dict():
    state_dict = {f'a{i}': i for i in range(1000)}
    replace_prefix = {'a': 'b'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output
    expected = {f'b{i}': i for i in range(1000)}

def test_large_replace_prefix_dict():
    state_dict = {'a1': 1, 'b2': 2}
    replace_prefix = {f'a{i}': f'b{i}' for i in range(1000)}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

# Special Characters in Prefixes Tests
def test_special_characters_in_state_dict_keys():
    state_dict = {'a!1': 1, 'a@2': 2}
    replace_prefix = {'a': 'b'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

def test_special_characters_in_replace_prefix():
    state_dict = {'a1': 1, 'a2': 2}
    replace_prefix = {'a': 'b!', 'a': 'b@'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

# Mixed Data Types Tests
def test_mixed_data_types():
    state_dict = {'a1': 1, 'a2': 'value', 'a3': [1, 2, 3]}
    replace_prefix = {'a': 'b'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

# Prefixes with Different Lengths Tests
def test_shorter_prefix_replacement():
    state_dict = {'abc1': 1, 'abc2': 2}
    replace_prefix = {'abc': 'a'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output

def test_longer_prefix_replacement():
    state_dict = {'a1': 1, 'a2': 2}
    replace_prefix = {'a': 'abc'}
    codeflash_output = state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-state_dict_prefix_replace-m9js6ztz and push.

Codeflash

Here's an optimized version of your Python function. The primary changes are to minimize the creation of intermediate lists and to use dictionary comprehensions for more efficient data manipulation.



### Changes and Optimizations

1. **Avoid Unneeded List Creation:** 
   - Instead of mapping and filtering the keys in a separate step (`map` and `filter`), it is done directly in the list comprehension.
   
2. **Dictionary Comprehension**: 
   - By directly assigning `out` to `{}` or `state_dict`, it forgoes unnecessary intermediate steps in the conditional initialization.
   
3. **In-Loop Item Assignment**.
   - Keys to be replaced and corresponding operations are now handled directly within loops, reducing intermediate variable assignments.

This rewritten function should perform better, especially with large dictionaries, due to reduced overhead from list operations and more efficient key manipulation.
comfy/utils.py Outdated
out[x[1]] = w
out = {} if filter_keys else state_dict

for old_prefix, new_prefix in replace_prefix.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code should be skipped if filter_keys is false.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double confirming this as the for rp in replace_prefix: loop is outside the if-else block

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the necessary changes, it's ready for review @ltdrdata

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants