Skip to content

feat(api): expose APG eta and momentum on /release_task HTTP API#1123

Open
FlexOr2 wants to merge 5 commits intoace-step:mainfrom
FlexOr2:feat/http-api-expose-apg-eta-momentum
Open

feat(api): expose APG eta and momentum on /release_task HTTP API#1123
FlexOr2 wants to merge 5 commits intoace-step:mainfrom
FlexOr2:feat/http-api-expose-apg-eta-momentum

Conversation

@FlexOr2
Copy link
Copy Markdown

@FlexOr2 FlexOr2 commented Apr 21, 2026

Summary

Surfaces the two remaining APG (Adaptive Projected Guidance) knobs — eta and momentum — on the /release_task HTTP API, following the pattern of #1092.

Stacking note

This branch is based on the tip of #1092, not on main, because both PRs touch release_task_models.py, release_task_param_parser.py, release_task_request_builder.py, and job_generation_setup.py. Filing as a Draft — I'll rebase onto main and mark ready-for-review once #1092 merges. If you'd prefer I close this and refile after #1092 lands, happy to.

Context

Public community workflows (e.g. the Civitai ACE-Step 1.5 XL workflow) tune APG using eta (parallel-component scaling, currently hardcoded to 0.0 at acestep/models/common/apg_guidance.py:38) and momentum (the MomentumBuffer decay factor, currently hardcoded to -0.75 at acestep/models/common/apg_guidance.py:7). Both are frozen at the MomentumBuffer() / apg_forward() call sites inside each DiT model, so HTTP API clients can't tune them.

#1092 established the pattern for exposing internal DiT params on /release_task. This PR reuses that pattern for eta and momentum so HTTP clients reach parity with ComfyUI users for APG tuning.

Changes

  • GenerationParams gains eta: float = 0.0 and momentum: float = -0.75 (both defaults match the previous hardcoded values — see Backward compatibility below).
  • generate_audio on the four DiT model variants that use APG — base, sft, xl_base, xl_sft — accepts eta and momentum as keyword-only params and forwards them to MomentumBuffer(momentum=...) and apg_forward(..., eta=...).
  • Turbo variants (turbo, xl_turbo) are distilled without CFG → no APG path → not affected.
  • The kwargs thread through generate_music_run_generate_music_service_with_progressservice_generate_build_service_generate_kwargs.
  • HTTP surface: GenerateMusicRequest, PARAM_ALIASES, build_generate_music_request, and build_generation_setup all accept and forward the new fields.
  • Unit tests cover alias resolution, Pydantic defaults, request-builder forwarding, GenerationSetup wiring, and a model-level invariant that omitting the new kwargs yields the same output as passing the historical defaults explicitly.

One unrelated fix bundled

Commit 6e95ec3 also repairs _FakeParser.get() in release_task_request_builder_test.py, which was missing a default= parameter — 5 pre-existing tests were red because of it. Bundled here because the new tests added by this PR exercise the same helper and need the fix to pass.

Per CONTRIBUTING.md's "Solve One Problem at a Time" principle I'd normally split this into its own PR. Happy to do so if you prefer — the parser fix is one commit, isolatable with git cherry-pick.

Backward compatibility

All new parameters are optional. Defaults (eta=0.0, momentum=-0.75) match the values the code used before this PR. Clients that omit the new fields see byte-identical behavior.

MLX scope

PyTorch only. MLX's _mlx_apg_forward in acestep/models/mlx/dit_generate.py:102 has no eta parameter at all, and its momentum state is an unbounded running sum (no decay factor). Exposing these knobs on the MLX backend is a non-trivial separate change — not in scope here per "minimize blast radius."

Test plan

  • python -m unittest acestep.api.http.release_task_param_parser_test acestep.api.http.release_task_models_test acestep.api.http.release_task_request_builder_test acestep.api.job_generation_setup_test acestep.models.common.apg_guidance_test — 38 tests green
  • Smoke test via live /release_task with {"eta": 0.5, "momentum": -0.5} and confirm the values round-trip into the stored GenerationParams
  • Confirm existing /release_task clients still work without specifying the new fields

Related

Summary by CodeRabbit

