25
25
from ..cache_utils import CacheMixin
26
26
from .controlnet import zero_module
27
27
from ..modeling_outputs import Transformer2DModelOutput
28
+ from ..normalization import AdaLayerNormContinuous , RMSNorm
28
29
from ..modeling_utils import ModelMixin
29
30
from ..transformers .qwenimage_dit import (
30
31
QwenEmbedRope ,
@@ -109,12 +110,16 @@ def __init__(
109
110
110
111
# controlnet_blocks
111
112
self .controlnet_blocks = nn .ModuleList ([])
112
- for _ in range (len ( self . transformer_blocks ) ):
113
+ for _ in range (num_layers ):
113
114
self .controlnet_blocks .append (zero_module (BlockWiseControlBlock (self .inner_dim )))
114
115
self .controlnet_x_embedder = zero_module (
115
116
torch .nn .Linear (in_channels + extra_condition_channels , self .inner_dim )
116
117
)
117
118
119
+ self .norm_out = AdaLayerNormContinuous (self .inner_dim , self .inner_dim , elementwise_affine = False , eps = 1e-6 )
120
+ self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
121
+
122
+
118
123
self .gradient_checkpointing = False
119
124
120
125
@property
@@ -265,10 +270,6 @@ def forward(
265
270
266
271
hidden_states_seq_len = hidden_states .shape [1 ]
267
272
hidden_states = self .img_in (hidden_states )
268
-
269
- # add
270
- hidden_states = hidden_states + self .controlnet_x_embedder (controlnet_cond )
271
-
272
273
temb = self .time_text_embed (timestep , hidden_states )
273
274
274
275
image_rotary_emb = self .pos_embed (img_shapes , txt_seq_lens , device = hidden_states .device )
@@ -277,7 +278,6 @@ def forward(
277
278
encoder_hidden_states = self .txt_norm (encoder_hidden_states )
278
279
encoder_hidden_states = self .txt_in (encoder_hidden_states )
279
280
280
- block_samples = ()
281
281
for index_block , block in enumerate (self .transformer_blocks ):
282
282
if torch .is_grad_enabled () and self .gradient_checkpointing :
283
283
encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
@@ -298,26 +298,33 @@ def forward(
298
298
image_rotary_emb = image_rotary_emb ,
299
299
joint_attention_kwargs = joint_attention_kwargs ,
300
300
)
301
-
301
+
302
302
# controlnet block
303
303
controlnet_block_samples = ()
304
- for block_sample , controlnet_block in zip (block_samples , self .controlnet_blocks ):
305
- block_sample = controlnet_block (block_sample )
306
- controlnet_block_samples = controlnet_block_samples + (block_sample ,)
307
-
308
- # scaling
309
- controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples ]
310
- controlnet_block_samples = None if len (controlnet_block_samples ) == 0 else controlnet_block_samples
311
-
304
+ #running for net
305
+ if controlnet_cond is not None :
306
+ hidden_states_slice = hidden_states [:,:hidden_states_seq_len ].clone ()
307
+ for conditioning in controlnet_cond :
308
+ controlnet_block = self .controlnet_blocks [index_block ]
309
+ sample = controlnet_block (hidden_states_slice , conditioning )
310
+ controlnet_block_samples .append (sample )
311
+ # scaling
312
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples ]
313
+ controlnet_block_samples = None if len (controlnet_block_samples ) == 0 else controlnet_block_samples
314
+
315
+ hidden_states [:, :hidden_states_seq_len ] = hidden_states_slice + controlnet_block_samples
316
+ hidden_states = self .norm_out (hidden_states , controlnet_cond )
317
+ hidden_states = self .proj_out (hidden_states )
318
+ hidden_states = hidden_states [:, :hidden_states_seq_len ]
312
319
if USE_PEFT_BACKEND :
313
320
# remove `lora_scale` from each PEFT layer
314
321
unscale_lora_layers (self , lora_scale )
315
322
316
323
if not return_dict :
317
- return controlnet_block_samples
324
+ return hidden_states
318
325
319
326
return QwenImageBlockControlNetOutput (
320
- controlnet_block_samples = controlnet_block_samples ,
327
+ controlnet_block_samples = hidden_states ,
321
328
)
322
329
323
330
class QwenImageBlockwiseMultiControlNetModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
@@ -350,60 +357,36 @@ def forward(
350
357
joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
351
358
return_dict : bool = True ,
352
359
) -> Union [QwenImageBlockControlNetOutput , Tuple ]:
353
- # ControlNet-Union with multiple conditions
360
+
361
+
362
+ # if len(self.nets) ==1 ControlNet-Union with multiple conditions
354
363
# only load one ControlNet for saving memories
355
- if len (self .nets ) == 1 :
356
- controlnet = self .nets [0 ]
357
-
358
- for i , (image , scale ) in enumerate (zip (controlnet_cond , conditioning_scale )):
359
- block_samples = controlnet (
360
- hidden_states = hidden_states ,
361
- controlnet_cond = image ,
362
- conditioning_scale = scale ,
363
- encoder_hidden_states = encoder_hidden_states ,
364
- encoder_hidden_states_mask = encoder_hidden_states_mask ,
365
- timestep = timestep ,
366
- img_shapes = img_shapes ,
367
- txt_seq_lens = txt_seq_lens ,
368
- joint_attention_kwargs = joint_attention_kwargs ,
369
- return_dict = return_dict ,
370
- )
371
-
372
- # merge samples
373
- if i == 0 :
374
- control_block_samples = block_samples
375
- else :
376
- if block_samples is not None and control_block_samples is not None :
377
- control_block_samples = [
378
- control_block_sample + block_sample
379
- for control_block_sample , block_sample in zip (control_block_samples , block_samples )
380
- ]
381
- # Regular Multi-ControlNets
364
+
365
+ # else Regular Multi-ControlNets
382
366
# load all ControlNets into memories
383
- else :
384
- for i , (image , scale , controlnet ) in enumerate (
385
- zip (controlnet_cond , conditioning_scale , self .nets )
386
- ):
387
- block_samples = controlnet (
388
- hidden_states = hidden_states ,
389
- controlnet_cond = image ,
390
- conditioning_scale = scale ,
391
- timestep = timestep ,
392
- encoder_hidden_states = encoder_hidden_states ,
393
- joint_attention_kwargs = joint_attention_kwargs ,
394
- return_dict = return_dict ,
395
- )
367
+
368
+
369
+ nets_to_use = [self .nets [0 ]] * len (controlnet_cond ) if len (self .nets ) == 1 else self .nets
370
+ controlnet_calls = list (zip (controlnet_cond , conditioning_scale , nets_to_use ))
371
+
372
+
373
+ # Process and merge outputs
374
+ for image , scale , controlnet in controlnet_calls :
375
+ control_block_samples = controlnet (
376
+ hidden_states = hidden_states ,
377
+ controlnet_cond = image ,
378
+ conditioning_scale = scale ,
379
+ timestep = timestep ,
380
+ encoder_hidden_states = encoder_hidden_states ,
381
+ encoder_hidden_states_mask = encoder_hidden_states_mask if len (self .nets ) == 1 else None ,
382
+ img_shapes = img_shapes if len (self .nets ) == 1 else None ,
383
+ txt_seq_lens = txt_seq_lens if len (self .nets ) == 1 else None ,
384
+ joint_attention_kwargs = joint_attention_kwargs ,
385
+ return_dict = return_dict ,
386
+ )
396
387
397
- # merge samples
398
- if i == 0 :
399
- control_block_samples = block_samples
400
- else :
401
- if block_samples is not None and control_block_samples is not None :
402
- control_block_samples = [
403
- control_block_sample + block_sample
404
- for control_block_sample , block_sample in zip (control_block_samples , block_samples )
405
- ]
406
-
388
+
407
389
408
390
return control_block_samples
409
391
392
+
0 commit comments