diff --git a/acestep/api/http/release_task_models.py b/acestep/api/http/release_task_models.py index 87042c01d..1817f3083 100644 --- a/acestep/api/http/release_task_models.py +++ b/acestep/api/http/release_task_models.py @@ -97,6 +97,26 @@ class GenerateMusicRequest(BaseModel): default=None, description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift.", ) + sampler_mode: str = Field( + default="euler", + description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).", + ) + velocity_norm_threshold: float = Field( + default=0.0, + description="Clamp velocity prediction norms during diffusion (0 = disabled, try 2.0).", + ) + velocity_ema_factor: float = Field( + default=0.0, + description="Velocity EMA smoothing during diffusion (0 = disabled, try 0.1).", + ) + latent_shift: float = Field( + default=0.0, + description="Additive shift applied to DiT latents before VAE decode (0 = no shift).", + ) + latent_rescale: float = Field( + default=1.0, + description="Multiplicative rescale applied to DiT latents before VAE decode (1.0 = no rescale).", + ) audio_format: str = Field( default="mp3", diff --git a/acestep/api/http/release_task_param_parser.py b/acestep/api/http/release_task_param_parser.py index 3b94d530f..d83585771 100644 --- a/acestep/api/http/release_task_param_parser.py +++ b/acestep/api/http/release_task_param_parser.py @@ -43,6 +43,11 @@ "instruction": ["instruction"], "track_name": ["track_name", "trackName"], "track_classes": ["track_classes", "trackClasses", "instruments"], + "sampler_mode": ["sampler_mode", "samplerMode"], + "velocity_norm_threshold": ["velocity_norm_threshold", "velocityNormThreshold"], + "velocity_ema_factor": ["velocity_ema_factor", "velocityEmaFactor"], + "latent_shift": ["latent_shift", "latentShift"], + "latent_rescale": ["latent_rescale", "latentRescale"], } diff --git a/acestep/api/http/release_task_param_parser_test.py b/acestep/api/http/release_task_param_parser_test.py index 889475103..232489a54 100644 --- a/acestep/api/http/release_task_param_parser_test.py +++ b/acestep/api/http/release_task_param_parser_test.py @@ -65,5 +65,24 @@ def test_non_dict_param_obj_json_is_ignored(self): self.assertEqual("meta-caption", parser.str("prompt")) + def test_sampler_and_dit_param_aliases_are_resolved(self): + """Parser should resolve camelCase aliases for sampler/DiT params.""" + + parser = RequestParser( + { + "samplerMode": "heun", + "velocityNormThreshold": "2.5", + "latentShift": "0.1", + "velocityEmaFactor": "0.95", + "latentRescale": "1.2", + } + ) + self.assertEqual("heun", parser.str("sampler_mode")) + self.assertAlmostEqual(2.5, parser.float("velocity_norm_threshold")) + self.assertAlmostEqual(0.1, parser.float("latent_shift")) + self.assertAlmostEqual(0.95, parser.float("velocity_ema_factor")) + self.assertAlmostEqual(1.2, parser.float("latent_rescale")) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/api/http/release_task_request_builder.py b/acestep/api/http/release_task_request_builder.py index a22862a06..eb1a8e4ce 100644 --- a/acestep/api/http/release_task_request_builder.py +++ b/acestep/api/http/release_task_request_builder.py @@ -81,6 +81,11 @@ def build_generate_music_request( cfg_interval_end=parser.float("cfg_interval_end", 1.0), infer_method=parser.str("infer_method", "ode"), shift=parser.float("shift", 3.0), + sampler_mode=parser.str("sampler_mode", "euler"), + velocity_norm_threshold=parser.float("velocity_norm_threshold", 0.0), + velocity_ema_factor=parser.float("velocity_ema_factor", 0.0), + latent_shift=parser.float("latent_shift", 0.0), + latent_rescale=parser.float("latent_rescale", 1.0), audio_format=parser.str("audio_format", "mp3"), use_tiled_decode=parser.bool("use_tiled_decode", True), lm_model_path=parser.str("lm_model_path") or None, diff --git a/acestep/api/http/release_task_request_builder_test.py b/acestep/api/http/release_task_request_builder_test.py index 72a0a0682..c9515eb08 100644 --- a/acestep/api/http/release_task_request_builder_test.py +++ b/acestep/api/http/release_task_request_builder_test.py @@ -135,5 +135,33 @@ def test_build_request_forwards_audio_code_string_and_cover_noise_strength(self) self.assertAlmostEqual(0.6, request.cover_noise_strength) + def test_build_request_forwards_sampler_and_dit_latent_params(self): + """Builder should include sampler_mode and latent post-processing params.""" + + parser = _FakeParser( + { + "sampler_mode": "heun", + "velocity_norm_threshold": 2.0, + "velocity_ema_factor": 0.1, + "latent_shift": 0.05, + "latent_rescale": 1.2, + } + ) + request = build_generate_music_request( + parser=parser, + request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs), + default_dit_instruction="default-instruction", + lm_default_temperature=0.85, + lm_default_cfg_scale=2.5, + lm_default_top_p=0.9, + ) + + self.assertEqual("heun", request.sampler_mode) + self.assertAlmostEqual(2.0, request.velocity_norm_threshold) + self.assertAlmostEqual(0.1, request.velocity_ema_factor) + self.assertAlmostEqual(0.05, request.latent_shift) + self.assertAlmostEqual(1.2, request.latent_rescale) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/api/job_generation_setup.py b/acestep/api/job_generation_setup.py index 34ee446f1..89c05868b 100644 --- a/acestep/api/job_generation_setup.py +++ b/acestep/api/job_generation_setup.py @@ -174,6 +174,11 @@ def build_generation_setup( cfg_interval_end=req.cfg_interval_end, shift=req.shift, infer_method=req.infer_method, + sampler_mode=getattr(req, "sampler_mode", "euler"), + velocity_norm_threshold=getattr(req, "velocity_norm_threshold", 0.0), + velocity_ema_factor=getattr(req, "velocity_ema_factor", 0.0), + latent_shift=getattr(req, "latent_shift", 0.0), + latent_rescale=getattr(req, "latent_rescale", 1.0), timesteps=parsed_timesteps, repainting_start=req.repainting_start, repainting_end=req.repainting_end if req.repainting_end else -1, diff --git a/acestep/api/job_generation_setup_test.py b/acestep/api/job_generation_setup_test.py index 3ec650b94..6c82075bb 100644 --- a/acestep/api/job_generation_setup_test.py +++ b/acestep/api/job_generation_setup_test.py @@ -263,5 +263,73 @@ def test_use_cot_metas_enabled_when_format_has_duration(self) -> None: self.assertTrue(setup.params.use_cot_metas) + def test_build_generation_setup_forwards_sampler_and_dit_latent_params(self) -> None: + """Sampler mode and latent post-processing params should be forwarded.""" + + req = _base_req() + req.sampler_mode = "heun" + req.velocity_norm_threshold = 2.0 + req.velocity_ema_factor = 0.1 + req.latent_shift = 0.05 + req.latent_rescale = 1.2 + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertEqual("heun", setup.params.sampler_mode) + self.assertAlmostEqual(2.0, setup.params.velocity_norm_threshold) + self.assertAlmostEqual(0.1, setup.params.velocity_ema_factor) + self.assertAlmostEqual(0.05, setup.params.latent_shift) + self.assertAlmostEqual(1.2, setup.params.latent_rescale) + + def test_build_generation_setup_defaults_sampler_params_when_missing(self) -> None: + """When req lacks sampler/latent fields, getattr defaults should apply.""" + + req = _base_req() + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertEqual("euler", setup.params.sampler_mode) + self.assertAlmostEqual(0.0, setup.params.velocity_norm_threshold) + self.assertAlmostEqual(0.0, setup.params.velocity_ema_factor) + self.assertAlmostEqual(0.0, setup.params.latent_shift) + self.assertAlmostEqual(1.0, setup.params.latent_rescale) + + if __name__ == "__main__": unittest.main()