Release Notes

  • New Features

    • Added advanced generation control parameters for enhanced music creation customization, including sampler algorithm selection, velocity normalization and smoothing, guidance tuning controls, and latent space post-processing adjustments. All parameters include sensible defaults that preserve current behavior when not configured.
  • Tests

    • Expanded test suite to validate parsing, defaults, and application of the new generation parameters.

FlexOr2 and others added 4 commits April 12, 2026 14:51
… API

Five GenerationParams fields that were only accessible via Gradio's
direct Python path are now accepted by the /release_task HTTP endpoint:

- sampler_mode: "euler" (default) or "heun" sampler selection
- velocity_norm_threshold: velocity prediction norm clamping (0=off)
- velocity_ema_factor: velocity EMA smoothing (0=off)
- latent_shift: additive shift on DiT latents before VAE decode
- latent_rescale: multiplicative rescale on DiT latents before VAE decode

All five already exist on the internal GenerationParams dataclass and
are wired through the diffusion loop. This change adds them to:
- GenerateMusicRequest (Pydantic request model)
- PARAM_ALIASES (camelCase alias support)
- build_generate_music_request (request builder)
- build_generation_setup (GenerationParams wiring)

Tests added for alias resolution, request builder forwarding, and
generation setup wiring (including getattr fallback for older callers).

Default values match the existing GenerationParams defaults, so
omitting them preserves current behavior. Non-target code paths
(Gradio UI, training, CLI) are unchanged.
The APG (Adaptive Projected Guidance) `eta` and `momentum` knobs were
hardcoded at the MomentumBuffer() and apg_forward() call sites inside
each DiT model. Lift them into the generation plumbing so callers can
tune APG without patching model code:

- apg_forward now receives `eta` from the model's generate_audio call
  (same default 0.0 preserved).
- MomentumBuffer is instantiated with the caller's `momentum` value
  (same default -0.75 preserved).
- generate_audio on all four DiT model variants (base, sft, xl_base,
  xl_sft) accepts `eta` and `momentum` as keyword-only args.
- GenerationParams gains two new fields (eta, momentum) that thread
  down through generate_music -> _run_generate_music_service_with_progress
  -> service_generate -> _build_service_generate_kwargs ->
  model.generate_audio.

MLX path is intentionally not touched: `_mlx_apg_forward` has no slot
for `eta`, and the MLX MomentumBuffer equivalent uses a fixed
running-sum (no decay factor). Exposing these knobs on MLX is a
separate upstream change. PyTorch path is unaffected by the MLX gap
since the new kwargs are consumed by the PyTorch model's
generate_audio directly.

All new parameters are keyword-only with defaults equal to the
pre-existing hardcoded values, so omitting them preserves current
behavior. Unit test (acestep/models/common/apg_guidance_test.py)
verifies this invariant at the apg_forward / MomentumBuffer level.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
The two APG guidance knobs exposed here (`eta` and `momentum`) were
already plumbed through GenerationParams and the DiT model stack in
the previous commit, but could not be set from the HTTP surface. This
change wires them into the /release_task request path, mirroring how
PR ace-step#1092 exposed sampler_mode and the four DiT latent params:

- GenerateMusicRequest: new `eta: float = 0.0` and `momentum: float =
  -0.75` fields (defaults match apg_guidance.py hardcoded values).
- PARAM_ALIASES: identity aliases for `eta` and `momentum` (both names
  are single lowercase words, so snake_case == camelCase).
- build_generate_music_request: forwards parsed values into the
  Pydantic request model.
- build_generation_setup: uses `getattr(req, ...)` with the same
  defaults, matching the existing sampler_mode/latent_shift pattern so
  older callers that don't set these fields still work.

The _FakeParser helper in release_task_request_builder_test.py gained
a `default=None` parameter on `.get()` to match the real parser's
contract — this was a pre-existing mismatch that broke several
unrelated tests; fixing it here lets the full builder suite run green.

Tests added alongside the changes cover:
- alias resolution for `eta` and `momentum` (flat body + nested
  param_obj),
- default values on the Pydantic model,
- request-builder forwarding and default-fallback behavior,
- generation-setup forwarding into GenerationParams,
- the getattr-default path when a request object lacks the fields.

All new fields are optional with defaults matching the pre-existing
hardcoded APG values; omitting them preserves current HTTP behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

This pull request adds new generation control parameters to the music generation API and threads them through the entire generation pipeline. New fields enable caller control over diffusion sampling (sampler_mode), velocity post-processing (velocity_norm_threshold, velocity_ema_factor), adaptive projected guidance (eta, momentum), and latent preprocessing (latent_shift, latent_rescale). Parameters flow from HTTP requests through request models, parameter parsing, generation handlers, and into model generation methods.

Changes

Cohort / File(s) Summary
Request Model & HTTP API
acestep/api/http/release_task_models.py, acestep/api/http/release_task_param_parser.py
New request model fields added for sampler mode, velocity controls, APG guidance, and latent adjustments; parameter alias mapping extended to recognize and parse all new canonical fields.
Request Building & Setup
acestep/api/http/release_task_request_builder.py, acestep/api/job_generation_setup.py
Request builders now extract and forward new parameters to GenerateMusicRequest; job generation setup populates corresponding GenerationParams fields with explicit defaults when request omits them.
Generation Handler Chain
acestep/core/generation/handler/generate_music.py, acestep/core/generation/handler/generate_music_execute.py, acestep/core/generation/handler/service_generate.py, acestep/core/generation/handler/service_generate_execute.py
Handler methods updated to accept and thread eta and momentum parameters through the generation execution path, with defaults set to eta=0.0 and momentum=-0.75.
Inference & Model Integration
acestep/inference.py, acestep/models/base/modeling_acestep_v15_base.py, acestep/models/sft/modeling_acestep_v15_base.py, acestep/models/xl_base/modeling_acestep_v15_xl_base.py, acestep/models/xl_sft/modeling_acestep_v15_xl_base.py
GenerationParams dataclass extended with new fields; all model generate_audio methods updated to accept and apply eta and momentum to MomentumBuffer and APG guidance calls.
Request & Parameter Tests
acestep/api/http/release_task_models_test.py, acestep/api/http/release_task_param_parser_test.py, acestep/api/http/release_task_request_builder_test.py, acestep/api/job_generation_setup_test.py
Unit tests added for request model field defaults, parameter alias resolution with type coercion, request builder parameter forwarding and default fallback, and generation setup parameter propagation.
Guidance Logic Tests
acestep/models/common/apg_guidance_test.py
New test module verifying MomentumBuffer initialization, momentum override behavior, running-average update rule, and apg_forward backward compatibility with and without eta parameter.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant RequestParser
    participant RequestBuilder
    participant GenerationSetup
    participant GenerationHandler
    participant Model
    
    Client->>RequestParser: HTTP request with eta, momentum, sampler_mode, etc.
    RequestParser->>RequestParser: Resolve parameters via PARAM_ALIASES
    RequestParser-->>RequestBuilder: Return parsed parameters
    RequestBuilder->>RequestBuilder: Build GenerateMusicRequest with new fields
    RequestBuilder-->>GenerationSetup: Pass request object
    GenerationSetup->>GenerationSetup: Create GenerationParams with eta, momentum, sampler_mode
    GenerationSetup-->>GenerationHandler: Pass setup configuration
    GenerationHandler->>GenerationHandler: Extract eta, momentum from setup.params
    GenerationHandler-->>Model: Call generate_audio(eta=..., momentum=...)
    Model->>Model: Initialize MomentumBuffer(momentum=momentum)
    Model->>Model: Call apg_forward(..., eta=eta, momentum_buffer=...)
    Model-->>Client: Return generated audio
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • PR #452: Adds DiT latent post-processing parameters (latent_shift, latent_rescale) to the same generation pipeline and threads them through identical call paths.
  • PR #1120: Modifies the generation handler chain (generate_music, execution/service handlers, GenerationParams) by adding sampler-related keyword parameters with consistent threading pattern.
  • PR #978: Threads sampler configuration parameters (sampler_mode, velocity controls) through the same generation call stack and handlers.

Suggested reviewers

  • ChuxiJ
  • ElWalki

Poem

