@@ -297,6 +297,30 @@ def _postprocess_latents_for_ti2v(self, z, reserved_frames_masks, batch):
297297
298298 return reserved_frames_mask_sp , z_sp
299299
300+ def _handle_boundary_ratio (
301+ self ,
302+ server_args ,
303+ batch ,
304+ ):
305+ """
306+ (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
307+ """
308+ boundary_ratio = server_args .pipeline_config .dit_config .boundary_ratio
309+ if batch .boundary_ratio is not None :
310+ logger .info (
311+ "Overriding boundary ratio from %s to %s" ,
312+ boundary_ratio ,
313+ batch .boundary_ratio ,
314+ )
315+ boundary_ratio = batch .boundary_ratio
316+
317+ if boundary_ratio is not None :
318+ boundary_timestep = boundary_ratio * self .scheduler .num_train_timesteps
319+ else :
320+ boundary_timestep = None
321+
322+ return boundary_timestep
323+
300324 def _prepare_denoising_loop (self , batch : Req , server_args : ServerArgs ):
301325 """
302326 Prepare all necessary invariant variables for the denoising loop.
@@ -362,20 +386,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
362386 assert neg_prompt_embeds is not None
363387 # Removed Tensor truthiness assert to avoid GPU sync
364388
365- # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
366- boundary_ratio = server_args .pipeline_config .dit_config .boundary_ratio
367- if batch .boundary_ratio is not None :
368- logger .info (
369- "Overriding boundary ratio from %s to %s" ,
370- boundary_ratio ,
371- batch .boundary_ratio ,
372- )
373- boundary_ratio = batch .boundary_ratio
374-
375- if boundary_ratio is not None :
376- boundary_timestep = boundary_ratio * self .scheduler .num_train_timesteps
377- else :
378- boundary_timestep = None
389+ boundary_timestep = self ._handle_boundary_ratio (server_args , batch )
379390
380391 # specifically for Wan2_2_TI2V_5B_Config, not applicable for FastWan2_2_TI2V_5B_Config
381392 should_preprocess_for_wan_ti2v = (
0 commit comments