Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 269d61d

Browse files
virginiafdezvirginiafdez
andauthored
Added function to load the state_dict from the diffusion model into t… (#478)
* Added function to load the state_dict from the diffusion model into the controlnet, informing the user - if required - of matched and unmatched layers. * Modify formatting: removed return statement, return in args description, and formatted the print with f-Strings. * Formatting of the function --------- Co-authored-by: virginiafdez <[email protected]>
1 parent 0db685f commit 269d61d

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

generative/networks/nets/controlnet.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141

4242
from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding
4343

44-
4544
class ControlNetConditioningEmbedding(nn.Module):
4645
"""
4746
Network to encode the conditioning into a latent space.
@@ -121,6 +120,26 @@ def zero_module(module):
121120
nn.init.zeros_(p)
122121
return module
123122

123+
def copy_weights_to_controlnet(controlnet : nn.Module,
124+
diffusion_model: nn.Module,
125+
verbose: bool = True) -> None:
126+
'''
127+
Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output
128+
keys that have matched and those that haven't.
129+
130+
Args:
131+
controlnet: instance of ControlNet
132+
diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet
133+
verbose: if True, the matched and unmatched keys will be printed.
134+
'''
135+
136+
output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False)
137+
if verbose:
138+
dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys]
139+
print(f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:"
140+
f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:"
141+
f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:"
142+
f"\n{'; '.join(output.unexpected_keys)}")
124143

125144
class ControlNet(nn.Module):
126145
"""

0 commit comments

Comments
 (0)