🐰 Through pipelines deep, new parameters hop,
From requests down to models' top,
With eta, momentum, and sampler's might,
The guidance flows with diffusion's light! ✨
Each layer threads the control with care,
Making music generation beyond compare! 🎵

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main feature: exposing APG eta and momentum parameters on the HTTP API endpoint.
Docstring Coverage ✅ Passed Docstring coverage is 86.84% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Resolves conflicts in generate_audio() / handler signatures, kwargs
forwarding, and inference.py GenerationParams where the upstream DCW
work (dcw_enabled, dcw_mode, dcw_scaler, dcw_high_scaler, dcw_wavelet,
plus timesteps in base) and the upstream task_type kwarg landed
alongside this branch's eta / momentum APG additions. Conflicts were
all in adjacent kwarg blocks, dict literals, and docstrings; resolution
keeps both sides.
@FlexOr2 FlexOr2 marked this pull request as ready for review April 25, 2026 15:53
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
acestep/core/generation/handler/service_generate_execute.py (1)

131-132: Optional: warn when eta/momentum are set on the MLX path.

eta and momentum are accepted into generate_kwargs but the MLX branch (lines 175-236) never reads them — by design per the PR description ("MLX backend is out of scope"). However, a user passing non-default eta/momentum via /release_task while the MLX backend is active will get silent no-op semantics. A one-line logger.info (similar to the existing DCW/haar notice on lines 176-183) when (eta != 0.0 or momentum != -0.75) and self.use_mlx_dit is true would prevent confusion and reduce future support load.

♻️ Suggested log on the MLX branch
                 if self.use_mlx_dit and self.mlx_decoder is not None:
+                    if generate_kwargs.get("eta", 0.0) != 0.0 or generate_kwargs.get("momentum", -0.75) != -0.75:
+                        logger.info(
+                            "[service_generate] APG eta/momentum overrides "
+                            "(eta={}, momentum={}) are ignored on the MLX backend; "
+                            "these are PyTorch-only.",
+                            generate_kwargs.get("eta"),
+                            generate_kwargs.get("momentum"),
+                        )
                     if generate_kwargs.get("dcw_enabled") and generate_kwargs.get("dcw_wavelet", "haar") != "haar":

Also applies to: 175-236

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/service_generate_execute.py` around lines 131
- 132, The MLX code path silently ignores eta and momentum passed through
generate_kwargs; update the MLX branch in service_generate_execute.py (the block
guarded by self.use_mlx_dit) to log a one-line informational warning when
non-default values are provided: check (eta != 0.0 or momentum != -0.75) and, if
true and self.use_mlx_dit is set, call logger.info with a concise message
stating that eta/momentum are ignored on the MLX backend; place this adjacent to
the existing DCW/haar notice so users see the same kind of feedback as in the
other branch and reference the variables eta, momentum, generate_kwargs,
self.use_mlx_dit, and logger.info when making the change.
acestep/api/http/release_task_models.py (2)

100-103: Recommend constraining sampler_mode to Literal["euler", "heun"] for consistency with other enum-like fields.

The description restricts the value to "euler" or "heun", and other enum-style fields in this model use Literal[...] (chunk_mask_mode L65, repaint_mode L74). Today, an arbitrary string passes Pydantic validation and is then silently coerced to euler at the model layer (use_heun = sampler_mode == "heun" in modeling_acestep_v15_base.py), so a typo like "heum" would produce a wrong-but-quiet sampler choice rather than a clean 422. Tightening the type at the API boundary surfaces the mistake to the caller.

♻️ Suggested change
-    sampler_mode: str = Field(
+    sampler_mode: Literal["euler", "heun"] = Field(
         default="euler",
         description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).",
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/api/http/release_task_models.py` around lines 100 - 103, The
sampler_mode field currently declared as sampler_mode: str = Field(...) should
be constrained to a Literal to enforce only "euler" or "heun"; update the
annotation to sampler_mode: Literal["euler", "heun"] = Field(default="euler",
description=...) and add an import for Literal (from typing import Literal) if
not already present so Pydantic will return a 422 for invalid values instead of
silently accepting arbitrary strings; keep the default "euler" and preserve the
existing description.

104-127: Optional: add range validation to the new numeric knobs.

