Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions acestep/api/http/release_task_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,34 @@ class GenerateMusicRequest(BaseModel):
default=None,
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift.",
)
sampler_mode: str = Field(
default="euler",
description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).",
)
velocity_norm_threshold: float = Field(
default=0.0,
description="Clamp velocity prediction norms during diffusion (0 = disabled, try 2.0).",
)
velocity_ema_factor: float = Field(
default=0.0,
description="Velocity EMA smoothing during diffusion (0 = disabled, try 0.1).",
)
eta: float = Field(
default=0.0,
description="APG guidance eta: parallel-component scaling factor (0.0 = pure orthogonal guidance).",
)
momentum: float = Field(
default=-0.75,
description="APG MomentumBuffer decay factor for accumulated guidance diff (default -0.75).",
)
latent_shift: float = Field(
default=0.0,
description="Additive shift applied to DiT latents before VAE decode (0 = no shift).",
)
latent_rescale: float = Field(
default=1.0,
description="Multiplicative rescale applied to DiT latents before VAE decode (1.0 = no rescale).",
)

audio_format: str = Field(
default="mp3",
Expand Down
14 changes: 14 additions & 0 deletions acestep/api/http/release_task_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ def test_audio_code_string_and_cover_noise_strength_are_accepted(self):
self.assertEqual("<|audio_code_1|>", req.audio_code_string)
self.assertAlmostEqual(0.75, req.cover_noise_strength)

def test_apg_eta_and_momentum_have_expected_defaults(self):
"""APG eta/momentum defaults must match the hardcoded values in apg_guidance."""

req = GenerateMusicRequest()
self.assertAlmostEqual(0.0, req.eta)
self.assertAlmostEqual(-0.75, req.momentum)

def test_apg_eta_and_momentum_are_accepted(self):
"""Model should accept user-supplied eta and momentum values."""

req = GenerateMusicRequest(eta=0.5, momentum=-0.5)
self.assertAlmostEqual(0.5, req.eta)
self.assertAlmostEqual(-0.5, req.momentum)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions acestep/api/http/release_task_param_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
"instruction": ["instruction"],
"track_name": ["track_name", "trackName"],
"track_classes": ["track_classes", "trackClasses", "instruments"],
"sampler_mode": ["sampler_mode", "samplerMode"],
"velocity_norm_threshold": ["velocity_norm_threshold", "velocityNormThreshold"],
"velocity_ema_factor": ["velocity_ema_factor", "velocityEmaFactor"],
"eta": ["eta"],
"momentum": ["momentum"],
"latent_shift": ["latent_shift", "latentShift"],
"latent_rescale": ["latent_rescale", "latentRescale"],
}


Expand Down
33 changes: 33 additions & 0 deletions acestep/api/http/release_task_param_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,38 @@ def test_non_dict_param_obj_json_is_ignored(self):
self.assertEqual("meta-caption", parser.str("prompt"))


def test_sampler_and_dit_param_aliases_are_resolved(self):
"""Parser should resolve camelCase aliases for sampler/DiT params."""

parser = RequestParser(
{
"samplerMode": "heun",
"velocityNormThreshold": "2.5",
"latentShift": "0.1",
"velocityEmaFactor": "0.95",
"latentRescale": "1.2",
}
)
self.assertEqual("heun", parser.str("sampler_mode"))
self.assertAlmostEqual(2.5, parser.float("velocity_norm_threshold"))
self.assertAlmostEqual(0.1, parser.float("latent_shift"))
self.assertAlmostEqual(0.95, parser.float("velocity_ema_factor"))
self.assertAlmostEqual(1.2, parser.float("latent_rescale"))

def test_apg_eta_and_momentum_aliases_are_resolved(self):
"""Parser should resolve eta/momentum APG params from snake_case keys."""

parser = RequestParser({"eta": "0.5", "momentum": "-0.5"})
self.assertAlmostEqual(0.5, parser.float("eta"))
self.assertAlmostEqual(-0.5, parser.float("momentum"))

def test_apg_eta_and_momentum_resolve_from_nested_param_obj(self):
"""Parser should resolve eta/momentum from nested param_obj payload."""

