feat(api): expose APG eta and momentum on /release_task HTTP API#1123
feat(api): expose APG eta and momentum on /release_task HTTP API#1123FlexOr2 wants to merge 4 commits intoace-step:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request adds new generation control parameters to the music generation API and threads them through the entire generation pipeline. New fields enable caller control over diffusion sampling ( ChangesSampling control propagation
Sequence Diagram(s)sequenceDiagram
participant Client
participant RequestParser
participant RequestBuilder
participant GenerationSetup
participant GenerationHandler
participant Model
Client->>RequestParser: HTTP request with eta, momentum, sampler_mode, etc.
RequestParser->>RequestParser: Resolve parameters via PARAM_ALIASES
RequestParser-->>RequestBuilder: Return parsed parameters
RequestBuilder->>RequestBuilder: Build GenerateMusicRequest with new fields
RequestBuilder-->>GenerationSetup: Pass request object
GenerationSetup->>GenerationSetup: Create GenerationParams with eta, momentum, sampler_mode
GenerationSetup-->>GenerationHandler: Pass setup configuration
GenerationHandler->>GenerationHandler: Extract eta, momentum from setup.params
GenerationHandler-->>Model: Call generate_audio(eta=..., momentum=...)
Model->>Model: Initialize MomentumBuffer(momentum=momentum)
Model->>Model: Call apg_forward(..., eta=eta, momentum_buffer=...)
Model-->>Client: Return generated audio
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
acestep/core/generation/handler/service_generate_execute.py (1)
131-132: Optional: warn wheneta/momentumare set on the MLX path.
etaandmomentumare accepted intogenerate_kwargsbut the MLX branch (lines 175-236) never reads them — by design per the PR description ("MLX backend is out of scope"). However, a user passing non-defaulteta/momentumvia/release_taskwhile the MLX backend is active will get silent no-op semantics. A one-linelogger.info(similar to the existing DCW/haarnotice on lines 176-183) when(eta != 0.0 or momentum != -0.75)andself.use_mlx_ditis true would prevent confusion and reduce future support load.♻️ Suggested log on the MLX branch
if self.use_mlx_dit and self.mlx_decoder is not None: + if generate_kwargs.get("eta", 0.0) != 0.0 or generate_kwargs.get("momentum", -0.75) != -0.75: + logger.info( + "[service_generate] APG eta/momentum overrides " + "(eta={}, momentum={}) are ignored on the MLX backend; " + "these are PyTorch-only.", + generate_kwargs.get("eta"), + generate_kwargs.get("momentum"), + ) if generate_kwargs.get("dcw_enabled") and generate_kwargs.get("dcw_wavelet", "haar") != "haar":Also applies to: 175-236
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/core/generation/handler/service_generate_execute.py` around lines 131 - 132, The MLX code path silently ignores eta and momentum passed through generate_kwargs; update the MLX branch in service_generate_execute.py (the block guarded by self.use_mlx_dit) to log a one-line informational warning when non-default values are provided: check (eta != 0.0 or momentum != -0.75) and, if true and self.use_mlx_dit is set, call logger.info with a concise message stating that eta/momentum are ignored on the MLX backend; place this adjacent to the existing DCW/haar notice so users see the same kind of feedback as in the other branch and reference the variables eta, momentum, generate_kwargs, self.use_mlx_dit, and logger.info when making the change.acestep/api/http/release_task_models.py (2)
100-103: Recommend constrainingsampler_modetoLiteral["euler", "heun"]for consistency with other enum-like fields.The description restricts the value to
"euler"or"heun", and other enum-style fields in this model useLiteral[...](chunk_mask_modeL65,repaint_modeL74). Today, an arbitrary string passes Pydantic validation and is then silently coerced to euler at the model layer (use_heun = sampler_mode == "heun"inmodeling_acestep_v15_base.py), so a typo like"heum"would produce a wrong-but-quiet sampler choice rather than a clean 422. Tightening the type at the API boundary surfaces the mistake to the caller.♻️ Suggested change
- sampler_mode: str = Field( + sampler_mode: Literal["euler", "heun"] = Field( default="euler", description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/api/http/release_task_models.py` around lines 100 - 103, The sampler_mode field currently declared as sampler_mode: str = Field(...) should be constrained to a Literal to enforce only "euler" or "heun"; update the annotation to sampler_mode: Literal["euler", "heun"] = Field(default="euler", description=...) and add an import for Literal (from typing import Literal) if not already present so Pydantic will return a 422 for invalid values instead of silently accepting arbitrary strings; keep the default "euler" and preserve the existing description.
104-127: Optional: add range validation to the new numeric knobs.These fields propagate directly into the diffusion loop without further validation. A few sensible Pydantic constraints would catch obvious misuse early:
velocity_norm_threshold:ge=0.0(negative makes no sense —use_norm_clampchecks> 0.0).velocity_ema_factor: typicallyge=0.0, le=1.0; values outside this break the convex blend invt = (1 - f) * vt + f * prev_vt.latent_rescale:gt=0.0(zero rescale would null out latents; negative flips sign).eta/momentumare passed through APG math where any float is technically defined; leaving them unconstrained is fine, but you may want a soft sanity range (e.g.etain[0, 1]) if matching upstream APG conventions.Not a blocker — purely defensive.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/api/http/release_task_models.py` around lines 104 - 127, Add Pydantic range constraints to the new numeric knobs so invalid values are caught early: set velocity_norm_threshold Field to ge=0.0, set velocity_ema_factor Field to ge=0.0 and le=1.0, and set latent_rescale Field to gt=0.0; leave eta and momentum unconstrained (or optionally add soft ranges like eta between 0 and 1) but ensure the changes are applied to the Field declarations for velocity_norm_threshold, velocity_ema_factor, latent_rescale (and optionally eta/momentum) in the ReleaseTask model.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/models/common/apg_guidance_test.py`:
- Around line 89-107: The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.
---
Nitpick comments:
In `@acestep/api/http/release_task_models.py`:
- Around line 100-103: The sampler_mode field currently declared as
sampler_mode: str = Field(...) should be constrained to a Literal to enforce
only "euler" or "heun"; update the annotation to sampler_mode: Literal["euler",
"heun"] = Field(default="euler", description=...) and add an import for Literal
(from typing import Literal) if not already present so Pydantic will return a
422 for invalid values instead of silently accepting arbitrary strings; keep the
default "euler" and preserve the existing description.
- Around line 104-127: Add Pydantic range constraints to the new numeric knobs
so invalid values are caught early: set velocity_norm_threshold Field to ge=0.0,
set velocity_ema_factor Field to ge=0.0 and le=1.0, and set latent_rescale Field
to gt=0.0; leave eta and momentum unconstrained (or optionally add soft ranges
like eta between 0 and 1) but ensure the changes are applied to the Field
declarations for velocity_norm_threshold, velocity_ema_factor, latent_rescale
(and optionally eta/momentum) in the ReleaseTask model.
In `@acestep/core/generation/handler/service_generate_execute.py`:
- Around line 131-132: The MLX code path silently ignores eta and momentum
passed through generate_kwargs; update the MLX branch in
service_generate_execute.py (the block guarded by self.use_mlx_dit) to log a
one-line informational warning when non-default values are provided: check (eta
!= 0.0 or momentum != -0.75) and, if true and self.use_mlx_dit is set, call
logger.info with a concise message stating that eta/momentum are ignored on the
MLX backend; place this adjacent to the existing DCW/haar notice so users see
the same kind of feedback as in the other branch and reference the variables
eta, momentum, generate_kwargs, self.use_mlx_dit, and logger.info when making
the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 19c8b5d8-a49e-46f7-b003-7b422952cd7b
📒 Files selected for processing (18)
acestep/api/http/release_task_models.pyacestep/api/http/release_task_models_test.pyacestep/api/http/release_task_param_parser.pyacestep/api/http/release_task_param_parser_test.pyacestep/api/http/release_task_request_builder.pyacestep/api/http/release_task_request_builder_test.pyacestep/api/job_generation_setup.pyacestep/api/job_generation_setup_test.pyacestep/core/generation/handler/generate_music.pyacestep/core/generation/handler/generate_music_execute.pyacestep/core/generation/handler/service_generate.pyacestep/core/generation/handler/service_generate_execute.pyacestep/inference.pyacestep/models/base/modeling_acestep_v15_base.pyacestep/models/common/apg_guidance_test.pyacestep/models/sft/modeling_acestep_v15_base.pyacestep/models/xl_base/modeling_acestep_v15_xl_base.pyacestep/models/xl_sft/modeling_acestep_v15_xl_base.py
| def test_default_momentum_buffer_matches_explicit_neg_075(self): | ||
| """Default MomentumBuffer() must produce the same trajectory as momentum=-0.75.""" | ||
|
|
||
| pred_cond, pred_uncond = self._make_inputs() | ||
| implicit_out = apg_forward( | ||
| pred_cond=pred_cond, | ||
| pred_uncond=pred_uncond, | ||
| guidance_scale=3.0, | ||
| momentum_buffer=MomentumBuffer(), | ||
| dims=[1], | ||
| ) | ||
| explicit_out = apg_forward( | ||
| pred_cond=pred_cond, | ||
| pred_uncond=pred_uncond, | ||
| guidance_scale=3.0, | ||
| momentum_buffer=MomentumBuffer(momentum=-0.75), | ||
| dims=[1], | ||
| ) | ||
| self.assertTrue(torch.allclose(implicit_out, explicit_out)) |
There was a problem hiding this comment.
This default-momentum test is a tautology and does not actually verify the -0.75 default.
Each MomentumBuffer is freshly constructed with running_average = 0, and update() computes running_average = momentum * 0 + diff = diff on the first call regardless of momentum. Because apg_forward is invoked exactly once per buffer, both branches produce the same output for any momentum value (including, e.g., +0.5). If someone changed MomentumBuffer's default to a different value, this test would still pass.
To actually exercise the default, either (a) invoke apg_forward twice on each buffer so the second update sees a non-zero running average, or (b) contrast against a buffer with a clearly different momentum and assert non-equality.
🛠️ Suggested fix — drive the buffers through two updates and contrast against a mismatched momentum
def test_default_momentum_buffer_matches_explicit_neg_075(self):
"""Default MomentumBuffer() must produce the same trajectory as momentum=-0.75."""
pred_cond, pred_uncond = self._make_inputs()
- implicit_out = apg_forward(
- pred_cond=pred_cond,
- pred_uncond=pred_uncond,
- guidance_scale=3.0,
- momentum_buffer=MomentumBuffer(),
- dims=[1],
- )
- explicit_out = apg_forward(
- pred_cond=pred_cond,
- pred_uncond=pred_uncond,
- guidance_scale=3.0,
- momentum_buffer=MomentumBuffer(momentum=-0.75),
- dims=[1],
- )
- self.assertTrue(torch.allclose(implicit_out, explicit_out))
+ implicit_buf = MomentumBuffer()
+ explicit_buf = MomentumBuffer(momentum=-0.75)
+ mismatched_buf = MomentumBuffer(momentum=0.5)
+ for buf in (implicit_buf, explicit_buf, mismatched_buf):
+ apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=buf,
+ dims=[1],
+ )
+ # Second call exercises momentum * prev term so the default value matters.
+ implicit_out = apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=implicit_buf,
+ dims=[1],
+ )
+ explicit_out = apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=explicit_buf,
+ dims=[1],
+ )
+ mismatched_out = apg_forward(
+ pred_cond=pred_cond,
+ pred_uncond=pred_uncond,
+ guidance_scale=3.0,
+ momentum_buffer=mismatched_buf,
+ dims=[1],
+ )
+ self.assertTrue(torch.allclose(implicit_out, explicit_out))
+ self.assertFalse(torch.allclose(implicit_out, mismatched_out))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/models/common/apg_guidance_test.py` around lines 89 - 107, The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.
… API Five GenerationParams fields that were only accessible via Gradio's direct Python path are now accepted by the /release_task HTTP endpoint: - sampler_mode: "euler" (default) or "heun" sampler selection - velocity_norm_threshold: velocity prediction norm clamping (0=off) - velocity_ema_factor: velocity EMA smoothing (0=off) - latent_shift: additive shift on DiT latents before VAE decode - latent_rescale: multiplicative rescale on DiT latents before VAE decode All five already exist on the internal GenerationParams dataclass and are wired through the diffusion loop. This change adds them to: - GenerateMusicRequest (Pydantic request model) - PARAM_ALIASES (camelCase alias support) - build_generate_music_request (request builder) - build_generation_setup (GenerationParams wiring) Tests added for alias resolution, request builder forwarding, and generation setup wiring (including getattr fallback for older callers). Default values match the existing GenerationParams defaults, so omitting them preserves current behavior. Non-target code paths (Gradio UI, training, CLI) are unchanged.
Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
The APG (Adaptive Projected Guidance) `eta` and `momentum` knobs were hardcoded at the MomentumBuffer() and apg_forward() call sites inside each DiT model. Lift them into the generation plumbing so callers can tune APG without patching model code: - apg_forward now receives `eta` from the model's generate_audio call (same default 0.0 preserved). - MomentumBuffer is instantiated with the caller's `momentum` value (same default -0.75 preserved). - generate_audio on all four DiT model variants (base, sft, xl_base, xl_sft) accepts `eta` and `momentum` as keyword-only args. - GenerationParams gains two new fields (eta, momentum) that thread down through generate_music -> _run_generate_music_service_with_progress -> service_generate -> _build_service_generate_kwargs -> model.generate_audio. MLX path is intentionally not touched: `_mlx_apg_forward` has no slot for `eta`, and the MLX MomentumBuffer equivalent uses a fixed running-sum (no decay factor). Exposing these knobs on MLX is a separate upstream change. PyTorch path is unaffected by the MLX gap since the new kwargs are consumed by the PyTorch model's generate_audio directly. All new parameters are keyword-only with defaults equal to the pre-existing hardcoded values, so omitting them preserves current behavior. Unit test (acestep/models/common/apg_guidance_test.py) verifies this invariant at the apg_forward / MomentumBuffer level. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
The two APG guidance knobs exposed here (`eta` and `momentum`) were already plumbed through GenerationParams and the DiT model stack in the previous commit, but could not be set from the HTTP surface. This change wires them into the /release_task request path, mirroring how PR ace-step#1092 exposed sampler_mode and the four DiT latent params: - GenerateMusicRequest: new `eta: float = 0.0` and `momentum: float = -0.75` fields (defaults match apg_guidance.py hardcoded values). - PARAM_ALIASES: identity aliases for `eta` and `momentum` (both names are single lowercase words, so snake_case == camelCase). - build_generate_music_request: forwards parsed values into the Pydantic request model. - build_generation_setup: uses `getattr(req, ...)` with the same defaults, matching the existing sampler_mode/latent_shift pattern so older callers that don't set these fields still work. The _FakeParser helper in release_task_request_builder_test.py gained a `default=None` parameter on `.get()` to match the real parser's contract — this was a pre-existing mismatch that broke several unrelated tests; fixing it here lets the full builder suite run green. Tests added alongside the changes cover: - alias resolution for `eta` and `momentum` (flat body + nested param_obj), - default values on the Pydantic model, - request-builder forwarding and default-fallback behavior, - generation-setup forwarding into GenerationParams, - the getattr-default path when a request object lacks the fields. All new fields are optional with defaults matching the pre-existing hardcoded APG values; omitting them preserves current HTTP behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
a2a397a to
931be76
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
acestep/models/xl_base/modeling_acestep_v15_xl_base.py (1)
1848-1895: ⚡ Quick win
generate_audiois missing a docstring — guideline violation for a modified public method.The function signature spans ~47 lines with ~35 parameters, and is now modified by this PR, but has no docstring at all. The new
etaandmomentumparameters are undocumented.📝 Suggested docstring skeleton (add after the signature, before line 1896)
**kwargs, ): + """Generate audio latents via flow-matching diffusion sampling. + + Args: + text_hidden_states: Text encoder hidden states [B, T_txt, D]. + text_attention_mask: Attention mask for text [B, T_txt]. + lyric_hidden_states: Lyric encoder hidden states [B, T_lyr, D]. + lyric_attention_mask: Attention mask for lyrics [B, T_lyr]. + refer_audio_acoustic_hidden_states_packed: Packed timbre features. + refer_audio_order_mask: Batch-assignment indices for packed timbre. + src_latents: Source latents (context / cover reference) [B, T, D]. + chunk_masks: Binary chunk masks [B, T, D]. + is_covers: Per-sample cover flag tensor [B]. + eta: APG momentum blend coefficient forwarded to ``apg_forward``. + 0.0 disables momentum (pure APG); defaults to ``0.0``. + momentum: Exponential decay factor for ``MomentumBuffer``; negative + values invert the momentum direction. Defaults to ``-0.75``. + ... (remaining args omitted for brevity — fill in as needed) + + Returns: + Dict with keys ``target_latents`` and ``time_costs``. + """ # Backward-compat: accept the old misspelled key ...As per coding guidelines: "Docstrings are mandatory for all new or modified Python modules, classes, and functions."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py` around lines 1848 - 1895, The public method generate_audio is missing a docstring; add a concise docstring immediately after the generate_audio(...) signature describing the method purpose, summary of inputs and outputs, and document key parameters (including newly added eta and momentum) with their types and default values, any important behavior (e.g., infer_method, infer_steps, diffusion_guidance_scale, use_adg, sampler_mode, timesteps, and return value), and note side effects or expected tensor shapes where relevant so the modified public API is fully documented for callers and linters.acestep/core/generation/handler/generate_music.py (1)
239-275: ⚡ Quick win
etaandmomentumare missing from thegenerate_musicdocstring.The Args section documents
dcw_*and other sampler params added in prior PRs but omits the two new APG parameters.📝 Proposed docstring addition
dcw_wavelet: PyWavelets basis, e.g. ``"haar"`` / ``"db4"`` / ``"sym8"``. On the MLX path only ``"haar"`` is implemented natively; other bases warn once and fall back to Haar. + eta: APG momentum blend coefficient passed to ``apg_forward``. + 0.0 disables momentum accumulation; defaults to ``0.0``. + momentum: Exponential decay factor for the APG momentum buffer. + Negative values invert the momentum direction. + Defaults to ``-0.75``. timesteps: Optional custom timestep schedule.As per coding guidelines: "Docstrings are mandatory for all new or modified Python modules, classes, and functions."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@acestep/core/generation/handler/generate_music.py` around lines 239 - 275, The generate_music docstring omits the new APG sampler parameters eta and momentum; update the Args section of the generate_music function to document both eta (APG noise/mixing parameter) and momentum (APG momentum coefficient), describing their types, default behavior, and how they affect sampling (e.g., float, optional, influence noise injection and momentum term during APG sampling), and place them near the other sampler/DCW params (dcw_*, infer_method, timesteps) so readers can find all sampler-related options together.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py`:
- Around line 1892-1893: Add a concise docstring to the public method
generate_audio describing its purpose (what kind of audio it generates and any
high-level behavior), listing key parameters (including eta and momentum if they
influence generation or are passed through; mention their types and typical
ranges), the return value (type and meaning, e.g., waveform tensor or file
path), and any exceptions that can be raised (e.g., ValueError for invalid
inputs, RuntimeError for inference failures). Keep it brief and follow the style
of existing docstrings in the class so it satisfies the project's
mandatory-docstring guideline for public functions.
---
Nitpick comments:
In `@acestep/core/generation/handler/generate_music.py`:
- Around line 239-275: The generate_music docstring omits the new APG sampler
parameters eta and momentum; update the Args section of the generate_music
function to document both eta (APG noise/mixing parameter) and momentum (APG
momentum coefficient), describing their types, default behavior, and how they
affect sampling (e.g., float, optional, influence noise injection and momentum
term during APG sampling), and place them near the other sampler/DCW params
(dcw_*, infer_method, timesteps) so readers can find all sampler-related options
together.
In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py`:
- Around line 1848-1895: The public method generate_audio is missing a
docstring; add a concise docstring immediately after the generate_audio(...)
signature describing the method purpose, summary of inputs and outputs, and
document key parameters (including newly added eta and momentum) with their
types and default values, any important behavior (e.g., infer_method,
infer_steps, diffusion_guidance_scale, use_adg, sampler_mode, timesteps, and
return value), and note side effects or expected tensor shapes where relevant so
the modified public API is fully documented for callers and linters.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: a47dd89c-3a88-4cd6-975e-c9f6542affdc
📒 Files selected for processing (18)
acestep/api/http/release_task_models.pyacestep/api/http/release_task_models_test.pyacestep/api/http/release_task_param_parser.pyacestep/api/http/release_task_param_parser_test.pyacestep/api/http/release_task_request_builder.pyacestep/api/http/release_task_request_builder_test.pyacestep/api/job_generation_setup.pyacestep/api/job_generation_setup_test.pyacestep/core/generation/handler/generate_music.pyacestep/core/generation/handler/generate_music_execute.pyacestep/core/generation/handler/service_generate.pyacestep/core/generation/handler/service_generate_execute.pyacestep/inference.pyacestep/models/base/modeling_acestep_v15_base.pyacestep/models/common/apg_guidance_test.pyacestep/models/sft/modeling_acestep_v15_base.pyacestep/models/xl_base/modeling_acestep_v15_xl_base.pyacestep/models/xl_sft/modeling_acestep_v15_xl_base.py
✅ Files skipped from review due to trivial changes (1)
- acestep/api/http/release_task_param_parser.py
🚧 Files skipped from review as they are similar to previous changes (13)
- acestep/api/job_generation_setup.py
- acestep/core/generation/handler/service_generate_execute.py
- acestep/inference.py
- acestep/core/generation/handler/generate_music_execute.py
- acestep/api/http/release_task_request_builder_test.py
- acestep/api/http/release_task_models.py
- acestep/api/job_generation_setup_test.py
- acestep/models/base/modeling_acestep_v15_base.py
- acestep/models/common/apg_guidance_test.py
- acestep/core/generation/handler/service_generate.py
- acestep/models/xl_sft/modeling_acestep_v15_xl_base.py
- acestep/api/http/release_task_param_parser_test.py
- acestep/models/sft/modeling_acestep_v15_base.py
| eta: float = 0.0, | ||
| momentum: float = -0.75, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Confirm apg_forward and MomentumBuffer signatures in both locations
rg -n "def apg_forward|class MomentumBuffer" --type py -A 5Repository: ace-step/ACE-Step-1.5
Length of output: 1460
🏁 Script executed:
#!/bin/bash
# Check imports in the xl_base model file and locate the apg_guidance module it uses
head -50 acestep/models/xl_base/modeling_acestep_v15_xl_base.py | grep -E "^import|^from"Repository: ace-step/ACE-Step-1.5
Length of output: 940
🏁 Script executed:
#!/bin/bash
# Check if there's an apg_guidance.py in the xl_base directory
fd "apg_guidance.py" acestep/models/xl_base/Repository: ace-step/ACE-Step-1.5
Length of output: 104
🏁 Script executed:
#!/bin/bash
# Check the docstring of generate_audio method in the xl_base model
rg -n "def generate_audio" acestep/models/xl_base/modeling_acestep_v15_xl_base.py -A 15Repository: ace-step/ACE-Step-1.5
Length of output: 838
🏁 Script executed:
#!/bin/bash
# Search for apg_guidance import in the xl_base model file
rg "apg_guidance" acestep/models/xl_base/modeling_acestep_v15_xl_base.py -B 2 -A 2Repository: ace-step/ACE-Step-1.5
Length of output: 434
🏁 Script executed:
#!/bin/bash
# Check docstring in generate_audio method - look at lines 1848-1870 to see if docstring is present
sed -n '1848,1875p' acestep/models/xl_base/modeling_acestep_v15_xl_base.pyRepository: ace-step/ACE-Step-1.5
Length of output: 1342
🏁 Script executed:
#!/bin/bash
# Compare signatures in both apg_guidance files
echo "=== xl_base/apg_guidance.py ===" && \
head -50 acestep/models/xl_base/apg_guidance.py && \
echo -e "\n=== common/apg_guidance.py ===" && \
head -50 acestep/models/common/apg_guidance.pyRepository: ace-step/ACE-Step-1.5
Length of output: 2185
🏁 Script executed:
#!/bin/bash
# Check lines immediately after parameter list to see docstring
sed -n '1870,1885p' acestep/models/xl_base/modeling_acestep_v15_xl_base.pyRepository: ace-step/ACE-Step-1.5
Length of output: 828
🏁 Script executed:
#!/bin/bash
# Check further down to find where parameters end and function body begins
sed -n '1880,1905p' acestep/models/xl_base/modeling_acestep_v15_xl_base.pyRepository: ace-step/ACE-Step-1.5
Length of output: 1310
Add docstring to generate_audio method.
The function signatures for apg_forward and MomentumBuffer.__init__ correctly accept eta and momentum parameters respectively, matching the calls at lines 1983 and 2080. However, the generate_audio method (starting line 1848) lacks a docstring, violating the guideline requirement for mandatory docstrings on all public functions. Add a concise docstring documenting purpose, key parameters, return value, and any raised exceptions.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py` around lines 1892 -
1893, Add a concise docstring to the public method generate_audio describing its
purpose (what kind of audio it generates and any high-level behavior), listing
key parameters (including eta and momentum if they influence generation or are
passed through; mention their types and typical ranges), the return value (type
and meaning, e.g., waveform tensor or file path), and any exceptions that can be
raised (e.g., ValueError for invalid inputs, RuntimeError for inference
failures). Keep it brief and follow the style of existing docstrings in the
class so it satisfies the project's mandatory-docstring guideline for public
functions.
Summary
Surfaces the two remaining APG (Adaptive Projected Guidance) knobs —
etaandmomentum— on the/release_taskHTTP API, following the pattern of #1092.Stacking note
This branch is based on the tip of #1092, not on
main, because both PRs touchrelease_task_models.py,release_task_param_parser.py,release_task_request_builder.py, andjob_generation_setup.py. Filing as a Draft — I'll rebase ontomainand mark ready-for-review once #1092 merges. If you'd prefer I close this and refile after #1092 lands, happy to.Context
Public community workflows (e.g. the Civitai ACE-Step 1.5 XL workflow) tune APG using
eta(parallel-component scaling, currently hardcoded to0.0atacestep/models/common/apg_guidance.py:38) andmomentum(theMomentumBufferdecay factor, currently hardcoded to-0.75atacestep/models/common/apg_guidance.py:7). Both are frozen at theMomentumBuffer()/apg_forward()call sites inside each DiT model, so HTTP API clients can't tune them.#1092 established the pattern for exposing internal DiT params on
/release_task. This PR reuses that pattern foretaandmomentumso HTTP clients reach parity with ComfyUI users for APG tuning.Changes
GenerationParamsgainseta: float = 0.0andmomentum: float = -0.75(both defaults match the previous hardcoded values — see Backward compatibility below).generate_audioon the four DiT model variants that use APG —base,sft,xl_base,xl_sft— acceptsetaandmomentumas keyword-only params and forwards them toMomentumBuffer(momentum=...)andapg_forward(..., eta=...).turbo,xl_turbo) are distilled without CFG → no APG path → not affected.generate_music→_run_generate_music_service_with_progress→service_generate→_build_service_generate_kwargs.GenerateMusicRequest,PARAM_ALIASES,build_generate_music_request, andbuild_generation_setupall accept and forward the new fields.GenerationSetupwiring, and a model-level invariant that omitting the new kwargs yields the same output as passing the historical defaults explicitly.One unrelated fix bundled
Commit
6e95ec3also repairs_FakeParser.get()inrelease_task_request_builder_test.py, which was missing adefault=parameter — 5 pre-existing tests were red because of it. Bundled here because the new tests added by this PR exercise the same helper and need the fix to pass.Per CONTRIBUTING.md's "Solve One Problem at a Time" principle I'd normally split this into its own PR. Happy to do so if you prefer — the parser fix is one commit, isolatable with
git cherry-pick.Backward compatibility
All new parameters are optional. Defaults (
eta=0.0,momentum=-0.75) match the values the code used before this PR. Clients that omit the new fields see byte-identical behavior.MLX scope
PyTorch only. MLX's
_mlx_apg_forwardinacestep/models/mlx/dit_generate.py:102has noetaparameter at all, and its momentum state is an unbounded running sum (no decay factor). Exposing these knobs on the MLX backend is a non-trivial separate change — not in scope here per "minimize blast radius."Test plan
python -m unittest acestep.api.http.release_task_param_parser_test acestep.api.http.release_task_models_test acestep.api.http.release_task_request_builder_test acestep.api.job_generation_setup_test acestep.models.common.apg_guidance_test— 38 tests green/release_taskwith{"eta": 0.5, "momentum": -0.5}and confirm the values round-trip into the storedGenerationParams/release_taskclients still work without specifying the new fieldsRelated
Summary by CodeRabbit
New Features
Tests