Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions acestep/api/http/release_task_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,26 @@ class GenerateMusicRequest(BaseModel):
default=None,
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift.",
)
sampler_mode: str = Field(
default="euler",
description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).",
)
velocity_norm_threshold: float = Field(
default=0.0,
description="Clamp velocity prediction norms during diffusion (0 = disabled, try 2.0).",
)
velocity_ema_factor: float = Field(
default=0.0,
description="Velocity EMA smoothing during diffusion (0 = disabled, try 0.1).",
)
latent_shift: float = Field(
default=0.0,
description="Additive shift applied to DiT latents before VAE decode (0 = no shift).",
)
latent_rescale: float = Field(
default=1.0,
description="Multiplicative rescale applied to DiT latents before VAE decode (1.0 = no rescale).",
)

audio_format: str = Field(
default="mp3",
Expand Down
5 changes: 5 additions & 0 deletions acestep/api/http/release_task_param_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
"instruction": ["instruction"],
"track_name": ["track_name", "trackName"],
"track_classes": ["track_classes", "trackClasses", "instruments"],
"sampler_mode": ["sampler_mode", "samplerMode"],
"velocity_norm_threshold": ["velocity_norm_threshold", "velocityNormThreshold"],
"velocity_ema_factor": ["velocity_ema_factor", "velocityEmaFactor"],
"latent_shift": ["latent_shift", "latentShift"],
"latent_rescale": ["latent_rescale", "latentRescale"],
}


Expand Down
11 changes: 11 additions & 0 deletions acestep/api/http/release_task_param_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,16 @@ def test_non_dict_param_obj_json_is_ignored(self):
self.assertEqual("meta-caption", parser.str("prompt"))


def test_sampler_and_dit_param_aliases_are_resolved(self):
"""Parser should resolve camelCase aliases for sampler/DiT params."""

parser = RequestParser(
{"samplerMode": "heun", "velocityNormThreshold": "2.5", "latentShift": "0.1"}
)
self.assertEqual("heun", parser.str("sampler_mode"))
self.assertAlmostEqual(2.5, parser.float("velocity_norm_threshold"))
self.assertAlmostEqual(0.1, parser.float("latent_shift"))

Comment thread
coderabbitai[bot] marked this conversation as resolved.

if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions acestep/api/http/release_task_request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def build_generate_music_request(
cfg_interval_end=parser.float("cfg_interval_end", 1.0),
infer_method=parser.str("infer_method", "ode"),
shift=parser.float("shift", 3.0),
sampler_mode=parser.str("sampler_mode", "euler"),
velocity_norm_threshold=parser.float("velocity_norm_threshold", 0.0),
velocity_ema_factor=parser.float("velocity_ema_factor", 0.0),
latent_shift=parser.float("latent_shift", 0.0),
latent_rescale=parser.float("latent_rescale", 1.0),
audio_format=parser.str("audio_format", "mp3"),
use_tiled_decode=parser.bool("use_tiled_decode", True),
lm_model_path=parser.str("lm_model_path") or None,
Expand Down
28 changes: 28 additions & 0 deletions acestep/api/http/release_task_request_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,33 @@ def test_build_request_forwards_audio_code_string_and_cover_noise_strength(self)
self.assertAlmostEqual(0.6, request.cover_noise_strength)


def test_build_request_forwards_sampler_and_dit_latent_params(self):
"""Builder should include sampler_mode and latent post-processing params."""

parser = _FakeParser(
{
"sampler_mode": "heun",
"velocity_norm_threshold": 2.0,
"velocity_ema_factor": 0.1,
"latent_shift": 0.05,
"latent_rescale": 1.2,
}
)
request = build_generate_music_request(
parser=parser,
request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs),
default_dit_instruction="default-instruction",
lm_default_temperature=0.85,
lm_default_cfg_scale=2.5,
lm_default_top_p=0.9,
)

self.assertEqual("heun", request.sampler_mode)
self.assertAlmostEqual(2.0, request.velocity_norm_threshold)
self.assertAlmostEqual(0.1, request.velocity_ema_factor)
self.assertAlmostEqual(0.05, request.latent_shift)
self.assertAlmostEqual(1.2, request.latent_rescale)


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions acestep/api/job_generation_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ def build_generation_setup(
cfg_interval_end=req.cfg_interval_end,
shift=req.shift,
infer_method=req.infer_method,
sampler_mode=getattr(req, "sampler_mode", "euler"),
velocity_norm_threshold=getattr(req, "velocity_norm_threshold", 0.0),
velocity_ema_factor=getattr(req, "velocity_ema_factor", 0.0),
latent_shift=getattr(req, "latent_shift", 0.0),
latent_rescale=getattr(req, "latent_rescale", 1.0),
timesteps=parsed_timesteps,
repainting_start=req.repainting_start,
repainting_end=req.repainting_end if req.repainting_end else -1,
Expand Down
68 changes: 68 additions & 0 deletions acestep/api/job_generation_setup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,73 @@ def test_use_cot_metas_enabled_when_format_has_duration(self) -> None:
self.assertTrue(setup.params.use_cot_metas)


def test_build_generation_setup_forwards_sampler_and_dit_latent_params(self) -> None:
"""Sampler mode and latent post-processing params should be forwarded."""

req = _base_req()
req.sampler_mode = "heun"
req.velocity_norm_threshold = 2.0
req.velocity_ema_factor = 0.1
req.latent_shift = 0.05
req.latent_rescale = 1.2
setup = build_generation_setup(
req=req,
caption="cap",
lyrics="lyr",
bpm=None,
key_scale="",
time_signature="",
audio_duration=None,
thinking=False,
sample_mode=False,
format_has_duration=False,
use_cot_caption=False,
use_cot_language=False,
lm_top_k=0,
lm_top_p=0.9,
parse_timesteps=lambda _value: None,
is_instrumental=lambda _lyrics: False,
default_dit_instruction="default instruction",
task_instructions={},
)

self.assertEqual("heun", setup.params.sampler_mode)
self.assertAlmostEqual(2.0, setup.params.velocity_norm_threshold)
self.assertAlmostEqual(0.1, setup.params.velocity_ema_factor)
self.assertAlmostEqual(0.05, setup.params.latent_shift)
self.assertAlmostEqual(1.2, setup.params.latent_rescale)

def test_build_generation_setup_defaults_sampler_params_when_missing(self) -> None:
"""When req lacks sampler/latent fields, getattr defaults should apply."""

req = _base_req()
setup = build_generation_setup(
req=req,
caption="cap",
lyrics="lyr",
bpm=None,
key_scale="",
time_signature="",
audio_duration=None,
thinking=False,
sample_mode=False,
format_has_duration=False,
use_cot_caption=False,
use_cot_language=False,
lm_top_k=0,
lm_top_p=0.9,
parse_timesteps=lambda _value: None,
is_instrumental=lambda _lyrics: False,
default_dit_instruction="default instruction",
task_instructions={},
)

self.assertEqual("euler", setup.params.sampler_mode)
self.assertAlmostEqual(0.0, setup.params.velocity_norm_threshold)
self.assertAlmostEqual(0.0, setup.params.velocity_ema_factor)
self.assertAlmostEqual(0.0, setup.params.latent_shift)
self.assertAlmostEqual(1.0, setup.params.latent_rescale)


if __name__ == "__main__":
unittest.main()