Skip to content

Commit f6b0a2d

Browse files
ENH Small speedups to adapter injection (#2785)
See huggingface/diffusers#11816 (comment) This PR implements two small improvements to the speed of adapter injection. On a benchmark based on the linked issue, the first change leads to a speedup of 21% and the second change of another 3%. It's not that much, but as the changes don't make the code more complicated, there is really no reason not to take them. The optimizations don't add any functional change but are simply based on not recomputing the same values multiple times. Therefore, unless I'm missing something, they should strictly improve runtime.
1 parent f1b8364 commit f6b0a2d

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/peft/tuners/tuners_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -668,10 +668,9 @@ def inject_adapter(
668668
and (len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION)
669669
and (peft_config.peft_type != PeftType.IA3)
670670
):
671+
suffixes = tuple("." + suffix for suffix in peft_config.target_modules)
671672
names_no_target = [
672-
name
673-
for name in key_list
674-
if not any((name == suffix) or name.endswith("." + suffix) for suffix in peft_config.target_modules)
673+
name for name in key_list if (name not in peft_config.target_modules) and not name.endswith(suffixes)
675674
]
676675
new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target)
677676
if len(new_target_modules) < len(peft_config.target_modules):
@@ -681,10 +680,10 @@ def inject_adapter(
681680
# MATCHING & CREATING MODULES #
682681
###############################
683682

684-
existing_adapter_map = {}
683+
existing_adapter_prefixes = []
685684
for key, module in named_modules:
686685
if isinstance(module, BaseTunerLayer):
687-
existing_adapter_map[key] = module
686+
existing_adapter_prefixes.append(key + ".")
688687

689688
# TODO: check if this the most robust way
690689
module_names: set[str] = set()
@@ -698,8 +697,8 @@ def inject_adapter(
698697

699698
# It is possible that we're adding an additional adapter, so if we encounter a key that clearly belongs to a
700699
# previous adapter we can skip here since we don't want to interfere with adapter internals.
701-
for adapter_key in existing_adapter_map:
702-
if key.startswith(adapter_key + "."):
700+
for adapter_key in existing_adapter_prefixes:
701+
if key.startswith(adapter_key):
703702
excluded_modules.append(key)
704703
break
705704

0 commit comments

Comments
 (0)