Skip to content

Commit a7dcb11

Browse files
committed
feat(api): expose sampler_mode and 4 DiT params on /release_task HTTP API
Five GenerationParams fields that were only accessible via Gradio's direct Python path are now accepted by the /release_task HTTP endpoint: - sampler_mode: "euler" (default) or "heun" sampler selection - velocity_norm_threshold: velocity prediction norm clamping (0=off) - velocity_ema_factor: velocity EMA smoothing (0=off) - latent_shift: additive shift on DiT latents before VAE decode - latent_rescale: multiplicative rescale on DiT latents before VAE decode All five already exist on the internal GenerationParams dataclass and are wired through the diffusion loop. This change adds them to: - GenerateMusicRequest (Pydantic request model) - PARAM_ALIASES (camelCase alias support) - build_generate_music_request (request builder) - build_generation_setup (GenerationParams wiring) Default values match the existing GenerationParams defaults, so omitting them preserves current behavior. Non-target code paths (Gradio UI, training, CLI) are unchanged.
1 parent 82252c2 commit a7dcb11

4 files changed

Lines changed: 35 additions & 0 deletions

File tree

acestep/api/http/release_task_models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,26 @@ class GenerateMusicRequest(BaseModel):
9797
default=None,
9898
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift.",
9999
)
100+
sampler_mode: str = Field(
101+
default="euler",
102+
description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).",
103+
)
104+
velocity_norm_threshold: float = Field(
105+
default=0.0,
106+
description="Clamp velocity prediction norms during diffusion (0 = disabled, try 2.0).",
107+
)
108+
velocity_ema_factor: float = Field(
109+
default=0.0,
110+
description="Velocity EMA smoothing during diffusion (0 = disabled, try 0.1).",
111+
)
112+
latent_shift: float = Field(
113+
default=0.0,
114+
description="Additive shift applied to DiT latents before VAE decode (0 = no shift).",
115+
)
116+
latent_rescale: float = Field(
117+
default=1.0,
118+
description="Multiplicative rescale applied to DiT latents before VAE decode (1.0 = no rescale).",
119+
)
100120

101121
audio_format: str = Field(
102122
default="mp3",

acestep/api/http/release_task_param_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
"instruction": ["instruction"],
4444
"track_name": ["track_name", "trackName"],
4545
"track_classes": ["track_classes", "trackClasses", "instruments"],
46+
"sampler_mode": ["sampler_mode", "samplerMode"],
47+
"velocity_norm_threshold": ["velocity_norm_threshold", "velocityNormThreshold"],
48+
"velocity_ema_factor": ["velocity_ema_factor", "velocityEmaFactor"],
49+
"latent_shift": ["latent_shift", "latentShift"],
50+
"latent_rescale": ["latent_rescale", "latentRescale"],
4651
}
4752

4853

acestep/api/http/release_task_request_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def build_generate_music_request(
8181
cfg_interval_end=parser.float("cfg_interval_end", 1.0),
8282
infer_method=parser.str("infer_method", "ode"),
8383
shift=parser.float("shift", 3.0),
84+
sampler_mode=parser.str("sampler_mode", "euler"),
85+
velocity_norm_threshold=parser.float("velocity_norm_threshold", 0.0),
86+
velocity_ema_factor=parser.float("velocity_ema_factor", 0.0),
87+
latent_shift=parser.float("latent_shift", 0.0),
88+
latent_rescale=parser.float("latent_rescale", 1.0),
8489
audio_format=parser.str("audio_format", "mp3"),
8590
use_tiled_decode=parser.bool("use_tiled_decode", True),
8691
lm_model_path=parser.str("lm_model_path") or None,

acestep/api/job_generation_setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def build_generation_setup(
174174
cfg_interval_end=req.cfg_interval_end,
175175
shift=req.shift,
176176
infer_method=req.infer_method,
177+
sampler_mode=getattr(req, "sampler_mode", "euler"),
178+
velocity_norm_threshold=getattr(req, "velocity_norm_threshold", 0.0),
179+
velocity_ema_factor=getattr(req, "velocity_ema_factor", 0.0),
180+
latent_shift=getattr(req, "latent_shift", 0.0),
181+
latent_rescale=getattr(req, "latent_rescale", 1.0),
177182
timesteps=parsed_timesteps,
178183
repainting_start=req.repainting_start,
179184
repainting_end=req.repainting_end if req.repainting_end else -1,

0 commit comments

Comments
 (0)