Skip to content

Commit 06c4c29

Browse files
committed
upd
1 parent 0e28edf commit 06c4c29

File tree

2 files changed

+59
-149
lines changed

2 files changed

+59
-149
lines changed

python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py

Lines changed: 56 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,30 @@ def parallelism_type(self) -> StageParallelismType:
184184
# return StageParallelismType.CFG_PARALLEL if get_global_server_args().enable_cfg_parallel else StageParallelismType.REPLICATED
185185
return StageParallelismType.REPLICATED
186186

187+
def _handle_boundary_ratio(
188+
self,
189+
server_args,
190+
batch,
191+
):
192+
"""
193+
(Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
194+
"""
195+
boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio
196+
if batch.boundary_ratio is not None:
197+
logger.info(
198+
"Overriding boundary ratio from %s to %s",
199+
boundary_ratio,
200+
batch.boundary_ratio,
201+
)
202+
boundary_ratio = batch.boundary_ratio
203+
204+
if boundary_ratio is not None:
205+
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
206+
else:
207+
boundary_timestep = None
208+
209+
return boundary_timestep
210+
187211
def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
188212
"""
189213
Prepare all necessary invariant variables for the denoising loop.
@@ -250,19 +274,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
250274
# Removed Tensor truthiness assert to avoid GPU sync
251275

252276
# (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
253-
boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio
254-
if batch.boundary_ratio is not None:
255-
logger.info(
256-
"Overriding boundary ratio from %s to %s",
257-
boundary_ratio,
258-
batch.boundary_ratio,
259-
)
260-
boundary_ratio = batch.boundary_ratio
261-
262-
if boundary_ratio is not None:
263-
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
264-
else:
265-
boundary_timestep = None
277+
boundary_timestep = self._handle_boundary_ratio(server_args, batch)
266278

267279
# TI2V specific preparations - BEFORE SP sharding
268280
z, z_sp, reserved_frames_masks, reserved_frames_mask_sp, seq_len = (
@@ -363,143 +375,39 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
363375
# Should not happen for TI2V
364376
z_sp = z
365377

366-
# Shard reserved_frames_mask along time dimension to match sharded latents
367-
# reserved_frames_mask is a list from masks_like, extract reserved_frames_mask[0] first
368-
# reserved_frames_mask[0] shape: [C, T, H, W]
369-
# All ranks need their portion of reserved_frames_mask for timestep calculation
370-
if reserved_frames_masks is not None:
371-
reserved_frames_mask = reserved_frames_masks[
372-
0
373-
] # Extract tensor from list
374-
time_dim = reserved_frames_mask.shape[1] # [C, T, H, W]
375-
if time_dim > 0 and time_dim % sp_world_size == 0:
376-
reserved_frames_mask_sp_tensor = rearrange(
377-
reserved_frames_mask,
378-
"c (n t) h w -> c n t h w",
379-
n=sp_world_size,
380-
).contiguous()
381-
reserved_frames_mask_sp_tensor = reserved_frames_mask_sp_tensor[
382-
:, rank_in_sp_group, :, :, :
383-
]
384-
reserved_frames_mask_sp = (
385-
reserved_frames_mask_sp_tensor # Store as tensor, not list
386-
)
378+
# Shard reserved_frames_mask along time dimension to match sharded latents
379+
# reserved_frames_mask is a list from masks_like, extract reserved_frames_mask[0] first
380+
# reserved_frames_mask[0] shape: [C, T, H, W]
381+
# All ranks need their portion of reserved_frames_mask for timestep calculation
382+
if reserved_frames_masks is not None:
383+
reserved_frames_mask = reserved_frames_masks[
384+
0
385+
] # Extract tensor from list
386+
time_dim = reserved_frames_mask.shape[1] # [C, T, H, W]
387+
if time_dim > 0 and time_dim % sp_world_size == 0:
388+
reserved_frames_mask_sp_tensor = rearrange(
389+
reserved_frames_mask,
390+
"c (n t) h w -> c n t h w",
391+
n=sp_world_size,
392+
).contiguous()
393+
reserved_frames_mask_sp_tensor = reserved_frames_mask_sp_tensor[
394+
:, rank_in_sp_group, :, :, :
395+
]
396+
reserved_frames_mask_sp = (
397+
reserved_frames_mask_sp_tensor # Store as tensor, not list
398+
)
399+
else:
400+
reserved_frames_mask_sp = reserved_frames_mask
387401
else:
388-
reserved_frames_mask_sp = reserved_frames_mask
402+
reserved_frames_mask_sp = None
389403
else:
390-
reserved_frames_mask_sp = None
391-
else:
392-
# SP not enabled or latents not sharded
393-
z_sp = z
394-
reserved_frames_mask_sp = (
395-
reserved_frames_masks[0] if reserved_frames_masks is not None else None
396-
) # Extract tensor
397-
398-
return reserved_frames_mask_sp, z_sp
399-
400-
def _handle_boundary_ratio(
401-
self,
402-
server_args,
403-
batch,
404-
):
405-
"""
406-
(Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
407-
"""
408-
boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio
409-
if batch.boundary_ratio is not None:
410-
logger.info(
411-
"Overriding boundary ratio from %s to %s",
412-
boundary_ratio,
413-
batch.boundary_ratio,
414-
)
415-
boundary_ratio = batch.boundary_ratio
416-
417-
if boundary_ratio is not None:
418-
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
419-
else:
420-
boundary_timestep = None
421-
422-
return boundary_timestep
423-
424-
def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
425-
"""
426-
Prepare all necessary invariant variables for the denoising loop.
427-
428-
Args:
429-
batch: The current batch information.
430-
server_args: The inference arguments.
431-
432-
Returns:
433-
A dictionary containing all the prepared variables for the denoising loop.
434-
"""
435-
pipeline = self.pipeline() if self.pipeline else None
436-
if not server_args.model_loaded["transformer"]:
437-
loader = TransformerLoader()
438-
self.transformer = loader.load(
439-
server_args.model_paths["transformer"], server_args
440-
)
441-
if self.server_args.enable_torch_compile:
442-
self.transformer = torch.compile(
443-
self.transformer, mode="max-autotune", fullgraph=True
444-
)
445-
if pipeline:
446-
pipeline.add_module("transformer", self.transformer)
447-
server_args.model_loaded["transformer"] = True
448-
449-
# Prepare extra step kwargs for scheduler
450-
extra_step_kwargs = self.prepare_extra_func_kwargs(
451-
self.scheduler.step,
452-
{"generator": batch.generator, "eta": batch.eta},
453-
)
454-
455-
# Setup precision and autocast settings
456-
target_dtype = torch.bfloat16
457-
autocast_enabled = (
458-
target_dtype != torch.float32
459-
) and not server_args.disable_autocast
460-
461-
# Get timesteps and calculate warmup steps
462-
timesteps = batch.timesteps
463-
if timesteps is None:
464-
raise ValueError("Timesteps must be provided")
465-
num_inference_steps = batch.num_inference_steps
466-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
467-
468-
# Prepare image latents and embeddings for I2V generation
469-
image_embeds = batch.image_embeds
470-
if len(image_embeds) > 0:
471-
image_embeds = [
472-
image_embed.to(target_dtype) for image_embed in image_embeds
473-
]
474-
475-
# Prepare STA parameters
476-
if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
477-
self.prepare_sta_param(batch, server_args)
478-
479-
# Get latents and embeddings
480-
latents = batch.latents
481-
prompt_embeds = batch.prompt_embeds
482-
# Removed Tensor truthiness assert to avoid GPU sync
483-
neg_prompt_embeds = None
484-
if batch.do_classifier_free_guidance:
485-
neg_prompt_embeds = batch.negative_prompt_embeds
486-
assert neg_prompt_embeds is not None
487-
# Removed Tensor truthiness assert to avoid GPU sync
488-
489-
boundary_timestep = self._handle_boundary_ratio(server_args, batch)
490-
491-
# specifically for Wan2_2_TI2V_5B_Config, not applicable for FastWan2_2_TI2V_5B_Config
492-
should_preprocess_for_wan_ti2v = (
493-
server_args.pipeline_config.task_type == ModelTaskType.TI2V
494-
and batch.condition_image is not None
495-
and type(server_args.pipeline_config) is Wan2_2_TI2V_5B_Config
496-
)
497-
498-
# TI2V specific preparations - before SP sharding
499-
if should_preprocess_for_wan_ti2v:
500-
seq_len, z, reserved_frames_masks = self._preprocess_latents_for_ti2v(
501-
latents, target_dtype, batch, server_args
502-
)
404+
# SP not enabled or latents not sharded
405+
z_sp = z
406+
reserved_frames_mask_sp = (
407+
reserved_frames_masks[0]
408+
if reserved_frames_masks is not None
409+
else None
410+
) # Extract tensor
503411
else:
504412
# TI2V not enabled or SP not enabled
505413
z_sp = z

python/sglang/multimodal_gen/test/slack_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from urllib.parse import urlparse
1010
from urllib.request import urlopen
1111

12+
from sglang.multimodal_gen.runtime.utils.perf_logger import get_git_commit_hash
13+
1214
logging.basicConfig(level=logging.INFO)
1315
logger = logging.getLogger(__name__)
1416

@@ -47,7 +49,7 @@
4749

4850
def _get_status_message(run_id, current_case_id, thread_messages=None):
4951
date_str = datetime.now().strftime("%d/%m")
50-
base_header = f""""🧵 for nightly test of {date_str}*
52+
base_header = f"""🧵 for nightly test of {date_str}*
5153
*Git Revision:* {get_git_commit_hash()}
5254
*GitHub Run ID:* {run_id}
5355
*Total Tasks:* {len(ALL_CASES)}

0 commit comments

Comments
 (0)