From 56dc85f74eeb45fdcc6303704929d212281d7eb5 Mon Sep 17 00:00:00 2001 From: Felix Hummert Date: Sun, 12 Apr 2026 14:44:29 +0200 Subject: [PATCH 1/2] feat(api): expose sampler_mode and 4 DiT params on /release_task HTTP API Five GenerationParams fields that were only accessible via Gradio's direct Python path are now accepted by the /release_task HTTP endpoint: - sampler_mode: "euler" (default) or "heun" sampler selection - velocity_norm_threshold: velocity prediction norm clamping (0=off) - velocity_ema_factor: velocity EMA smoothing (0=off) - latent_shift: additive shift on DiT latents before VAE decode - latent_rescale: multiplicative rescale on DiT latents before VAE decode All five already exist on the internal GenerationParams dataclass and are wired through the diffusion loop. This change adds them to: - GenerateMusicRequest (Pydantic request model) - PARAM_ALIASES (camelCase alias support) - build_generate_music_request (request builder) - build_generation_setup (GenerationParams wiring) Tests added for alias resolution, request builder forwarding, and generation setup wiring (including getattr fallback for older callers). Default values match the existing GenerationParams defaults, so omitting them preserves current behavior. Non-target code paths (Gradio UI, training, CLI) are unchanged. --- acestep/api/http/release_task_models.py | 20 ++++++ acestep/api/http/release_task_param_parser.py | 5 ++ .../http/release_task_param_parser_test.py | 11 +++ .../api/http/release_task_request_builder.py | 5 ++ .../http/release_task_request_builder_test.py | 28 ++++++++ acestep/api/job_generation_setup.py | 5 ++ acestep/api/job_generation_setup_test.py | 68 +++++++++++++++++++ 7 files changed, 142 insertions(+) diff --git a/acestep/api/http/release_task_models.py b/acestep/api/http/release_task_models.py index 87042c01d..1817f3083 100644 --- a/acestep/api/http/release_task_models.py +++ b/acestep/api/http/release_task_models.py @@ -97,6 +97,26 @@ class GenerateMusicRequest(BaseModel): default=None, description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift.", ) + sampler_mode: str = Field( + default="euler", + description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).", + ) + velocity_norm_threshold: float = Field( + default=0.0, + description="Clamp velocity prediction norms during diffusion (0 = disabled, try 2.0).", + ) + velocity_ema_factor: float = Field( + default=0.0, + description="Velocity EMA smoothing during diffusion (0 = disabled, try 0.1).", + ) + latent_shift: float = Field( + default=0.0, + description="Additive shift applied to DiT latents before VAE decode (0 = no shift).", + ) + latent_rescale: float = Field( + default=1.0, + description="Multiplicative rescale applied to DiT latents before VAE decode (1.0 = no rescale).", + ) audio_format: str = Field( default="mp3", diff --git a/acestep/api/http/release_task_param_parser.py b/acestep/api/http/release_task_param_parser.py index 3b94d530f..d83585771 100644 --- a/acestep/api/http/release_task_param_parser.py +++ b/acestep/api/http/release_task_param_parser.py @@ -43,6 +43,11 @@ "instruction": ["instruction"], "track_name": ["track_name", "trackName"], "track_classes": ["track_classes", "trackClasses", "instruments"], + "sampler_mode": ["sampler_mode", "samplerMode"], + "velocity_norm_threshold": ["velocity_norm_threshold", "velocityNormThreshold"], + "velocity_ema_factor": ["velocity_ema_factor", "velocityEmaFactor"], + "latent_shift": ["latent_shift", "latentShift"], + "latent_rescale": ["latent_rescale", "latentRescale"], } diff --git a/acestep/api/http/release_task_param_parser_test.py b/acestep/api/http/release_task_param_parser_test.py index 889475103..ef14857d6 100644 --- a/acestep/api/http/release_task_param_parser_test.py +++ b/acestep/api/http/release_task_param_parser_test.py @@ -65,5 +65,16 @@ def test_non_dict_param_obj_json_is_ignored(self): self.assertEqual("meta-caption", parser.str("prompt")) + def test_sampler_and_dit_param_aliases_are_resolved(self): + """Parser should resolve camelCase aliases for sampler/DiT params.""" + + parser = RequestParser( + {"samplerMode": "heun", "velocityNormThreshold": "2.5", "latentShift": "0.1"} + ) + self.assertEqual("heun", parser.str("sampler_mode")) + self.assertAlmostEqual(2.5, parser.float("velocity_norm_threshold")) + self.assertAlmostEqual(0.1, parser.float("latent_shift")) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/api/http/release_task_request_builder.py b/acestep/api/http/release_task_request_builder.py index a22862a06..eb1a8e4ce 100644 --- a/acestep/api/http/release_task_request_builder.py +++ b/acestep/api/http/release_task_request_builder.py @@ -81,6 +81,11 @@ def build_generate_music_request( cfg_interval_end=parser.float("cfg_interval_end", 1.0), infer_method=parser.str("infer_method", "ode"), shift=parser.float("shift", 3.0), + sampler_mode=parser.str("sampler_mode", "euler"), + velocity_norm_threshold=parser.float("velocity_norm_threshold", 0.0), + velocity_ema_factor=parser.float("velocity_ema_factor", 0.0), + latent_shift=parser.float("latent_shift", 0.0), + latent_rescale=parser.float("latent_rescale", 1.0), audio_format=parser.str("audio_format", "mp3"), use_tiled_decode=parser.bool("use_tiled_decode", True), lm_model_path=parser.str("lm_model_path") or None, diff --git a/acestep/api/http/release_task_request_builder_test.py b/acestep/api/http/release_task_request_builder_test.py index 72a0a0682..c9515eb08 100644 --- a/acestep/api/http/release_task_request_builder_test.py +++ b/acestep/api/http/release_task_request_builder_test.py @@ -135,5 +135,33 @@ def test_build_request_forwards_audio_code_string_and_cover_noise_strength(self) self.assertAlmostEqual(0.6, request.cover_noise_strength) + def test_build_request_forwards_sampler_and_dit_latent_params(self): + """Builder should include sampler_mode and latent post-processing params.""" + + parser = _FakeParser( + { + "sampler_mode": "heun", + "velocity_norm_threshold": 2.0, + "velocity_ema_factor": 0.1, + "latent_shift": 0.05, + "latent_rescale": 1.2, + } + ) + request = build_generate_music_request( + parser=parser, + request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs), + default_dit_instruction="default-instruction", + lm_default_temperature=0.85, + lm_default_cfg_scale=2.5, + lm_default_top_p=0.9, + ) + + self.assertEqual("heun", request.sampler_mode) + self.assertAlmostEqual(2.0, request.velocity_norm_threshold) + self.assertAlmostEqual(0.1, request.velocity_ema_factor) + self.assertAlmostEqual(0.05, request.latent_shift) + self.assertAlmostEqual(1.2, request.latent_rescale) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/api/job_generation_setup.py b/acestep/api/job_generation_setup.py index 34ee446f1..89c05868b 100644 --- a/acestep/api/job_generation_setup.py +++ b/acestep/api/job_generation_setup.py @@ -174,6 +174,11 @@ def build_generation_setup( cfg_interval_end=req.cfg_interval_end, shift=req.shift, infer_method=req.infer_method, + sampler_mode=getattr(req, "sampler_mode", "euler"), + velocity_norm_threshold=getattr(req, "velocity_norm_threshold", 0.0), + velocity_ema_factor=getattr(req, "velocity_ema_factor", 0.0), + latent_shift=getattr(req, "latent_shift", 0.0), + latent_rescale=getattr(req, "latent_rescale", 1.0), timesteps=parsed_timesteps, repainting_start=req.repainting_start, repainting_end=req.repainting_end if req.repainting_end else -1, diff --git a/acestep/api/job_generation_setup_test.py b/acestep/api/job_generation_setup_test.py index 3ec650b94..6c82075bb 100644 --- a/acestep/api/job_generation_setup_test.py +++ b/acestep/api/job_generation_setup_test.py @@ -263,5 +263,73 @@ def test_use_cot_metas_enabled_when_format_has_duration(self) -> None: self.assertTrue(setup.params.use_cot_metas) + def test_build_generation_setup_forwards_sampler_and_dit_latent_params(self) -> None: + """Sampler mode and latent post-processing params should be forwarded.""" + + req = _base_req() + req.sampler_mode = "heun" + req.velocity_norm_threshold = 2.0 + req.velocity_ema_factor = 0.1 + req.latent_shift = 0.05 + req.latent_rescale = 1.2 + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertEqual("heun", setup.params.sampler_mode) + self.assertAlmostEqual(2.0, setup.params.velocity_norm_threshold) + self.assertAlmostEqual(0.1, setup.params.velocity_ema_factor) + self.assertAlmostEqual(0.05, setup.params.latent_shift) + self.assertAlmostEqual(1.2, setup.params.latent_rescale) + + def test_build_generation_setup_defaults_sampler_params_when_missing(self) -> None: + """When req lacks sampler/latent fields, getattr defaults should apply.""" + + req = _base_req() + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertEqual("euler", setup.params.sampler_mode) + self.assertAlmostEqual(0.0, setup.params.velocity_norm_threshold) + self.assertAlmostEqual(0.0, setup.params.velocity_ema_factor) + self.assertAlmostEqual(0.0, setup.params.latent_shift) + self.assertAlmostEqual(1.0, setup.params.latent_rescale) + + if __name__ == "__main__": unittest.main() From de84457334c5fe77825596fd94529f289a258e11 Mon Sep 17 00:00:00 2001 From: Felix Hummert Date: Mon, 13 Apr 2026 13:06:56 +0200 Subject: [PATCH 2/2] test: expand alias coverage to all 5 exposed DiT params Co-Authored-By: Claude Opus 4.6 (1M context) --- acestep/api/http/release_task_param_parser_test.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/acestep/api/http/release_task_param_parser_test.py b/acestep/api/http/release_task_param_parser_test.py index ef14857d6..232489a54 100644 --- a/acestep/api/http/release_task_param_parser_test.py +++ b/acestep/api/http/release_task_param_parser_test.py @@ -69,11 +69,19 @@ def test_sampler_and_dit_param_aliases_are_resolved(self): """Parser should resolve camelCase aliases for sampler/DiT params.""" parser = RequestParser( - {"samplerMode": "heun", "velocityNormThreshold": "2.5", "latentShift": "0.1"} + { + "samplerMode": "heun", + "velocityNormThreshold": "2.5", + "latentShift": "0.1", + "velocityEmaFactor": "0.95", + "latentRescale": "1.2", + } ) self.assertEqual("heun", parser.str("sampler_mode")) self.assertAlmostEqual(2.5, parser.float("velocity_norm_threshold")) self.assertAlmostEqual(0.1, parser.float("latent_shift")) + self.assertAlmostEqual(0.95, parser.float("velocity_ema_factor")) + self.assertAlmostEqual(1.2, parser.float("latent_rescale")) if __name__ == "__main__":