parser = RequestParser({"param_obj": {"eta": 0.75, "momentum": -0.9}})
self.assertAlmostEqual(0.75, parser.float("eta"))
self.assertAlmostEqual(-0.9, parser.float("momentum"))


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions acestep/api/http/release_task_request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def build_generate_music_request(
cfg_interval_end=parser.float("cfg_interval_end", 1.0),
infer_method=parser.str("infer_method", "ode"),
shift=parser.float("shift", 3.0),
sampler_mode=parser.str("sampler_mode", "euler"),
velocity_norm_threshold=parser.float("velocity_norm_threshold", 0.0),
velocity_ema_factor=parser.float("velocity_ema_factor", 0.0),
eta=parser.float("eta", 0.0),
momentum=parser.float("momentum", -0.75),
latent_shift=parser.float("latent_shift", 0.0),
latent_rescale=parser.float("latent_rescale", 1.0),
audio_format=parser.str("audio_format", "mp3"),
use_tiled_decode=parser.bool("use_tiled_decode", True),
lm_model_path=parser.str("lm_model_path") or None,
Expand Down
66 changes: 63 additions & 3 deletions acestep/api/http/release_task_request_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def __init__(self, values: dict) -> None:

self._values = values

def get(self, key: str):
"""Return raw value for ``key`` from parser payload."""
def get(self, key: str, default=None):
"""Return raw value for ``key`` from parser payload with optional default."""

return self._values.get(key)
return self._values.get(key, default)

def str(self, key: str, default: str = "") -> str:
"""Return string value for ``key`` with default fallback."""
Expand Down Expand Up @@ -135,5 +135,65 @@ def test_build_request_forwards_audio_code_string_and_cover_noise_strength(self)
self.assertAlmostEqual(0.6, request.cover_noise_strength)


def test_build_request_forwards_sampler_and_dit_latent_params(self):
"""Builder should include sampler_mode and latent post-processing params."""

parser = _FakeParser(
{
"sampler_mode": "heun",
"velocity_norm_threshold": 2.0,
"velocity_ema_factor": 0.1,
"latent_shift": 0.05,
"latent_rescale": 1.2,
}
)
request = build_generate_music_request(
parser=parser,
request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs),
default_dit_instruction="default-instruction",
lm_default_temperature=0.85,
lm_default_cfg_scale=2.5,
lm_default_top_p=0.9,
)

self.assertEqual("heun", request.sampler_mode)
self.assertAlmostEqual(2.0, request.velocity_norm_threshold)
self.assertAlmostEqual(0.1, request.velocity_ema_factor)
self.assertAlmostEqual(0.05, request.latent_shift)
self.assertAlmostEqual(1.2, request.latent_rescale)

def test_build_request_forwards_apg_eta_and_momentum(self):
"""Builder should include APG eta/momentum params in the payload."""

parser = _FakeParser({"eta": 0.5, "momentum": -0.5})
request = build_generate_music_request(
parser=parser,
request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs),
default_dit_instruction="default-instruction",
lm_default_temperature=0.85,
lm_default_cfg_scale=2.5,
lm_default_top_p=0.9,
)

self.assertAlmostEqual(0.5, request.eta)
self.assertAlmostEqual(-0.5, request.momentum)

def test_build_request_defaults_apg_eta_and_momentum_when_absent(self):
"""Builder should fall back to APG hardcoded defaults when params are absent."""

parser = _FakeParser({})
request = build_generate_music_request(
parser=parser,
request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs),
default_dit_instruction="default-instruction",
lm_default_temperature=0.85,
lm_default_cfg_scale=2.5,
lm_default_top_p=0.9,
)

