Skip to content

feat(api): expose APG eta and momentum on /release_task HTTP API#1123

Open
FlexOr2 wants to merge 4 commits intoace-step:mainfrom
FlexOr2:feat/http-api-expose-apg-eta-momentum
Open

feat(api): expose APG eta and momentum on /release_task HTTP API#1123
FlexOr2 wants to merge 4 commits intoace-step:mainfrom
FlexOr2:feat/http-api-expose-apg-eta-momentum

Conversation

@FlexOr2
Copy link
Copy Markdown

@FlexOr2 FlexOr2 commented Apr 21, 2026

Summary

Surfaces the two remaining APG (Adaptive Projected Guidance) knobs — eta and momentum — on the /release_task HTTP API, following the pattern of #1092.

Stacking note

This branch is based on the tip of #1092, not on main, because both PRs touch release_task_models.py, release_task_param_parser.py, release_task_request_builder.py, and job_generation_setup.py. Filing as a Draft — I'll rebase onto main and mark ready-for-review once #1092 merges. If you'd prefer I close this and refile after #1092 lands, happy to.

Context

Public community workflows (e.g. the Civitai ACE-Step 1.5 XL workflow) tune APG using eta (parallel-component scaling, currently hardcoded to 0.0 at acestep/models/common/apg_guidance.py:38) and momentum (the MomentumBuffer decay factor, currently hardcoded to -0.75 at acestep/models/common/apg_guidance.py:7). Both are frozen at the MomentumBuffer() / apg_forward() call sites inside each DiT model, so HTTP API clients can't tune them.

#1092 established the pattern for exposing internal DiT params on /release_task. This PR reuses that pattern for eta and momentum so HTTP clients reach parity with ComfyUI users for APG tuning.

Changes

  • GenerationParams gains eta: float = 0.0 and momentum: float = -0.75 (both defaults match the previous hardcoded values — see Backward compatibility below).
  • generate_audio on the four DiT model variants that use APG — base, sft, xl_base, xl_sft — accepts eta and momentum as keyword-only params and forwards them to MomentumBuffer(momentum=...) and apg_forward(..., eta=...).
  • Turbo variants (turbo, xl_turbo) are distilled without CFG → no APG path → not affected.
  • The kwargs thread through generate_music_run_generate_music_service_with_progressservice_generate_build_service_generate_kwargs.
  • HTTP surface: GenerateMusicRequest, PARAM_ALIASES, build_generate_music_request, and build_generation_setup all accept and forward the new fields.
  • Unit tests cover alias resolution, Pydantic defaults, request-builder forwarding, GenerationSetup wiring, and a model-level invariant that omitting the new kwargs yields the same output as passing the historical defaults explicitly.

One unrelated fix bundled

Commit 6e95ec3 also repairs _FakeParser.get() in release_task_request_builder_test.py, which was missing a default= parameter — 5 pre-existing tests were red because of it. Bundled here because the new tests added by this PR exercise the same helper and need the fix to pass.

Per CONTRIBUTING.md's "Solve One Problem at a Time" principle I'd normally split this into its own PR. Happy to do so if you prefer — the parser fix is one commit, isolatable with git cherry-pick.

Backward compatibility

All new parameters are optional. Defaults (eta=0.0, momentum=-0.75) match the values the code used before this PR. Clients that omit the new fields see byte-identical behavior.

MLX scope

PyTorch only. MLX's _mlx_apg_forward in acestep/models/mlx/dit_generate.py:102 has no eta parameter at all, and its momentum state is an unbounded running sum (no decay factor). Exposing these knobs on the MLX backend is a non-trivial separate change — not in scope here per "minimize blast radius."

Test plan

  • python -m unittest acestep.api.http.release_task_param_parser_test acestep.api.http.release_task_models_test acestep.api.http.release_task_request_builder_test acestep.api.job_generation_setup_test acestep.models.common.apg_guidance_test — 38 tests green
  • Smoke test via live /release_task with {"eta": 0.5, "momentum": -0.5} and confirm the values round-trip into the stored GenerationParams
  • Confirm existing /release_task clients still work without specifying the new fields

Related