These fields propagate directly into the diffusion loop without further validation. A few sensible Pydantic constraints would catch obvious misuse early:

  • velocity_norm_threshold: ge=0.0 (negative makes no sense — use_norm_clamp checks > 0.0).
  • velocity_ema_factor: typically ge=0.0, le=1.0; values outside this break the convex blend in vt = (1 - f) * vt + f * prev_vt.
  • latent_rescale: gt=0.0 (zero rescale would null out latents; negative flips sign).
  • eta / momentum are passed through APG math where any float is technically defined; leaving them unconstrained is fine, but you may want a soft sanity range (e.g. eta in [0, 1]) if matching upstream APG conventions.

Not a blocker — purely defensive.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/api/http/release_task_models.py` around lines 104 - 127, Add Pydantic
range constraints to the new numeric knobs so invalid values are caught early:
set velocity_norm_threshold Field to ge=0.0, set velocity_ema_factor Field to
ge=0.0 and le=1.0, and set latent_rescale Field to gt=0.0; leave eta and
momentum unconstrained (or optionally add soft ranges like eta between 0 and 1)
but ensure the changes are applied to the Field declarations for
velocity_norm_threshold, velocity_ema_factor, latent_rescale (and optionally
eta/momentum) in the ReleaseTask model.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/models/common/apg_guidance_test.py`:
- Around line 89-107: The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.

---

Nitpick comments:
In `@acestep/api/http/release_task_models.py`:
- Around line 100-103: The sampler_mode field currently declared as
sampler_mode: str = Field(...) should be constrained to a Literal to enforce
only "euler" or "heun"; update the annotation to sampler_mode: Literal["euler",
"heun"] = Field(default="euler", description=...) and add an import for Literal
(from typing import Literal) if not already present so Pydantic will return a
422 for invalid values instead of silently accepting arbitrary strings; keep the
default "euler" and preserve the existing description.
- Around line 104-127: Add Pydantic range constraints to the new numeric knobs
so invalid values are caught early: set velocity_norm_threshold Field to ge=0.0,
set velocity_ema_factor Field to ge=0.0 and le=1.0, and set latent_rescale Field
to gt=0.0; leave eta and momentum unconstrained (or optionally add soft ranges
like eta between 0 and 1) but ensure the changes are applied to the Field
declarations for velocity_norm_threshold, velocity_ema_factor, latent_rescale
(and optionally eta/momentum) in the ReleaseTask model.

In `@acestep/core/generation/handler/service_generate_execute.py`:
- Around line 131-132: The MLX code path silently ignores eta and momentum
passed through generate_kwargs; update the MLX branch in
service_generate_execute.py (the block guarded by self.use_mlx_dit) to log a
one-line informational warning when non-default values are provided: check (eta
!= 0.0 or momentum != -0.75) and, if true and self.use_mlx_dit is set, call
logger.info with a concise message stating that eta/momentum are ignored on the
MLX backend; place this adjacent to the existing DCW/haar notice so users see
the same kind of feedback as in the other branch and reference the variables
eta, momentum, generate_kwargs, self.use_mlx_dit, and logger.info when making
the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 19c8b5d8-a49e-46f7-b003-7b422952cd7b

📥 Commits

Reviewing files that changed from the base of the PR and between d5d958e and a2a397a.

📒 Files selected for processing (18)
  • acestep/api/http/release_task_models.py
  • acestep/api/http/release_task_models_test.py
  • acestep/api/http/release_task_param_parser.py
  • acestep/api/http/release_task_param_parser_test.py
  • acestep/api/http/release_task_request_builder.py
  • acestep/api/http/release_task_request_builder_test.py
  • acestep/api/job_generation_setup.py
  • acestep/api/job_generation_setup_test.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/inference.py
  • acestep/models/base/modeling_acestep_v15_base.py
  • acestep/models/common/apg_guidance_test.py
  • acestep/models/sft/modeling_acestep_v15_base.py
  • acestep/models/xl_base/modeling_acestep_v15_xl_base.py
  • acestep/models/xl_sft/modeling_acestep_v15_xl_base.py

Comment on lines +89 to +107
def test_default_momentum_buffer_matches_explicit_neg_075(self):
"""Default MomentumBuffer() must produce the same trajectory as momentum=-0.75."""

