Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 8 additions & 29 deletions python/sglang/multimodal_gen/configs/sample/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class SamplingParams:
return_trajectory_latents: bool = False # returns all latents for each timestep
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
# if True, allow user params to override subclass-defined protected fields
override_protected_fields: bool = False
no_override_protected_fields: bool = True
# whether to adjust num_frames for multi-GPU friendly splitting (default: True)
adjust_frames: bool = True

Expand Down Expand Up @@ -290,15 +290,6 @@ def _adjust(
self._set_output_file_name()
self.log(server_args=server_args)

def update(self, source_dict: dict[str, Any]) -> None:
for key, value in source_dict.items():
if hasattr(self, key):
setattr(self, key, value)
else:
logger.exception("%s has no attribute %s", type(self).__name__, key)

self.__post_init__()

@classmethod
def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams":
from sglang.multimodal_gen.registry import get_model_info
Expand Down Expand Up @@ -522,12 +513,11 @@ def add_cli_args(parser: Any) -> Any:
help="Whether to return the decoded trajectory",
)
parser.add_argument(
"--override-protected-fields",
"--no-override-protected-fields",
action="store_true",
default=SamplingParams.override_protected_fields,
default=SamplingParams.no_override_protected_fields,
help=(
"If set, allow user params to override fields defined in subclasses "
"(protected by default)."
"If set, disallow user params to override fields defined in subclasses."
),
)
parser.add_argument(
Expand Down Expand Up @@ -583,32 +573,21 @@ def _merge_with_user_params(self, user_params: "SamplingParams"):
subclass_defined_fields = set(type(self).__annotations__.keys())

# global switch: if True, allow overriding protected fields
allow_override_protected = bool(
user_params.override_protected_fields or self.override_protected_fields
)

# Compare against current instance to avoid constructing a default instance
default_params = SamplingParams()
allow_override_protected = not user_params.no_override_protected_fields

for field in dataclasses.fields(user_params):
field_name = field.name
user_value = getattr(user_params, field_name)
default_value = getattr(default_params, field_name)
default_value = getattr(self, field_name)

# A field is considered user-modified if its value is different from
# the default, with an exception for `output_file_name` which is
# auto-generated with a random component.
is_user_modified = (
user_value != default_value
if field_name != "output_file_name"
else user_params.output_file_path is not None
)
# the default
is_user_modified = user_value != default_value
if is_user_modified and (
allow_override_protected or field_name not in subclass_defined_fields
):
if hasattr(self, field_name):
setattr(self, field_name, user_value)

self.height_not_provided = user_params.height_not_provided
self.width_not_provided = user_params.width_not_provided
self.__post_init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@

import sglang.multimodal_gen.envs as envs
from sglang.multimodal_gen import DiffGenerator
from sglang.multimodal_gen.configs.sample.sampling_params import (
SamplingParams,
generate_request_id,
)
from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams
from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand
from sglang.multimodal_gen.runtime.entrypoints.cli.utils import (
RaiseNotImplementedAction,
Expand Down Expand Up @@ -89,8 +86,7 @@ def maybe_dump_performance(args: argparse.Namespace, server_args, prompt: str, r

def generate_cmd(args: argparse.Namespace):
"""The entry point for the generate command."""
# FIXME(mick): do not hard code
args.request_id = generate_request_id()
args.request_id = "mocked_fake_id_for_offline_generate"

# Auto-enable stage logging if dump path is provided
if args.perf_dump_path:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,30 @@ def _postprocess_latents_for_ti2v(self, z, reserved_frames_masks, batch):

return reserved_frames_mask_sp, z_sp

def _handle_boundary_ratio(
self,
server_args,
batch,
):
"""
(Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
"""
boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio
if batch.boundary_ratio is not None:
logger.info(
"Overriding boundary ratio from %s to %s",
boundary_ratio,
batch.boundary_ratio,
)
boundary_ratio = batch.boundary_ratio

if boundary_ratio is not None:
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
else:
boundary_timestep = None

return boundary_timestep

def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
"""
Prepare all necessary invariant variables for the denoising loop.
Expand Down Expand Up @@ -362,20 +386,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
assert neg_prompt_embeds is not None
# Removed Tensor truthiness assert to avoid GPU sync

# (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio
if batch.boundary_ratio is not None:
logger.info(
"Overriding boundary ratio from %s to %s",
boundary_ratio,
batch.boundary_ratio,
)
boundary_ratio = batch.boundary_ratio

if boundary_ratio is not None:
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
else:
boundary_timestep = None
boundary_timestep = self._handle_boundary_ratio(server_args, batch)

# specifically for Wan2_2_TI2V_5B_Config, not applicable for FastWan2_2_TI2V_5B_Config
should_preprocess_for_wan_ti2v = (
Expand Down
25 changes: 13 additions & 12 deletions python/sglang/multimodal_gen/test/server/testcase_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,19 @@ def from_req_perf_record(
),
),
# === Text and Image to Image (TI2I) ===
# TODO: Timeout with Torch2.9. Add back when it can pass CI
# DiffusionTestCase(
# id="qwen_image_edit_ti2i",
# model_path="Qwen/Qwen-Image-Edit",
# modality="image",
# prompt=None, # not used for editing
# output_size="1024x1536",
# warmup_text=0,
# warmup_edit=1,
# edit_prompt="Convert 2D style to 3D style",
# image_path="https://github.com/lm-sys/lm-sys.github.io/releases/download/test/TI2I_Qwen_Image_Edit_Input.jpg",
# ),
DiffusionTestCase(
"qwen_image_edit_ti2i",
DiffusionServerArgs(
model_path="Qwen/Qwen-Image-Edit",
modality="image",
warmup_text=0,
warmup_edit=1,
),
DiffusionSamplingParams(
prompt="Convert 2D style to 3D style",
image_path="https://github.com/lm-sys/lm-sys.github.io/releases/download/test/TI2I_Qwen_Image_Edit_Input.jpg",
),
),
]

ONE_GPU_CASES_B: list[DiffusionTestCase] = [
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/multimodal_gen/test/slack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

def _get_status_message(run_id, current_case_id, thread_messages=None):
date_str = datetime.now().strftime("%d/%m")
base_header = f""""*🧵 for nightly test of {date_str}*
base_header = f"""🧵 for nightly test of {date_str}*
*Git Revision:* {get_git_commit_hash()}
*GitHub Run ID:* {run_id}
*Total Tasks:* {len(ALL_CASES)}
Expand Down
Loading