Summary by CodeRabbit

  • New Features

    • Added advanced diffusion/APG controls: sampler selection, velocity tuning (threshold/EMA), APG hyperparameters (eta, momentum), and latent-space shift/rescale — all defaulted to preserve existing behavior.
  • Tests

    • Expanded test coverage to validate parsing, defaulting, and propagation of the new generation parameters.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

This pull request adds new generation control parameters to the music generation API and threads them through the entire generation pipeline. New fields enable caller control over diffusion sampling (sampler_mode), velocity post-processing (velocity_norm_threshold, velocity_ema_factor), adaptive projected guidance (eta, momentum), and latent preprocessing (latent_shift, latent_rescale). Parameters flow from HTTP requests through request models, parameter parsing, generation handlers, and into model generation methods.

Changes

Sampling control propagation

Layer / File(s) Summary
Data Shape / Public API
acestep/api/http/release_task_models.py, acestep/inference.py
Adds new request fields and GenerationParams fields for sampler, velocity, APG (eta, momentum), and latent adjustments with explicit defaults and descriptions.
Parameter Aliases / Parser
acestep/api/http/release_task_param_parser.py
Extends PARAM_ALIASES to map camelCase/synonymous keys to the new canonical parameter names for parser resolution.
Request Builder & Tests
acestep/api/http/release_task_request_builder.py, acestep/api/http/release_task_models_test.py
Builder reads parser values (with defaults) and constructs GenerateMusicRequest; tests added for request-model defaults and overrides for APG parameters.
Request Parsing Tests
acestep/api/http/release_task_param_parser_test.py
Unit tests added to assert alias resolution and float coercion for the new sampler/APG/latent fields from top-level and nested param_obj payloads.
Generation Setup & Tests
acestep/api/job_generation_setup.py, acestep/api/job_generation_setup_test.py
build_generation_setup maps request attributes into GenerationParams using getattr with defaults; tests verify forwarding and defaulting behavior.
Handler Chain
acestep/core/generation/handler/generate_music.py, acestep/core/generation/handler/generate_music_execute.py, acestep/core/generation/handler/service_generate.py, acestep/core/generation/handler/service_generate_execute.py
Signatures updated to accept eta and momentum and thread them through executor and service calls; service kwargs builder injects these into generate_kwargs.
Model Integration
acestep/models/*/modeling_acestep_v15_*.py
Model generate_audio methods updated to accept eta and momentum, initialize MomentumBuffer(momentum=...), and forward eta into apg_forward in relevant CFG paths.
APG Guidance Tests
acestep/models/common/apg_guidance_test.py
New unit tests for MomentumBuffer defaults, update rule, and apg_forward behavior with and without eta.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant RequestParser
    participant RequestBuilder
    participant GenerationSetup
    participant GenerationHandler
    participant Model
    
    Client->>RequestParser: HTTP request with eta, momentum, sampler_mode, etc.
    RequestParser->>RequestParser: Resolve parameters via PARAM_ALIASES
    RequestParser-->>RequestBuilder: Return parsed parameters
    RequestBuilder->>RequestBuilder: Build GenerateMusicRequest with new fields
    RequestBuilder-->>GenerationSetup: Pass request object
    GenerationSetup->>GenerationSetup: Create GenerationParams with eta, momentum, sampler_mode
    GenerationSetup-->>GenerationHandler: Pass setup configuration
    GenerationHandler->>GenerationHandler: Extract eta, momentum from setup.params
    GenerationHandler-->>Model: Call generate_audio(eta=..., momentum=...)
    Model->>Model: Initialize MomentumBuffer(momentum=momentum)
    Model->>Model: Call apg_forward(..., eta=eta, momentum_buffer=...)
    Model-->>Client: Return generated audio
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • ChuxiJ

Poem

🐰 I hop on keys and thread the stream,
Eta and momentum join the scheme,
Sampler, shift, and velocity play,
Through request to model they find their way,
A rabbit cheers — now generate away!

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: exposing APG eta and momentum parameters on the /release_task HTTP API.
Docstring Coverage ✅ Passed Docstring coverage is 86.84% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@FlexOr2 FlexOr2 marked this pull request as ready for review April 25, 2026 15:53
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
acestep/core/generation/handler/service_generate_execute.py (1)

131-132: Optional: warn when eta/momentum are set on the MLX path.

eta and momentum are accepted into generate_kwargs but the MLX branch (lines 175-236) never reads them — by design per the PR description ("MLX backend is out of scope"). However, a user passing non-default eta/momentum via /release_task while the MLX backend is active will get silent no-op semantics. A one-line logger.info (similar to the existing DCW/haar notice on lines 176-183) when (eta != 0.0 or momentum != -0.75) and self.use_mlx_dit is true would prevent confusion and reduce future support load.

♻️ Suggested log on the MLX branch
                 if self.use_mlx_dit and self.mlx_decoder is not None:
+                    if generate_kwargs.get("eta", 0.0) != 0.0 or generate_kwargs.get("momentum", -0.75) != -0.75:
+                        logger.info(
+                            "[service_generate] APG eta/momentum overrides "
+                            "(eta={}, momentum={}) are ignored on the MLX backend; "
+                            "these are PyTorch-only.",
+                            generate_kwargs.get("eta"),
+                            generate_kwargs.get("momentum"),
+                        )
                     if generate_kwargs.get("dcw_enabled") and generate_kwargs.get("dcw_wavelet", "haar") != "haar":

Also applies to: 175-236

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/service_generate_execute.py` around lines 131
- 132, The MLX code path silently ignores eta and momentum passed through
generate_kwargs; update the MLX branch in service_generate_execute.py (the block
guarded by self.use_mlx_dit) to log a one-line informational warning when
non-default values are provided: check (eta != 0.0 or momentum != -0.75) and, if
true and self.use_mlx_dit is set, call logger.info with a concise message
stating that eta/momentum are ignored on the MLX backend; place this adjacent to
the existing DCW/haar notice so users see the same kind of feedback as in the
other branch and reference the variables eta, momentum, generate_kwargs,
self.use_mlx_dit, and logger.info when making the change.
acestep/api/http/release_task_models.py (2)

100-103: Recommend constraining sampler_mode to Literal["euler", "heun"] for consistency with other enum-like fields.

The description restricts the value to "euler" or "heun", and other enum-style fields in this model use Literal[...] (chunk_mask_mode L65, repaint_mode L74). Today, an arbitrary string passes Pydantic validation and is then silently coerced to euler at the model layer (use_heun = sampler_mode == "heun" in modeling_acestep_v15_base.py), so a typo like "heum" would produce a wrong-but-quiet sampler choice rather than a clean 422. Tightening the type at the API boundary surfaces the mistake to the caller.

♻️ Suggested change
-    sampler_mode: str = Field(
+    sampler_mode: Literal["euler", "heun"] = Field(
         default="euler",
         description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).",
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/api/http/release_task_models.py` around lines 100 - 103, The
sampler_mode field currently declared as sampler_mode: str = Field(...) should
be constrained to a Literal to enforce only "euler" or "heun"; update the
annotation to sampler_mode: Literal["euler", "heun"] = Field(default="euler",
description=...) and add an import for Literal (from typing import Literal) if
not already present so Pydantic will return a 422 for invalid values instead of
silently accepting arbitrary strings; keep the default "euler" and preserve the
existing description.

104-127: Optional: add range validation to the new numeric knobs.

These fields propagate directly into the diffusion loop without further validation. A few sensible Pydantic constraints would catch obvious misuse early:

  • velocity_norm_threshold: ge=0.0 (negative makes no sense — use_norm_clamp checks > 0.0).
  • velocity_ema_factor: typically ge=0.0, le=1.0; values outside this break the convex blend in vt = (1 - f) * vt + f * prev_vt.
  • latent_rescale: gt=0.0 (zero rescale would null out latents; negative flips sign).
  • eta / momentum are passed through APG math where any float is technically defined; leaving them unconstrained is fine, but you may want a soft sanity range (e.g. eta in [0, 1]) if matching upstream APG conventions.

Not a blocker — purely defensive.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/api/http/release_task_models.py` around lines 104 - 127, Add Pydantic
range constraints to the new numeric knobs so invalid values are caught early:
set velocity_norm_threshold Field to ge=0.0, set velocity_ema_factor Field to
ge=0.0 and le=1.0, and set latent_rescale Field to gt=0.0; leave eta and
momentum unconstrained (or optionally add soft ranges like eta between 0 and 1)
but ensure the changes are applied to the Field declarations for
velocity_norm_threshold, velocity_ema_factor, latent_rescale (and optionally
eta/momentum) in the ReleaseTask model.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/models/common/apg_guidance_test.py`:
- Around line 89-107: The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.

---

Nitpick comments:
In `@acestep/api/http/release_task_models.py`:
- Around line 100-103: The sampler_mode field currently declared as
sampler_mode: str = Field(...) should be constrained to a Literal to enforce
only "euler" or "heun"; update the annotation to sampler_mode: Literal["euler",
"heun"] = Field(default="euler", description=...) and add an import for Literal
(from typing import Literal) if not already present so Pydantic will return a
422 for invalid values instead of silently accepting arbitrary strings; keep the
default "euler" and preserve the existing description.
- Around line 104-127: Add Pydantic range constraints to the new numeric knobs
so invalid values are caught early: set velocity_norm_threshold Field to ge=0.0,
set velocity_ema_factor Field to ge=0.0 and le=1.0, and set latent_rescale Field
to gt=0.0; leave eta and momentum unconstrained (or optionally add soft ranges
like eta between 0 and 1) but ensure the changes are applied to the Field
declarations for velocity_norm_threshold, velocity_ema_factor, latent_rescale
(and optionally eta/momentum) in the ReleaseTask model.

In `@acestep/core/generation/handler/service_generate_execute.py`:
- Around line 131-132: The MLX code path silently ignores eta and momentum
passed through generate_kwargs; update the MLX branch in
service_generate_execute.py (the block guarded by self.use_mlx_dit) to log a
one-line informational warning when non-default values are provided: check (eta
!= 0.0 or momentum != -0.75) and, if true and self.use_mlx_dit is set, call
logger.info with a concise message stating that eta/momentum are ignored on the
MLX backend; place this adjacent to the existing DCW/haar notice so users see
the same kind of feedback as in the other branch and reference the variables
eta, momentum, generate_kwargs, self.use_mlx_dit, and logger.info when making
the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 19c8b5d8-a49e-46f7-b003-7b422952cd7b

📥 Commits

Reviewing files that changed from the base of the PR and between d5d958e and a2a397a.

📒 Files selected for processing (18)
  • acestep/api/http/release_task_models.py
  • acestep/api/http/release_task_models_test.py
  • acestep/api/http/release_task_param_parser.py
  • acestep/api/http/release_task_param_parser_test.py
  • acestep/api/http/release_task_request_builder.py
  • acestep/api/http/release_task_request_builder_test.py
  • acestep/api/job_generation_setup.py
  • acestep/api/job_generation_setup_test.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/inference.py
  • acestep/models/base/modeling_acestep_v15_base.py
  • acestep/models/common/apg_guidance_test.py
  • acestep/models/sft/modeling_acestep_v15_base.py
  • acestep/models/xl_base/modeling_acestep_v15_xl_base.py
  • acestep/models/xl_sft/modeling_acestep_v15_xl_base.py

Comment on lines +89 to +107
def test_default_momentum_buffer_matches_explicit_neg_075(self):
"""Default MomentumBuffer() must produce the same trajectory as momentum=-0.75."""

pred_cond, pred_uncond = self._make_inputs()
implicit_out = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_uncond,
guidance_scale=3.0,
momentum_buffer=MomentumBuffer(),
dims=[1],
)
explicit_out = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_uncond,
guidance_scale=3.0,
momentum_buffer=MomentumBuffer(momentum=-0.75),
dims=[1],
)
self.assertTrue(torch.allclose(implicit_out, explicit_out))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

This default-momentum test is a tautology and does not actually verify the -0.75 default.

Each MomentumBuffer is freshly constructed with running_average = 0, and update() computes running_average = momentum * 0 + diff = diff on the first call regardless of momentum. Because apg_forward is invoked exactly once per buffer, both branches produce the same output for any momentum value (including, e.g., +0.5). If someone changed MomentumBuffer's default to a different value, this test would still pass.

To actually exercise the default, either (a) invoke apg_forward twice on each buffer so the second update sees a non-zero running average, or (b) contrast against a buffer with a clearly different momentum and assert non-equality.

🛠️ Suggested fix — drive the buffers through two updates and contrast against a mismatched momentum
     def test_default_momentum_buffer_matches_explicit_neg_075(self):
         """Default MomentumBuffer() must produce the same trajectory as momentum=-0.75."""

         pred_cond, pred_uncond = self._make_inputs()
-        implicit_out = apg_forward(
-            pred_cond=pred_cond,
-            pred_uncond=pred_uncond,
-            guidance_scale=3.0,
-            momentum_buffer=MomentumBuffer(),
-            dims=[1],
-        )
-        explicit_out = apg_forward(
-            pred_cond=pred_cond,
-            pred_uncond=pred_uncond,
-            guidance_scale=3.0,
-            momentum_buffer=MomentumBuffer(momentum=-0.75),
-            dims=[1],
-        )
-        self.assertTrue(torch.allclose(implicit_out, explicit_out))
+        implicit_buf = MomentumBuffer()
+        explicit_buf = MomentumBuffer(momentum=-0.75)
+        mismatched_buf = MomentumBuffer(momentum=0.5)
+        for buf in (implicit_buf, explicit_buf, mismatched_buf):
+            apg_forward(
+                pred_cond=pred_cond,
+                pred_uncond=pred_uncond,
+                guidance_scale=3.0,
+                momentum_buffer=buf,
+                dims=[1],
+            )
+        # Second call exercises momentum * prev term so the default value matters.
+        implicit_out = apg_forward(
+            pred_cond=pred_cond,
+            pred_uncond=pred_uncond,
+            guidance_scale=3.0,
+            momentum_buffer=implicit_buf,
+            dims=[1],
+        )
+        explicit_out = apg_forward(
+            pred_cond=pred_cond,
+            pred_uncond=pred_uncond,
+            guidance_scale=3.0,
+            momentum_buffer=explicit_buf,
+            dims=[1],
+        )
+        mismatched_out = apg_forward(
+            pred_cond=pred_cond,
+            pred_uncond=pred_uncond,
+            guidance_scale=3.0,
+            momentum_buffer=mismatched_buf,
+            dims=[1],
+        )
+        self.assertTrue(torch.allclose(implicit_out, explicit_out))
+        self.assertFalse(torch.allclose(implicit_out, mismatched_out))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/common/apg_guidance_test.py` around lines 89 - 107, The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.

FlexOr2 and others added 4 commits May 7, 2026 20:42
… API

Five GenerationParams fields that were only accessible via Gradio's
direct Python path are now accepted by the /release_task HTTP endpoint:

- sampler_mode: "euler" (default) or "heun" sampler selection
- velocity_norm_threshold: velocity prediction norm clamping (0=off)
- velocity_ema_factor: velocity EMA smoothing (0=off)
- latent_shift: additive shift on DiT latents before VAE decode
- latent_rescale: multiplicative rescale on DiT latents before VAE decode

All five already exist on the internal GenerationParams dataclass and
are wired through the diffusion loop. This change adds them to:
- GenerateMusicRequest (Pydantic request model)
- PARAM_ALIASES (camelCase alias support)
- build_generate_music_request (request builder)
- build_generation_setup (GenerationParams wiring)

Tests added for alias resolution, request builder forwarding, and
generation setup wiring (including getattr fallback for older callers).

Default values match the existing GenerationParams defaults, so
omitting them preserves current behavior. Non-target code paths
(Gradio UI, training, CLI) are unchanged.
The APG (Adaptive Projected Guidance) `eta` and `momentum` knobs were
hardcoded at the MomentumBuffer() and apg_forward() call sites inside
each DiT model. Lift them into the generation plumbing so callers can
tune APG without patching model code:

- apg_forward now receives `eta` from the model's generate_audio call
  (same default 0.0 preserved).
- MomentumBuffer is instantiated with the caller's `momentum` value
  (same default -0.75 preserved).
- generate_audio on all four DiT model variants (base, sft, xl_base,
  xl_sft) accepts `eta` and `momentum` as keyword-only args.
- GenerationParams gains two new fields (eta, momentum) that thread
  down through generate_music -> _run_generate_music_service_with_progress
  -> service_generate -> _build_service_generate_kwargs ->
  model.generate_audio.

MLX path is intentionally not touched: `_mlx_apg_forward` has no slot
for `eta`, and the MLX MomentumBuffer equivalent uses a fixed
running-sum (no decay factor). Exposing these knobs on MLX is a
separate upstream change. PyTorch path is unaffected by the MLX gap
since the new kwargs are consumed by the PyTorch model's
generate_audio directly.

All new parameters are keyword-only with defaults equal to the
pre-existing hardcoded values, so omitting them preserves current
behavior. Unit test (acestep/models/common/apg_guidance_test.py)
verifies this invariant at the apg_forward / MomentumBuffer level.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
The two APG guidance knobs exposed here (`eta` and `momentum`) were
already plumbed through GenerationParams and the DiT model stack in
the previous commit, but could not be set from the HTTP surface. This
change wires them into the /release_task request path, mirroring how
PR ace-step#1092 exposed sampler_mode and the four DiT latent params:

- GenerateMusicRequest: new `eta: float = 0.0` and `momentum: float =
  -0.75` fields (defaults match apg_guidance.py hardcoded values).
- PARAM_ALIASES: identity aliases for `eta` and `momentum` (both names
  are single lowercase words, so snake_case == camelCase).
- build_generate_music_request: forwards parsed values into the
  Pydantic request model.
- build_generation_setup: uses `getattr(req, ...)` with the same
  defaults, matching the existing sampler_mode/latent_shift pattern so
  older callers that don't set these fields still work.

The _FakeParser helper in release_task_request_builder_test.py gained
a `default=None` parameter on `.get()` to match the real parser's
contract — this was a pre-existing mismatch that broke several
unrelated tests; fixing it here lets the full builder suite run green.

Tests added alongside the changes cover:
- alias resolution for `eta` and `momentum` (flat body + nested
  param_obj),
- default values on the Pydantic model,
- request-builder forwarding and default-fallback behavior,
- generation-setup forwarding into GenerationParams,
- the getattr-default path when a request object lacks the fields.

All new fields are optional with defaults matching the pre-existing
hardcoded APG values; omitting them preserves current HTTP behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@FlexOr2 FlexOr2 force-pushed the feat/http-api-expose-apg-eta-momentum branch from a2a397a to 931be76 Compare May 7, 2026 18:54
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
acestep/models/xl_base/modeling_acestep_v15_xl_base.py (1)

1848-1895: ⚡ Quick win

generate_audio is missing a docstring — guideline violation for a modified public method.

The function signature spans ~47 lines with ~35 parameters, and is now modified by this PR, but has no docstring at all. The new eta and momentum parameters are undocumented.

📝 Suggested docstring skeleton (add after the signature, before line 1896)
         **kwargs,
     ):
+        """Generate audio latents via flow-matching diffusion sampling.
+
+        Args:
+            text_hidden_states: Text encoder hidden states [B, T_txt, D].
+            text_attention_mask: Attention mask for text [B, T_txt].
+            lyric_hidden_states: Lyric encoder hidden states [B, T_lyr, D].
+            lyric_attention_mask: Attention mask for lyrics [B, T_lyr].
+            refer_audio_acoustic_hidden_states_packed: Packed timbre features.
+            refer_audio_order_mask: Batch-assignment indices for packed timbre.
+            src_latents: Source latents (context / cover reference) [B, T, D].
+            chunk_masks: Binary chunk masks [B, T, D].
+            is_covers: Per-sample cover flag tensor [B].
+            eta: APG momentum blend coefficient forwarded to ``apg_forward``.
+                0.0 disables momentum (pure APG); defaults to ``0.0``.
+            momentum: Exponential decay factor for ``MomentumBuffer``; negative
+                values invert the momentum direction.  Defaults to ``-0.75``.
+            ... (remaining args omitted for brevity — fill in as needed)
+
+        Returns:
+            Dict with keys ``target_latents`` and ``time_costs``.
+        """
         # Backward-compat: accept the old misspelled key ...

As per coding guidelines: "Docstrings are mandatory for all new or modified Python modules, classes, and functions."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py` around lines 1848 -
1895, The public method generate_audio is missing a docstring; add a concise
docstring immediately after the generate_audio(...) signature describing the
method purpose, summary of inputs and outputs, and document key parameters
(including newly added eta and momentum) with their types and default values,
any important behavior (e.g., infer_method, infer_steps,
diffusion_guidance_scale, use_adg, sampler_mode, timesteps, and return value),
and note side effects or expected tensor shapes where relevant so the modified
public API is fully documented for callers and linters.
acestep/core/generation/handler/generate_music.py (1)

239-275: ⚡ Quick win

eta and momentum are missing from the generate_music docstring.

The Args section documents dcw_* and other sampler params added in prior PRs but omits the two new APG parameters.

📝 Proposed docstring addition
             dcw_wavelet: PyWavelets basis, e.g. ``"haar"`` / ``"db4"`` / ``"sym8"``.
                 On the MLX path only ``"haar"`` is implemented natively; other
                 bases warn once and fall back to Haar.
+            eta: APG momentum blend coefficient passed to ``apg_forward``.
+                0.0 disables momentum accumulation; defaults to ``0.0``.
+            momentum: Exponential decay factor for the APG momentum buffer.
+                Negative values invert the momentum direction.
+                Defaults to ``-0.75``.
             timesteps: Optional custom timestep schedule.

As per coding guidelines: "Docstrings are mandatory for all new or modified Python modules, classes, and functions."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@acestep/core/generation/handler/generate_music.py` around lines 239 - 275,
The generate_music docstring omits the new APG sampler parameters eta and
momentum; update the Args section of the generate_music function to document
both eta (APG noise/mixing parameter) and momentum (APG momentum coefficient),
describing their types, default behavior, and how they affect sampling (e.g.,
float, optional, influence noise injection and momentum term during APG
sampling), and place them near the other sampler/DCW params (dcw_*,
infer_method, timesteps) so readers can find all sampler-related options
together.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py`:
- Around line 1892-1893: Add a concise docstring to the public method
generate_audio describing its purpose (what kind of audio it generates and any
high-level behavior), listing key parameters (including eta and momentum if they
influence generation or are passed through; mention their types and typical
ranges), the return value (type and meaning, e.g., waveform tensor or file
path), and any exceptions that can be raised (e.g., ValueError for invalid
inputs, RuntimeError for inference failures). Keep it brief and follow the style
of existing docstrings in the class so it satisfies the project's
mandatory-docstring guideline for public functions.

---

Nitpick comments:
In `@acestep/core/generation/handler/generate_music.py`:
- Around line 239-275: The generate_music docstring omits the new APG sampler
parameters eta and momentum; update the Args section of the generate_music
function to document both eta (APG noise/mixing parameter) and momentum (APG
momentum coefficient), describing their types, default behavior, and how they
affect sampling (e.g., float, optional, influence noise injection and momentum
term during APG sampling), and place them near the other sampler/DCW params
(dcw_*, infer_method, timesteps) so readers can find all sampler-related options
together.

In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py`:
- Around line 1848-1895: The public method generate_audio is missing a
docstring; add a concise docstring immediately after the generate_audio(...)
signature describing the method purpose, summary of inputs and outputs, and
document key parameters (including newly added eta and momentum) with their
types and default values, any important behavior (e.g., infer_method,
infer_steps, diffusion_guidance_scale, use_adg, sampler_mode, timesteps, and
return value), and note side effects or expected tensor shapes where relevant so
the modified public API is fully documented for callers and linters.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: a47dd89c-3a88-4cd6-975e-c9f6542affdc

📥 Commits

Reviewing files that changed from the base of the PR and between a2a397a and 931be76.

📒 Files selected for processing (18)
  • acestep/api/http/release_task_models.py
  • acestep/api/http/release_task_models_test.py
  • acestep/api/http/release_task_param_parser.py
  • acestep/api/http/release_task_param_parser_test.py
  • acestep/api/http/release_task_request_builder.py
  • acestep/api/http/release_task_request_builder_test.py
  • acestep/api/job_generation_setup.py
  • acestep/api/job_generation_setup_test.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/inference.py
  • acestep/models/base/modeling_acestep_v15_base.py
  • acestep/models/common/apg_guidance_test.py
  • acestep/models/sft/modeling_acestep_v15_base.py
  • acestep/models/xl_base/modeling_acestep_v15_xl_base.py
  • acestep/models/xl_sft/modeling_acestep_v15_xl_base.py
✅ Files skipped from review due to trivial changes (1)
  • acestep/api/http/release_task_param_parser.py
🚧 Files skipped from review as they are similar to previous changes (13)
  • acestep/api/job_generation_setup.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/inference.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/api/http/release_task_request_builder_test.py
  • acestep/api/http/release_task_models.py
  • acestep/api/job_generation_setup_test.py
  • acestep/models/base/modeling_acestep_v15_base.py
  • acestep/models/common/apg_guidance_test.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/models/xl_sft/modeling_acestep_v15_xl_base.py
  • acestep/api/http/release_task_param_parser_test.py
  • acestep/models/sft/modeling_acestep_v15_base.py

Comment on lines +1892 to +1893
eta: float = 0.0,
momentum: float = -0.75,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Confirm apg_forward and MomentumBuffer signatures in both locations
rg -n "def apg_forward|class MomentumBuffer" --type py -A 5

Repository: ace-step/ACE-Step-1.5

Length of output: 1460


🏁 Script executed:

#!/bin/bash
# Check imports in the xl_base model file and locate the apg_guidance module it uses
head -50 acestep/models/xl_base/modeling_acestep_v15_xl_base.py | grep -E "^import|^from"

Repository: ace-step/ACE-Step-1.5

Length of output: 940


🏁 Script executed:

#!/bin/bash
# Check if there's an apg_guidance.py in the xl_base directory
fd "apg_guidance.py" acestep/models/xl_base/

Repository: ace-step/ACE-Step-1.5

Length of output: 104


🏁 Script executed:

#!/bin/bash
# Check the docstring of generate_audio method in the xl_base model
rg -n "def generate_audio" acestep/models/xl_base/modeling_acestep_v15_xl_base.py -A 15

Repository: ace-step/ACE-Step-1.5

Length of output: 838


🏁 Script executed:

#!/bin/bash
# Search for apg_guidance import in the xl_base model file
rg "apg_guidance" acestep/models/xl_base/modeling_acestep_v15_xl_base.py -B 2 -A 2

Repository: ace-step/ACE-Step-1.5

Length of output: 434


🏁 Script executed:

#!/bin/bash
# Check docstring in generate_audio method - look at lines 1848-1870 to see if docstring is present
sed -n '1848,1875p' acestep/models/xl_base/modeling_acestep_v15_xl_base.py

Repository: ace-step/ACE-Step-1.5

Length of output: 1342


🏁 Script executed:

#!/bin/bash
# Compare signatures in both apg_guidance files
echo "=== xl_base/apg_guidance.py ===" && \
head -50 acestep/models/xl_base/apg_guidance.py && \
echo -e "\n=== common/apg_guidance.py ===" && \
head -50 acestep/models/common/apg_guidance.py

Repository: ace-step/ACE-Step-1.5

Length of output: 2185


🏁 Script executed:

#!/bin/bash
# Check lines immediately after parameter list to see docstring
sed -n '1870,1885p' acestep/models/xl_base/modeling_acestep_v15_xl_base.py

Repository: ace-step/ACE-Step-1.5

Length of output: 828


🏁 Script executed:

#!/bin/bash
# Check further down to find where parameters end and function body begins
sed -n '1880,1905p' acestep/models/xl_base/modeling_acestep_v15_xl_base.py

Repository: ace-step/ACE-Step-1.5

Length of output: 1310


Add docstring to generate_audio method.

The function signatures for apg_forward and MomentumBuffer.__init__ correctly accept eta and momentum parameters respectively, matching the calls at lines 1983 and 2080. However, the generate_audio method (starting line 1848) lacks a docstring, violating the guideline requirement for mandatory docstrings on all public functions. Add a concise docstring documenting purpose, key parameters, return value, and any raised exceptions.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py` around lines 1892 -
1893, Add a concise docstring to the public method generate_audio describing its
purpose (what kind of audio it generates and any high-level behavior), listing
key parameters (including eta and momentum if they influence generation or are
passed through; mention their types and typical ranges), the return value (type
and meaning, e.g., waveform tensor or file path), and any exceptions that can be
raised (e.g., ValueError for invalid inputs, RuntimeError for inference
failures). Keep it brief and follow the style of existing docstrings in the
class so it satisfies the project's mandatory-docstring guideline for public
functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant