Skip to content

Commit ef2226d

Browse files
committed
pipeline updated
1 parent b4a24ba commit ef2226d

File tree

2 files changed

+62
-71
lines changed

2 files changed

+62
-71
lines changed

src/diffusers/models/controlnets/controlnet_qwenimage_blockwise.py

Lines changed: 51 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..cache_utils import CacheMixin
2626
from .controlnet import zero_module
2727
from ..modeling_outputs import Transformer2DModelOutput
28+
from ..normalization import AdaLayerNormContinuous, RMSNorm
2829
from ..modeling_utils import ModelMixin
2930
from ..transformers.qwenimage_dit import (
3031
QwenEmbedRope,
@@ -109,12 +110,16 @@ def __init__(
109110

110111
# controlnet_blocks
111112
self.controlnet_blocks = nn.ModuleList([])
112-
for _ in range(len(self.transformer_blocks)):
113+
for _ in range(num_layers):
113114
self.controlnet_blocks.append(zero_module(BlockWiseControlBlock(self.inner_dim)))
114115
self.controlnet_x_embedder = zero_module(
115116
torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
116117
)
117118

119+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
120+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
121+
122+
118123
self.gradient_checkpointing = False
119124

120125
@property
@@ -265,10 +270,6 @@ def forward(
265270

266271
hidden_states_seq_len = hidden_states.shape[1]
267272
hidden_states = self.img_in(hidden_states)
268-
269-
# add
270-
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
271-
272273
temb = self.time_text_embed(timestep, hidden_states)
273274

274275
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
@@ -277,7 +278,6 @@ def forward(
277278
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
278279
encoder_hidden_states = self.txt_in(encoder_hidden_states)
279280

280-
block_samples = ()
281281
for index_block, block in enumerate(self.transformer_blocks):
282282
if torch.is_grad_enabled() and self.gradient_checkpointing:
283283
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
@@ -298,26 +298,33 @@ def forward(
298298
image_rotary_emb=image_rotary_emb,
299299
joint_attention_kwargs=joint_attention_kwargs,
300300
)
301-
301+
302302
# controlnet block
303303
controlnet_block_samples = ()
304-
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
305-
block_sample = controlnet_block(block_sample)
306-
controlnet_block_samples = controlnet_block_samples + (block_sample,)
307-
308-
# scaling
309-
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
310-
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
311-
304+
#running for net
305+
if controlnet_cond is not None :
306+
hidden_states_slice = hidden_states[:,:hidden_states_seq_len].clone()
307+
for conditioning in controlnet_cond:
308+
controlnet_block = self.controlnet_blocks[index_block]
309+
sample = controlnet_block(hidden_states_slice, conditioning)
310+
controlnet_block_samples.append(sample)
311+
# scaling
312+
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
313+
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
314+
315+
hidden_states[:, :hidden_states_seq_len] = hidden_states_slice + controlnet_block_samples
316+
hidden_states = self.norm_out(hidden_states, controlnet_cond)
317+
hidden_states = self.proj_out(hidden_states)
318+
hidden_states = hidden_states[:, :hidden_states_seq_len]
312319
if USE_PEFT_BACKEND:
313320
# remove `lora_scale` from each PEFT layer
314321
unscale_lora_layers(self, lora_scale)
315322

316323
if not return_dict:
317-
return controlnet_block_samples
324+
return hidden_states
318325

319326
return QwenImageBlockControlNetOutput(
320-
controlnet_block_samples=controlnet_block_samples,
327+
controlnet_block_samples=hidden_states,
321328
)
322329

323330
class QwenImageBlockwiseMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
@@ -350,60 +357,36 @@ def forward(
350357
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
351358
return_dict: bool = True,
352359
) -> Union[QwenImageBlockControlNetOutput, Tuple]:
353-
# ControlNet-Union with multiple conditions
360+
361+
362+
# if len(self.nets) ==1 ControlNet-Union with multiple conditions
354363
# only load one ControlNet for saving memories
355-
if len(self.nets) == 1:
356-
controlnet = self.nets[0]
357-
358-
for i, (image, scale) in enumerate(zip(controlnet_cond, conditioning_scale)):
359-
block_samples = controlnet(
360-
hidden_states=hidden_states,
361-
controlnet_cond=image,
362-
conditioning_scale=scale,
363-
encoder_hidden_states=encoder_hidden_states,
364-
encoder_hidden_states_mask=encoder_hidden_states_mask,
365-
timestep=timestep,
366-
img_shapes=img_shapes,
367-
txt_seq_lens=txt_seq_lens,
368-
joint_attention_kwargs=joint_attention_kwargs,
369-
return_dict=return_dict,
370-
)
371-
372-
# merge samples
373-
if i == 0:
374-
control_block_samples = block_samples
375-
else:
376-
if block_samples is not None and control_block_samples is not None:
377-
control_block_samples = [
378-
control_block_sample + block_sample
379-
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
380-
]
381-
# Regular Multi-ControlNets
364+
365+
# else Regular Multi-ControlNets
382366
# load all ControlNets into memories
383-
else:
384-
for i, (image, scale, controlnet) in enumerate(
385-
zip(controlnet_cond, conditioning_scale, self.nets)
386-
):
387-
block_samples = controlnet(
388-
hidden_states=hidden_states,
389-
controlnet_cond=image,
390-
conditioning_scale=scale,
391-
timestep=timestep,
392-
encoder_hidden_states=encoder_hidden_states,
393-
joint_attention_kwargs=joint_attention_kwargs,
394-
return_dict=return_dict,
395-
)
367+
368+
369+
nets_to_use = [self.nets[0]] * len(controlnet_cond) if len(self.nets) == 1 else self.nets
370+
controlnet_calls = list(zip(controlnet_cond, conditioning_scale, nets_to_use))
371+
372+
373+
# Process and merge outputs
374+
for image, scale, controlnet in controlnet_calls:
375+
control_block_samples = controlnet(
376+
hidden_states=hidden_states,
377+
controlnet_cond=image,
378+
conditioning_scale=scale,
379+
timestep=timestep,
380+
encoder_hidden_states=encoder_hidden_states,
381+
encoder_hidden_states_mask=encoder_hidden_states_mask if len(self.nets) == 1 else None,
382+
img_shapes=img_shapes if len(self.nets) == 1 else None,
383+
txt_seq_lens=txt_seq_lens if len(self.nets) == 1 else None,
384+
joint_attention_kwargs=joint_attention_kwargs,
385+
return_dict=return_dict,
386+
)
396387

397-
# merge samples
398-
if i == 0:
399-
control_block_samples = block_samples
400-
else:
401-
if block_samples is not None and control_block_samples is not None:
402-
control_block_samples = [
403-
control_block_sample + block_sample
404-
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
405-
]
406-
388+
407389

408390
return control_block_samples
409391

392+

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_blockcontrolnet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,7 @@ def __call__(
721721
height=control_image_.shape[3],
722722
width=control_image_.shape[4],
723723
).to(dtype=prompt_embeds.dtype, device=device)
724-
725-
self.controlnet.img_in(control_image)
724+
control_image_ = self.controlnet.controlnet_x_embedder(control_image_)
726725
control_images.append(control_image_)
727726

728727
control_image = control_images
@@ -787,6 +786,7 @@ def __call__(
787786

788787
# 6. Denoising loop
789788
self.scheduler.set_begin_index(0)
789+
790790
with self.progress_bar(total=num_inference_steps) as progress_bar:
791791
for i, t in enumerate(timesteps):
792792
if self.interrupt:
@@ -795,7 +795,15 @@ def __call__(
795795
self._current_timestep = t
796796
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
797797
timestep = t.expand(latents.shape[0]).to(latents.dtype)
798-
798+
controlnet_keep = []
799+
800+
progress = (num_inference_steps - 1 - i) / max(num_inference_steps - 1, 1)
801+
keeps = [
802+
1.0 - float(progress > s + 1e-4 or progress < e - 1e-4)
803+
for s, e in zip(control_guidance_start, control_guidance_end)
804+
]
805+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageBlockwiseMultiControlNetModel) else keeps)
806+
799807
if isinstance(controlnet_keep[i], list):
800808
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
801809
else:

0 commit comments

Comments
 (0)