self.assertAlmostEqual(0.0, request.eta)
self.assertAlmostEqual(-0.75, request.momentum)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions acestep/api/job_generation_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ def build_generation_setup(
cfg_interval_end=req.cfg_interval_end,
shift=req.shift,
infer_method=req.infer_method,
sampler_mode=getattr(req, "sampler_mode", "euler"),
velocity_norm_threshold=getattr(req, "velocity_norm_threshold", 0.0),
velocity_ema_factor=getattr(req, "velocity_ema_factor", 0.0),
eta=getattr(req, "eta", 0.0),
momentum=getattr(req, "momentum", -0.75),
latent_shift=getattr(req, "latent_shift", 0.0),
latent_rescale=getattr(req, "latent_rescale", 1.0),
timesteps=parsed_timesteps,
repainting_start=req.repainting_start,
repainting_end=req.repainting_end if req.repainting_end else -1,
Expand Down
126 changes: 126 additions & 0 deletions acestep/api/job_generation_setup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,131 @@ def test_use_cot_metas_enabled_when_format_has_duration(self) -> None:
self.assertTrue(setup.params.use_cot_metas)


def test_build_generation_setup_forwards_sampler_and_dit_latent_params(self) -> None:
"""Sampler mode and latent post-processing params should be forwarded."""

req = _base_req()
req.sampler_mode = "heun"
req.velocity_norm_threshold = 2.0
req.velocity_ema_factor = 0.1
req.latent_shift = 0.05
req.latent_rescale = 1.2
setup = build_generation_setup(
req=req,
caption="cap",
lyrics="lyr",
bpm=None,
key_scale="",
time_signature="",
audio_duration=None,
thinking=False,
sample_mode=False,
format_has_duration=False,
use_cot_caption=False,
use_cot_language=False,
lm_top_k=0,
lm_top_p=0.9,
parse_timesteps=lambda _value: None,
is_instrumental=lambda _lyrics: False,
default_dit_instruction="default instruction",
task_instructions={},
)

self.assertEqual("heun", setup.params.sampler_mode)
self.assertAlmostEqual(2.0, setup.params.velocity_norm_threshold)
self.assertAlmostEqual(0.1, setup.params.velocity_ema_factor)
self.assertAlmostEqual(0.05, setup.params.latent_shift)
self.assertAlmostEqual(1.2, setup.params.latent_rescale)

def test_build_generation_setup_defaults_sampler_params_when_missing(self) -> None:
"""When req lacks sampler/latent fields, getattr defaults should apply."""

req = _base_req()
setup = build_generation_setup(
req=req,
caption="cap",
lyrics="lyr",
bpm=None,
key_scale="",
time_signature="",
audio_duration=None,
thinking=False,
sample_mode=False,
format_has_duration=False,
use_cot_caption=False,
use_cot_language=False,
lm_top_k=0,
lm_top_p=0.9,
parse_timesteps=lambda _value: None,
is_instrumental=lambda _lyrics: False,
default_dit_instruction="default instruction",
task_instructions={},
)

self.assertEqual("euler", setup.params.sampler_mode)
self.assertAlmostEqual(0.0, setup.params.velocity_norm_threshold)
self.assertAlmostEqual(0.0, setup.params.velocity_ema_factor)
self.assertAlmostEqual(0.0, setup.params.latent_shift)
self.assertAlmostEqual(1.0, setup.params.latent_rescale)

def test_build_generation_setup_forwards_apg_eta_and_momentum(self) -> None:
"""APG eta and momentum should be forwarded to GenerationParams."""

req = _base_req()
req.eta = 0.5
req.momentum = -0.5
setup = build_generation_setup(
req=req,
caption="cap",
lyrics="lyr",
bpm=None,
key_scale="",
time_signature="",
audio_duration=None,
thinking=False,
sample_mode=False,
format_has_duration=False,
use_cot_caption=False,
use_cot_language=False,
lm_top_k=0,
lm_top_p=0.9,
parse_timesteps=lambda _value: None,
is_instrumental=lambda _lyrics: False,
default_dit_instruction="default instruction",
task_instructions={},
)

self.assertAlmostEqual(0.5, setup.params.eta)
self.assertAlmostEqual(-0.5, setup.params.momentum)

def test_build_generation_setup_defaults_apg_params_when_missing(self) -> None:
"""When req lacks eta/momentum fields, the APG hardcoded defaults should apply."""

req = _base_req()
setup = build_generation_setup(
req=req,
caption="cap",
lyrics="lyr",
bpm=None,
key_scale="",
time_signature="",
audio_duration=None,
thinking=False,
sample_mode=False,
format_has_duration=False,
use_cot_caption=False,
use_cot_language=False,
lm_top_k=0,
lm_top_p=0.9,
parse_timesteps=lambda _value: None,
is_instrumental=lambda _lyrics: False,
default_dit_instruction="default instruction",
task_instructions={},
)

self.assertAlmostEqual(0.0, setup.params.eta)
self.assertAlmostEqual(-0.75, setup.params.momentum)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions acestep/core/generation/handler/generate_music.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def generate_music(
dcw_scaler: float = 0.05,
dcw_high_scaler: float = 0.02,
dcw_wavelet: str = "haar",
eta: float = 0.0,
momentum: float = -0.75,
use_tiled_decode: bool = True,
timesteps: Optional[List[float]] = None,
latent_shift: float = 0.0,
Expand Down Expand Up @@ -408,6 +410,8 @@ def generate_music(
dcw_scaler=dcw_scaler,
dcw_high_scaler=dcw_high_scaler,
dcw_wavelet=dcw_wavelet,
eta=eta,
momentum=momentum,
repaint_crossfade_frames=resolved_cf_frames,
repaint_injection_ratio=injection_ratio,
source_repaint_latents=source_repaint_latents,
Expand Down
4 changes: 4 additions & 0 deletions acestep/core/generation/handler/generate_music_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _run_generate_music_service_with_progress(
dcw_scaler: float = 0.05,
dcw_high_scaler: float = 0.02,
dcw_wavelet: str = "haar",
eta: float = 0.0,
momentum: float = -0.75,
repaint_crossfade_frames: int = 10,
repaint_injection_ratio: float = 0.5,
source_repaint_latents: Any = None,
Expand Down Expand Up @@ -105,6 +107,8 @@ def _service_target():
dcw_scaler=dcw_scaler,
dcw_high_scaler=dcw_high_scaler,
dcw_wavelet=dcw_wavelet,
eta=eta,
momentum=momentum,
audio_code_hints=service_inputs["audio_code_hints_batch"],
return_intermediate=service_inputs["should_return_intermediate"],
timesteps=timesteps,
Expand Down
Loading