@@ -184,6 +184,30 @@ def parallelism_type(self) -> StageParallelismType:
184184 # return StageParallelismType.CFG_PARALLEL if get_global_server_args().enable_cfg_parallel else StageParallelismType.REPLICATED
185185 return StageParallelismType .REPLICATED
186186
187+ def _handle_boundary_ratio (
188+ self ,
189+ server_args ,
190+ batch ,
191+ ):
192+ """
193+ (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
194+ """
195+ boundary_ratio = server_args .pipeline_config .dit_config .boundary_ratio
196+ if batch .boundary_ratio is not None :
197+ logger .info (
198+ "Overriding boundary ratio from %s to %s" ,
199+ boundary_ratio ,
200+ batch .boundary_ratio ,
201+ )
202+ boundary_ratio = batch .boundary_ratio
203+
204+ if boundary_ratio is not None :
205+ boundary_timestep = boundary_ratio * self .scheduler .num_train_timesteps
206+ else :
207+ boundary_timestep = None
208+
209+ return boundary_timestep
210+
187211 def _prepare_denoising_loop (self , batch : Req , server_args : ServerArgs ):
188212 """
189213 Prepare all necessary invariant variables for the denoising loop.
@@ -250,19 +274,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
250274 # Removed Tensor truthiness assert to avoid GPU sync
251275
252276 # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
253- boundary_ratio = server_args .pipeline_config .dit_config .boundary_ratio
254- if batch .boundary_ratio is not None :
255- logger .info (
256- "Overriding boundary ratio from %s to %s" ,
257- boundary_ratio ,
258- batch .boundary_ratio ,
259- )
260- boundary_ratio = batch .boundary_ratio
261-
262- if boundary_ratio is not None :
263- boundary_timestep = boundary_ratio * self .scheduler .num_train_timesteps
264- else :
265- boundary_timestep = None
277+ boundary_timestep = self ._handle_boundary_ratio (server_args , batch )
266278
267279 # TI2V specific preparations - BEFORE SP sharding
268280 z , z_sp , reserved_frames_masks , reserved_frames_mask_sp , seq_len = (
@@ -363,143 +375,39 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
363375 # Should not happen for TI2V
364376 z_sp = z
365377
366- # Shard reserved_frames_mask along time dimension to match sharded latents
367- # reserved_frames_mask is a list from masks_like, extract reserved_frames_mask[0] first
368- # reserved_frames_mask[0] shape: [C, T, H, W]
369- # All ranks need their portion of reserved_frames_mask for timestep calculation
370- if reserved_frames_masks is not None :
371- reserved_frames_mask = reserved_frames_masks [
372- 0
373- ] # Extract tensor from list
374- time_dim = reserved_frames_mask .shape [1 ] # [C, T, H, W]
375- if time_dim > 0 and time_dim % sp_world_size == 0 :
376- reserved_frames_mask_sp_tensor = rearrange (
377- reserved_frames_mask ,
378- "c (n t) h w -> c n t h w" ,
379- n = sp_world_size ,
380- ).contiguous ()
381- reserved_frames_mask_sp_tensor = reserved_frames_mask_sp_tensor [
382- :, rank_in_sp_group , :, :, :
383- ]
384- reserved_frames_mask_sp = (
385- reserved_frames_mask_sp_tensor # Store as tensor, not list
386- )
378+ # Shard reserved_frames_mask along time dimension to match sharded latents
379+ # reserved_frames_mask is a list from masks_like, extract reserved_frames_mask[0] first
380+ # reserved_frames_mask[0] shape: [C, T, H, W]
381+ # All ranks need their portion of reserved_frames_mask for timestep calculation
382+ if reserved_frames_masks is not None :
383+ reserved_frames_mask = reserved_frames_masks [
384+ 0
385+ ] # Extract tensor from list
386+ time_dim = reserved_frames_mask .shape [1 ] # [C, T, H, W]
387+ if time_dim > 0 and time_dim % sp_world_size == 0 :
388+ reserved_frames_mask_sp_tensor = rearrange (
389+ reserved_frames_mask ,
390+ "c (n t) h w -> c n t h w" ,
391+ n = sp_world_size ,
392+ ).contiguous ()
393+ reserved_frames_mask_sp_tensor = reserved_frames_mask_sp_tensor [
394+ :, rank_in_sp_group , :, :, :
395+ ]
396+ reserved_frames_mask_sp = (
397+ reserved_frames_mask_sp_tensor # Store as tensor, not list
398+ )
399+ else :
400+ reserved_frames_mask_sp = reserved_frames_mask
387401 else :
388- reserved_frames_mask_sp = reserved_frames_mask
402+ reserved_frames_mask_sp = None
389403 else :
390- reserved_frames_mask_sp = None
391- else :
392- # SP not enabled or latents not sharded
393- z_sp = z
394- reserved_frames_mask_sp = (
395- reserved_frames_masks [0 ] if reserved_frames_masks is not None else None
396- ) # Extract tensor
397-
398- return reserved_frames_mask_sp , z_sp
399-
400- def _handle_boundary_ratio (
401- self ,
402- server_args ,
403- batch ,
404- ):
405- """
406- (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
407- """
408- boundary_ratio = server_args .pipeline_config .dit_config .boundary_ratio
409- if batch .boundary_ratio is not None :
410- logger .info (
411- "Overriding boundary ratio from %s to %s" ,
412- boundary_ratio ,
413- batch .boundary_ratio ,
414- )
415- boundary_ratio = batch .boundary_ratio
416-
417- if boundary_ratio is not None :
418- boundary_timestep = boundary_ratio * self .scheduler .num_train_timesteps
419- else :
420- boundary_timestep = None
421-
422- return boundary_timestep
423-
424- def _prepare_denoising_loop (self , batch : Req , server_args : ServerArgs ):
425- """
426- Prepare all necessary invariant variables for the denoising loop.
427-
428- Args:
429- batch: The current batch information.
430- server_args: The inference arguments.
431-
432- Returns:
433- A dictionary containing all the prepared variables for the denoising loop.
434- """
435- pipeline = self .pipeline () if self .pipeline else None
436- if not server_args .model_loaded ["transformer" ]:
437- loader = TransformerLoader ()
438- self .transformer = loader .load (
439- server_args .model_paths ["transformer" ], server_args
440- )
441- if self .server_args .enable_torch_compile :
442- self .transformer = torch .compile (
443- self .transformer , mode = "max-autotune" , fullgraph = True
444- )
445- if pipeline :
446- pipeline .add_module ("transformer" , self .transformer )
447- server_args .model_loaded ["transformer" ] = True
448-
449- # Prepare extra step kwargs for scheduler
450- extra_step_kwargs = self .prepare_extra_func_kwargs (
451- self .scheduler .step ,
452- {"generator" : batch .generator , "eta" : batch .eta },
453- )
454-
455- # Setup precision and autocast settings
456- target_dtype = torch .bfloat16
457- autocast_enabled = (
458- target_dtype != torch .float32
459- ) and not server_args .disable_autocast
460-
461- # Get timesteps and calculate warmup steps
462- timesteps = batch .timesteps
463- if timesteps is None :
464- raise ValueError ("Timesteps must be provided" )
465- num_inference_steps = batch .num_inference_steps
466- num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
467-
468- # Prepare image latents and embeddings for I2V generation
469- image_embeds = batch .image_embeds
470- if len (image_embeds ) > 0 :
471- image_embeds = [
472- image_embed .to (target_dtype ) for image_embed in image_embeds
473- ]
474-
475- # Prepare STA parameters
476- if st_attn_available and self .attn_backend == SlidingTileAttentionBackend :
477- self .prepare_sta_param (batch , server_args )
478-
479- # Get latents and embeddings
480- latents = batch .latents
481- prompt_embeds = batch .prompt_embeds
482- # Removed Tensor truthiness assert to avoid GPU sync
483- neg_prompt_embeds = None
484- if batch .do_classifier_free_guidance :
485- neg_prompt_embeds = batch .negative_prompt_embeds
486- assert neg_prompt_embeds is not None
487- # Removed Tensor truthiness assert to avoid GPU sync
488-
489- boundary_timestep = self ._handle_boundary_ratio (server_args , batch )
490-
491- # specifically for Wan2_2_TI2V_5B_Config, not applicable for FastWan2_2_TI2V_5B_Config
492- should_preprocess_for_wan_ti2v = (
493- server_args .pipeline_config .task_type == ModelTaskType .TI2V
494- and batch .condition_image is not None
495- and type (server_args .pipeline_config ) is Wan2_2_TI2V_5B_Config
496- )
497-
498- # TI2V specific preparations - before SP sharding
499- if should_preprocess_for_wan_ti2v :
500- seq_len , z , reserved_frames_masks = self ._preprocess_latents_for_ti2v (
501- latents , target_dtype , batch , server_args
502- )
404+ # SP not enabled or latents not sharded
405+ z_sp = z
406+ reserved_frames_mask_sp = (
407+ reserved_frames_masks [0 ]
408+ if reserved_frames_masks is not None
409+ else None
410+ ) # Extract tensor
503411 else :
504412 # TI2V not enabled or SP not enabled
505413 z_sp = z
0 commit comments