feat(api): expose APG eta and momentum on /release_task HTTP API#1123
feat(api): expose APG eta and momentum on /release_task HTTP API#1123FlexOr2 wants to merge 5 commits intoace-step:mainfrom
Conversation
… API Five GenerationParams fields that were only accessible via Gradio's direct Python path are now accepted by the /release_task HTTP endpoint: - sampler_mode: "euler" (default) or "heun" sampler selection - velocity_norm_threshold: velocity prediction norm clamping (0=off) - velocity_ema_factor: velocity EMA smoothing (0=off) - latent_shift: additive shift on DiT latents before VAE decode - latent_rescale: multiplicative rescale on DiT latents before VAE decode All five already exist on the internal GenerationParams dataclass and are wired through the diffusion loop. This change adds them to: - GenerateMusicRequest (Pydantic request model) - PARAM_ALIASES (camelCase alias support) - build_generate_music_request (request builder) - build_generation_setup (GenerationParams wiring) Tests added for alias resolution, request builder forwarding, and generation setup wiring (including getattr fallback for older callers). Default values match the existing GenerationParams defaults, so omitting them preserves current behavior. Non-target code paths (Gradio UI, training, CLI) are unchanged.
Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
The APG (Adaptive Projected Guidance) `eta` and `momentum` knobs were hardcoded at the MomentumBuffer() and apg_forward() call sites inside each DiT model. Lift them into the generation plumbing so callers can tune APG without patching model code: - apg_forward now receives `eta` from the model's generate_audio call (same default 0.0 preserved). - MomentumBuffer is instantiated with the caller's `momentum` value (same default -0.75 preserved). - generate_audio on all four DiT model variants (base, sft, xl_base, xl_sft) accepts `eta` and `momentum` as keyword-only args. - GenerationParams gains two new fields (eta, momentum) that thread down through generate_music -> _run_generate_music_service_with_progress -> service_generate -> _build_service_generate_kwargs -> model.generate_audio. MLX path is intentionally not touched: `_mlx_apg_forward` has no slot for `eta`, and the MLX MomentumBuffer equivalent uses a fixed running-sum (no decay factor). Exposing these knobs on MLX is a separate upstream change. PyTorch path is unaffected by the MLX gap since the new kwargs are consumed by the PyTorch model's generate_audio directly. All new parameters are keyword-only with defaults equal to the pre-existing hardcoded values, so omitting them preserves current behavior. Unit test (acestep/models/common/apg_guidance_test.py) verifies this invariant at the apg_forward / MomentumBuffer level. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
The two APG guidance knobs exposed here (`eta` and `momentum`) were already plumbed through GenerationParams and the DiT model stack in the previous commit, but could not be set from the HTTP surface. This change wires them into the /release_task request path, mirroring how PR ace-step#1092 exposed sampler_mode and the four DiT latent params: - GenerateMusicRequest: new `eta: float = 0.0` and `momentum: float = -0.75` fields (defaults match apg_guidance.py hardcoded values). - PARAM_ALIASES: identity aliases for `eta` and `momentum` (both names are single lowercase words, so snake_case == camelCase). - build_generate_music_request: forwards parsed values into the Pydantic request model. - build_generation_setup: uses `getattr(req, ...)` with the same defaults, matching the existing sampler_mode/latent_shift pattern so older callers that don't set these fields still work. The _FakeParser helper in release_task_request_builder_test.py gained a `default=None` parameter on `.get()` to match the real parser's contract — this was a pre-existing mismatch that broke several unrelated tests; fixing it here lets the full builder suite run green. Tests added alongside the changes cover: - alias resolution for `eta` and `momentum` (flat body + nested param_obj), - default values on the Pydantic model, - request-builder forwarding and default-fallback behavior, - generation-setup forwarding into GenerationParams, - the getattr-default path when a request object lacks the fields. All new fields are optional with defaults matching the pre-existing hardcoded APG values; omitting them preserves current HTTP behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
📝 WalkthroughWalkthroughThis pull request adds new generation control parameters to the music generation API and threads them through the entire generation pipeline. New fields enable caller control over diffusion sampling ( Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant RequestParser
participant RequestBuilder
participant GenerationSetup
participant GenerationHandler
participant Model
Client->>RequestParser: HTTP request with eta, momentum, sampler_mode, etc.
RequestParser->>RequestParser: Resolve parameters via PARAM_ALIASES
RequestParser-->>RequestBuilder: Return parsed parameters
RequestBuilder->>RequestBuilder: Build GenerateMusicRequest with new fields
RequestBuilder-->>GenerationSetup: Pass request object
GenerationSetup->>GenerationSetup: Create GenerationParams with eta, momentum, sampler_mode
GenerationSetup-->>GenerationHandler: Pass setup configuration
GenerationHandler->>GenerationHandler: Extract eta, momentum from setup.params
GenerationHandler-->>Model: Call generate_audio(eta=..., momentum=...)
Model->>Model: Initialize MomentumBuffer(momentum=momentum)
Model->>Model: Call apg_forward(..., eta=eta, momentum_buffer=...)
Model-->>Client: Return generated audio
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Resolves conflicts in generate_audio() / handler signatures, kwargs forwarding, and inference.py GenerationParams where the upstream DCW work (dcw_enabled, dcw_mode, dcw_scaler, dcw_high_scaler, dcw_wavelet, plus timesteps in base) and the upstream task_type kwarg landed alongside this branch's eta / momentum APG additions. Conflicts were all in adjacent kwarg blocks, dict literals, and docstrings; resolution keeps both sides.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
acestep/core/generation/handler/service_generate_execute.py (1)
131-132: Optional: warn wheneta/momentumare set on the MLX path.
etaandmomentumare accepted intogenerate_kwargsbut the MLX branch (lines 175-236) never reads them — by design per the PR description ("MLX backend is out of scope"). However, a user passing non-defaulteta/momentumvia/release_taskwhile the MLX backend is active will get silent no-op semantics. A one-linelogger.info(similar to the existing DCW/haarnotice on lines 176-183) when(eta != 0.0 or momentum != -0.75)andself.use_mlx_ditis true would prevent confusion and reduce future support load.♻️ Suggested log on the MLX branch
if self.use_mlx_dit and self.mlx_decoder is not None: + if generate_kwargs.get("eta", 0.0) != 0.0 or generate_kwargs.get("momentum", -0.75) != -0.75: + logger.info( + "[service_generate] APG eta/momentum overrides " + "(eta={}, momentum={}) are ignored on the MLX backend; " + "these are PyTorch-only.", + generate_kwargs.get("eta"), + generate_kwargs.get("momentum"), + ) if generate_kwargs.get("dcw_enabled") and generate_kwargs.get("dcw_wavelet", "haar") != "haar":Also applies to: 175-236
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/core/generation/handler/service_generate_execute.py` around lines 131 - 132, The MLX code path silently ignores eta and momentum passed through generate_kwargs; update the MLX branch in service_generate_execute.py (the block guarded by self.use_mlx_dit) to log a one-line informational warning when non-default values are provided: check (eta != 0.0 or momentum != -0.75) and, if true and self.use_mlx_dit is set, call logger.info with a concise message stating that eta/momentum are ignored on the MLX backend; place this adjacent to the existing DCW/haar notice so users see the same kind of feedback as in the other branch and reference the variables eta, momentum, generate_kwargs, self.use_mlx_dit, and logger.info when making the change.acestep/api/http/release_task_models.py (2)
100-103: Recommend constrainingsampler_modetoLiteral["euler", "heun"]for consistency with other enum-like fields.The description restricts the value to
"euler"or"heun", and other enum-style fields in this model useLiteral[...](chunk_mask_modeL65,repaint_modeL74). Today, an arbitrary string passes Pydantic validation and is then silently coerced to euler at the model layer (use_heun = sampler_mode == "heun"inmodeling_acestep_v15_base.py), so a typo like"heum"would produce a wrong-but-quiet sampler choice rather than a clean 422. Tightening the type at the API boundary surfaces the mistake to the caller.♻️ Suggested change
- sampler_mode: str = Field( + sampler_mode: Literal["euler", "heun"] = Field( default="euler", description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/api/http/release_task_models.py` around lines 100 - 103, The sampler_mode field currently declared as sampler_mode: str = Field(...) should be constrained to a Literal to enforce only "euler" or "heun"; update the annotation to sampler_mode: Literal["euler", "heun"] = Field(default="euler", description=...) and add an import for Literal (from typing import Literal) if not already present so Pydantic will return a 422 for invalid values instead of silently accepting arbitrary strings; keep the default "euler" and preserve the existing description.
104-127: Optional: add range validation to the new numeric knobs.These fields propagate directly into the diffusion loop without further validation. A few sensible Pydantic constraints would catch obvious misuse early:
velocity_norm_threshold:ge=0.0(negative makes no sense —use_norm_clampchecks> 0.0).velocity_ema_factor: typicallyge=0.0, le=1.0; values outside this break the convex blend invt = (1 - f) * vt + f * prev_vt.latent_rescale:gt=0.0(zero rescale would null out latents; negative flips sign).eta/momentumare passed through APG math where any float is technically defined; leaving them unconstrained is fine, but you may want a soft sanity range (e.g.etain[0, 1]) if matching upstream APG conventions.Not a blocker — purely defensive.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/api/http/release_task_models.py` around lines 104 - 127, Add Pydantic range constraints to the new numeric knobs so invalid values are caught early: set velocity_norm_threshold Field to ge=0.0, set velocity_ema_factor Field to ge=0.0 and le=1.0, and set latent_rescale Field to gt=0.0; leave eta and momentum unconstrained (or optionally add soft ranges like eta between 0 and 1) but ensure the changes are applied to the Field declarations for velocity_norm_threshold, velocity_ema_factor, latent_rescale (and optionally eta/momentum) in the ReleaseTask model.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/models/common/apg_guidance_test.py`:
- Around line 89-107: The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.
---
Nitpick comments:
In `@acestep/api/http/release_task_models.py`:
- Around line 100-103: The sampler_mode field currently declared as
sampler_mode: str = Field(...) should be constrained to a Literal to enforce
only "euler" or "heun"; update the annotation to sampler_mode: Literal["euler",
"heun"] = Field(default="euler", description=...) and add an import for Literal
(from typing import Literal) if not already present so Pydantic will return a
422 for invalid values instead of silently accepting arbitrary strings; keep the
default "euler" and preserve the existing description.
- Around line 104-127: Add Pydantic range constraints to the new numeric knobs
so invalid values are caught early: set velocity_norm_threshold Field to ge=0.0,
set velocity_ema_factor Field to ge=0.0 and le=1.0, and set latent_rescale Field
to gt=0.0; leave eta and momentum unconstrained (or optionally add soft ranges
like eta between 0 and 1) but ensure the changes are applied to the Field
declarations for velocity_norm_threshold, velocity_ema_factor, latent_rescale
(and optionally eta/momentum) in the ReleaseTask model.
In `@acestep/core/generation/handler/service_generate_execute.py`:
- Around line 131-132: The MLX code path silently ignores eta and momentum
passed through generate_kwargs; update the MLX branch in
service_generate_execute.py (the block guarded by self.use_mlx_dit) to log a
one-line informational warning when non-default values are provided: check (eta
!= 0.0 or momentum != -0.75) and, if true and self.use_mlx_dit is set, call
logger.info with a concise message stating that eta/momentum are ignored on the
MLX backend; place this adjacent to the existing DCW/haar notice so users see
the same kind of feedback as in the other branch and reference the variables
eta, momentum, generate_kwargs, self.use_mlx_dit, and logger.info when making
the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 19c8b5d8-a49e-46f7-b003-7b422952cd7b
📒 Files selected for processing (18)
acestep/api/http/release_task_models.pyacestep/api/http/release_task_models_test.pyacestep/api/http/release_task_param_parser.pyacestep/api/http/release_task_param_parser_test.pyacestep/api/http/release_task_request_builder.pyacestep/api/http/release_task_request_builder_test.pyacestep/api/job_generation_setup.pyacestep/api/job_generation_setup_test.pyacestep/core/generation/handler/generate_music.pyacestep/core/generation/handler/generate_music_execute.pyacestep/core/generation/handler/service_generate.pyacestep/core/generation/handler/service_generate_execute.pyacestep/inference.pyacestep/models/base/modeling_acestep_v15_base.pyacestep/models/common/apg_guidance_test.pyacestep/models/sft/modeling_acestep_v15_base.pyacestep/models/xl_base/modeling_acestep_v15_xl_base.pyacestep/models/xl_sft/modeling_acestep_v15_xl_base.py
| def test_default_momentum_buffer_matches_explicit_neg_075(self): | ||
| """Default MomentumBuffer() must produce the same trajectory as momentum=-0.75.""" | ||
|
|
||
| pred_cond, pred_uncond = self._make_inputs() | ||
| implicit_out = apg_forward( | ||
| pred_cond=pred_cond, | ||
| pred_uncond=pred_uncond, | ||
| guidance_scale=3.0, | ||
| momentum_buffer=MomentumBuffer(), | ||
| dims=[1], | ||
| ) | ||
| explicit_out = apg_forward( | ||
| pred_cond=pred_cond, | ||
| pred_uncond=pred_uncond, | ||
| guidance_scale=3.0, | ||
| momentum_buffer=MomentumBuffer(momentum=-0.75), | ||
| dims=[1], | ||
| ) | ||
| self.assertTrue(torch.allclose(implicit_out, explicit_out)) |
There was a problem hiding this comment.
This default-momentum test is a tautology and does not actually verify the -0.75 default.
Each MomentumBuffer is freshly constructed with running_average = 0, and update() computes running_average = momentum * 0 + diff = diff on the first call regardless of momentum. Because apg_forward is invoked exactly once per buffer, both branches produce the same output for any momentum value (including, e.g., +0.5). If someone changed MomentumBuffer's default to a different value, this test would still pass.
To actually exercise the default, either (a) invoke apg_forward twice on each buffer so the second update sees a non-zero running average, or (b) contrast against a buffer with a clearly different momentum and assert non-equality.
🛠️ Suggested fix — drive the buffers through two updates and contrast against a mismatched momentum
def test_default_momentum_buffer_matches_explicit_neg_075(self):
"""Default MomentumBuffer() must produce the same trajectory as momentum=-0.75."""
pred_cond, pred_uncond = self._make_inputs()
- implicit_out = apg_forward(
- pred_cond=pred_cond,
- pred_uncond=pred_uncond,
- guidance_scale=3.0,
- momentum_buffer=MomentumBuffer(),
- dims=[1],
- )
- explicit_out = apg_forward(
- pred_cond=pred_cond,
- pred_uncond=pred_uncond,
- guidance_scale=3.0,
- momentum_buffer=MomentumBuffer(momentum=-0.75),
- dims=[1],
- )
- self.assertTrue(torch.allclose(implicit_out, explicit_out))
+ implicit_buf = MomentumBuffer()
+ explicit_buf = MomentumBuffer(momentum=-0.75)
+ mismatched_buf = MomentumBuffer(momentum=0.5)
+ for buf in (implicit_buf, explicit_buf, mismatched_buf):
+ apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=buf,
+ dims=[1],
+ )
+ # Second call exercises momentum * prev term so the default value matters.
+ implicit_out = apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=implicit_buf,
+ dims=[1],
+ )
+ explicit_out = apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=explicit_buf,
+ dims=[1],
+ )
+ mismatched_out = apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=mismatched_buf,
+ dims=[1],
+ )
+ self.assertTrue(torch.allclose(implicit_out, explicit_out))
+ self.assertFalse(torch.allclose(implicit_out, mismatched_out))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/models/common/apg_guidance_test.py` around lines 89 - 107, The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.
Summary
Surfaces the two remaining APG (Adaptive Projected Guidance) knobs —
etaandmomentum— on the/release_taskHTTP API, following the pattern of #1092.Stacking note
This branch is based on the tip of #1092, not on
main, because both PRs touchrelease_task_models.py,release_task_param_parser.py,release_task_request_builder.py, andjob_generation_setup.py. Filing as a Draft — I'll rebase ontomainand mark ready-for-review once #1092 merges. If you'd prefer I close this and refile after #1092 lands, happy to.Context
Public community workflows (e.g. the Civitai ACE-Step 1.5 XL workflow) tune APG using
eta(parallel-component scaling, currently hardcoded to0.0atacestep/models/common/apg_guidance.py:38) andmomentum(theMomentumBufferdecay factor, currently hardcoded to-0.75atacestep/models/common/apg_guidance.py:7). Both are frozen at theMomentumBuffer()/apg_forward()call sites inside each DiT model, so HTTP API clients can't tune them.#1092 established the pattern for exposing internal DiT params on
/release_task. This PR reuses that pattern foretaandmomentumso HTTP clients reach parity with ComfyUI users for APG tuning.Changes
GenerationParamsgainseta: float = 0.0andmomentum: float = -0.75(both defaults match the previous hardcoded values — see Backward compatibility below).generate_audioon the four DiT model variants that use APG —base,sft,xl_base,xl_sft— acceptsetaandmomentumas keyword-only params and forwards them toMomentumBuffer(momentum=...)andapg_forward(..., eta=...).turbo,xl_turbo) are distilled without CFG → no APG path → not affected.generate_music→_run_generate_music_service_with_progress→service_generate→_build_service_generate_kwargs.GenerateMusicRequest,PARAM_ALIASES,build_generate_music_request, andbuild_generation_setupall accept and forward the new fields.GenerationSetupwiring, and a model-level invariant that omitting the new kwargs yields the same output as passing the historical defaults explicitly.One unrelated fix bundled
Commit
6e95ec3also repairs_FakeParser.get()inrelease_task_request_builder_test.py, which was missing adefault=parameter — 5 pre-existing tests were red because of it. Bundled here because the new tests added by this PR exercise the same helper and need the fix to pass.Per CONTRIBUTING.md's "Solve One Problem at a Time" principle I'd normally split this into its own PR. Happy to do so if you prefer — the parser fix is one commit, isolatable with
git cherry-pick.Backward compatibility
All new parameters are optional. Defaults (
eta=0.0,momentum=-0.75) match the values the code used before this PR. Clients that omit the new fields see byte-identical behavior.MLX scope
PyTorch only. MLX's
_mlx_apg_forwardinacestep/models/mlx/dit_generate.py:102has noetaparameter at all, and its momentum state is an unbounded running sum (no decay factor). Exposing these knobs on the MLX backend is a non-trivial separate change — not in scope here per "minimize blast radius."Test plan
python -m unittest acestep.api.http.release_task_param_parser_test acestep.api.http.release_task_models_test acestep.api.http.release_task_request_builder_test acestep.api.job_generation_setup_test acestep.models.common.apg_guidance_test— 38 tests green/release_taskwith{"eta": 0.5, "momentum": -0.5}and confirm the values round-trip into the storedGenerationParams/release_taskclients still work without specifying the new fieldsRelated
Summary by CodeRabbit
Release Notes
New Features
Tests