pred_cond, pred_uncond = self._make_inputs()
implicit_out = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_uncond,
guidance_scale=3.0,
momentum_buffer=MomentumBuffer(),
dims=[1],
)
explicit_out = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_uncond,
guidance_scale=3.0,
momentum_buffer=MomentumBuffer(momentum=-0.75),
dims=[1],
)
self.assertTrue(torch.allclose(implicit_out, explicit_out))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

This default-momentum test is a tautology and does not actually verify the -0.75 default.

Each MomentumBuffer is freshly constructed with running_average = 0, and update() computes running_average = momentum * 0 + diff = diff on the first call regardless of momentum. Because apg_forward is invoked exactly once per buffer, both branches produce the same output for any momentum value (including, e.g., +0.5). If someone changed MomentumBuffer's default to a different value, this test would still pass.

To actually exercise the default, either (a) invoke apg_forward twice on each buffer so the second update sees a non-zero running average, or (b) contrast against a buffer with a clearly different momentum and assert non-equality.

🛠️ Suggested fix — drive the buffers through two updates and contrast against a mismatched momentum
     def test_default_momentum_buffer_matches_explicit_neg_075(self):
         """Default MomentumBuffer() must produce the same trajectory as momentum=-0.75."""

         pred_cond, pred_uncond = self._make_inputs()
-        implicit_out = apg_forward(
-            pred_cond=pred_cond,
-            pred_uncond=pred_uncond,
-            guidance_scale=3.0,
-            momentum_buffer=MomentumBuffer(),
-            dims=[1],
-        )
-        explicit_out = apg_forward(
-            pred_cond=pred_cond,
-            pred_uncond=pred_uncond,
-            guidance_scale=3.0,
-            momentum_buffer=MomentumBuffer(momentum=-0.75),
-            dims=[1],
-        )
-        self.assertTrue(torch.allclose(implicit_out, explicit_out))
+        implicit_buf = MomentumBuffer()
+        explicit_buf = MomentumBuffer(momentum=-0.75)
+        mismatched_buf = MomentumBuffer(momentum=0.5)
+        for buf in (implicit_buf, explicit_buf, mismatched_buf):
+            apg_forward(
+                pred_cond=pred_cond,
+                pred_uncond=pred_uncond,
+                guidance_scale=3.0,
+                momentum_buffer=buf,
+                dims=[1],
+            )
+        # Second call exercises momentum * prev term so the default value matters.
+        implicit_out = apg_forward(
+            pred_cond=pred_cond,
+            pred_uncond=pred_uncond,
+            guidance_scale=3.0,
+            momentum_buffer=implicit_buf,
+            dims=[1],
+        )
+        explicit_out = apg_forward(
+            pred_cond=pred_cond,
+            pred_uncond=pred_uncond,
+            guidance_scale=3.0,
+            momentum_buffer=explicit_buf,
+            dims=[1],
+        )
+        mismatched_out = apg_forward(
+            pred_cond=pred_cond,
+            pred_uncond=pred_uncond,
+            guidance_scale=3.0,
+            momentum_buffer=mismatched_buf,
+            dims=[1],
+        )
+        self.assertTrue(torch.allclose(implicit_out, explicit_out))
+        self.assertFalse(torch.allclose(implicit_out, mismatched_out))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/common/apg_guidance_test.py` around lines 89 - 107, The test
test_default_momentum_buffer_matches_explicit_neg_075 is tautological because
each MomentumBuffer is fresh and the first update is momentum-independent;
modify the test so the buffers see a non-zero running_average before the
equality check: e.g., call apg_forward twice for each buffer (using the same
pred_cond/pred_uncond/guidance_scale/dims) so the second update depends on
momentum, then assert the default MomentumBuffer() output equals
MomentumBuffer(momentum=-0.75) output on that second call; additionally add a
negative test that the default does not match a clearly different momentum
(e.g., MomentumBuffer(momentum=+0.5)) to ensure the default is actually -0.75.
Ensure changes are made inside the
test_default_momentum_buffer_matches_explicit_neg_075 function and reference
apg_forward and MomentumBuffer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant