diff --git a/acestep/api/http/release_task_models.py b/acestep/api/http/release_task_models.py index 87042c01d..ef75c942c 100644 --- a/acestep/api/http/release_task_models.py +++ b/acestep/api/http/release_task_models.py @@ -97,6 +97,34 @@ class GenerateMusicRequest(BaseModel): default=None, description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift.", ) + sampler_mode: str = Field( + default="euler", + description="Diffusion sampler mode: 'euler' (first-order) or 'heun' (second-order predictor-corrector).", + ) + velocity_norm_threshold: float = Field( + default=0.0, + description="Clamp velocity prediction norms during diffusion (0 = disabled, try 2.0).", + ) + velocity_ema_factor: float = Field( + default=0.0, + description="Velocity EMA smoothing during diffusion (0 = disabled, try 0.1).", + ) + eta: float = Field( + default=0.0, + description="APG guidance eta: parallel-component scaling factor (0.0 = pure orthogonal guidance).", + ) + momentum: float = Field( + default=-0.75, + description="APG MomentumBuffer decay factor for accumulated guidance diff (default -0.75).", + ) + latent_shift: float = Field( + default=0.0, + description="Additive shift applied to DiT latents before VAE decode (0 = no shift).", + ) + latent_rescale: float = Field( + default=1.0, + description="Multiplicative rescale applied to DiT latents before VAE decode (1.0 = no rescale).", + ) audio_format: str = Field( default="mp3", diff --git a/acestep/api/http/release_task_models_test.py b/acestep/api/http/release_task_models_test.py index 238b0fe6f..0b30e2cac 100644 --- a/acestep/api/http/release_task_models_test.py +++ b/acestep/api/http/release_task_models_test.py @@ -33,6 +33,20 @@ def test_audio_code_string_and_cover_noise_strength_are_accepted(self): self.assertEqual("<|audio_code_1|>", req.audio_code_string) self.assertAlmostEqual(0.75, req.cover_noise_strength) + def test_apg_eta_and_momentum_have_expected_defaults(self): + """APG eta/momentum defaults must match the hardcoded values in apg_guidance.""" + + req = GenerateMusicRequest() + self.assertAlmostEqual(0.0, req.eta) + self.assertAlmostEqual(-0.75, req.momentum) + + def test_apg_eta_and_momentum_are_accepted(self): + """Model should accept user-supplied eta and momentum values.""" + + req = GenerateMusicRequest(eta=0.5, momentum=-0.5) + self.assertAlmostEqual(0.5, req.eta) + self.assertAlmostEqual(-0.5, req.momentum) + if __name__ == "__main__": unittest.main() diff --git a/acestep/api/http/release_task_param_parser.py b/acestep/api/http/release_task_param_parser.py index 3b94d530f..2bd2b4bc5 100644 --- a/acestep/api/http/release_task_param_parser.py +++ b/acestep/api/http/release_task_param_parser.py @@ -43,6 +43,13 @@ "instruction": ["instruction"], "track_name": ["track_name", "trackName"], "track_classes": ["track_classes", "trackClasses", "instruments"], + "sampler_mode": ["sampler_mode", "samplerMode"], + "velocity_norm_threshold": ["velocity_norm_threshold", "velocityNormThreshold"], + "velocity_ema_factor": ["velocity_ema_factor", "velocityEmaFactor"], + "eta": ["eta"], + "momentum": ["momentum"], + "latent_shift": ["latent_shift", "latentShift"], + "latent_rescale": ["latent_rescale", "latentRescale"], } diff --git a/acestep/api/http/release_task_param_parser_test.py b/acestep/api/http/release_task_param_parser_test.py index 889475103..54e003afd 100644 --- a/acestep/api/http/release_task_param_parser_test.py +++ b/acestep/api/http/release_task_param_parser_test.py @@ -65,5 +65,38 @@ def test_non_dict_param_obj_json_is_ignored(self): self.assertEqual("meta-caption", parser.str("prompt")) + def test_sampler_and_dit_param_aliases_are_resolved(self): + """Parser should resolve camelCase aliases for sampler/DiT params.""" + + parser = RequestParser( + { + "samplerMode": "heun", + "velocityNormThreshold": "2.5", + "latentShift": "0.1", + "velocityEmaFactor": "0.95", + "latentRescale": "1.2", + } + ) + self.assertEqual("heun", parser.str("sampler_mode")) + self.assertAlmostEqual(2.5, parser.float("velocity_norm_threshold")) + self.assertAlmostEqual(0.1, parser.float("latent_shift")) + self.assertAlmostEqual(0.95, parser.float("velocity_ema_factor")) + self.assertAlmostEqual(1.2, parser.float("latent_rescale")) + + def test_apg_eta_and_momentum_aliases_are_resolved(self): + """Parser should resolve eta/momentum APG params from snake_case keys.""" + + parser = RequestParser({"eta": "0.5", "momentum": "-0.5"}) + self.assertAlmostEqual(0.5, parser.float("eta")) + self.assertAlmostEqual(-0.5, parser.float("momentum")) + + def test_apg_eta_and_momentum_resolve_from_nested_param_obj(self): + """Parser should resolve eta/momentum from nested param_obj payload.""" + + parser = RequestParser({"param_obj": {"eta": 0.75, "momentum": -0.9}}) + self.assertAlmostEqual(0.75, parser.float("eta")) + self.assertAlmostEqual(-0.9, parser.float("momentum")) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/api/http/release_task_request_builder.py b/acestep/api/http/release_task_request_builder.py index a22862a06..643305b2f 100644 --- a/acestep/api/http/release_task_request_builder.py +++ b/acestep/api/http/release_task_request_builder.py @@ -81,6 +81,13 @@ def build_generate_music_request( cfg_interval_end=parser.float("cfg_interval_end", 1.0), infer_method=parser.str("infer_method", "ode"), shift=parser.float("shift", 3.0), + sampler_mode=parser.str("sampler_mode", "euler"), + velocity_norm_threshold=parser.float("velocity_norm_threshold", 0.0), + velocity_ema_factor=parser.float("velocity_ema_factor", 0.0), + eta=parser.float("eta", 0.0), + momentum=parser.float("momentum", -0.75), + latent_shift=parser.float("latent_shift", 0.0), + latent_rescale=parser.float("latent_rescale", 1.0), audio_format=parser.str("audio_format", "mp3"), use_tiled_decode=parser.bool("use_tiled_decode", True), lm_model_path=parser.str("lm_model_path") or None, diff --git a/acestep/api/http/release_task_request_builder_test.py b/acestep/api/http/release_task_request_builder_test.py index 72a0a0682..6554cc72c 100644 --- a/acestep/api/http/release_task_request_builder_test.py +++ b/acestep/api/http/release_task_request_builder_test.py @@ -14,10 +14,10 @@ def __init__(self, values: dict) -> None: self._values = values - def get(self, key: str): - """Return raw value for ``key`` from parser payload.""" + def get(self, key: str, default=None): + """Return raw value for ``key`` from parser payload with optional default.""" - return self._values.get(key) + return self._values.get(key, default) def str(self, key: str, default: str = "") -> str: """Return string value for ``key`` with default fallback.""" @@ -135,5 +135,65 @@ def test_build_request_forwards_audio_code_string_and_cover_noise_strength(self) self.assertAlmostEqual(0.6, request.cover_noise_strength) + def test_build_request_forwards_sampler_and_dit_latent_params(self): + """Builder should include sampler_mode and latent post-processing params.""" + + parser = _FakeParser( + { + "sampler_mode": "heun", + "velocity_norm_threshold": 2.0, + "velocity_ema_factor": 0.1, + "latent_shift": 0.05, + "latent_rescale": 1.2, + } + ) + request = build_generate_music_request( + parser=parser, + request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs), + default_dit_instruction="default-instruction", + lm_default_temperature=0.85, + lm_default_cfg_scale=2.5, + lm_default_top_p=0.9, + ) + + self.assertEqual("heun", request.sampler_mode) + self.assertAlmostEqual(2.0, request.velocity_norm_threshold) + self.assertAlmostEqual(0.1, request.velocity_ema_factor) + self.assertAlmostEqual(0.05, request.latent_shift) + self.assertAlmostEqual(1.2, request.latent_rescale) + + def test_build_request_forwards_apg_eta_and_momentum(self): + """Builder should include APG eta/momentum params in the payload.""" + + parser = _FakeParser({"eta": 0.5, "momentum": -0.5}) + request = build_generate_music_request( + parser=parser, + request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs), + default_dit_instruction="default-instruction", + lm_default_temperature=0.85, + lm_default_cfg_scale=2.5, + lm_default_top_p=0.9, + ) + + self.assertAlmostEqual(0.5, request.eta) + self.assertAlmostEqual(-0.5, request.momentum) + + def test_build_request_defaults_apg_eta_and_momentum_when_absent(self): + """Builder should fall back to APG hardcoded defaults when params are absent.""" + + parser = _FakeParser({}) + request = build_generate_music_request( + parser=parser, + request_model_cls=lambda **kwargs: SimpleNamespace(**kwargs), + default_dit_instruction="default-instruction", + lm_default_temperature=0.85, + lm_default_cfg_scale=2.5, + lm_default_top_p=0.9, + ) + + self.assertAlmostEqual(0.0, request.eta) + self.assertAlmostEqual(-0.75, request.momentum) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/api/job_generation_setup.py b/acestep/api/job_generation_setup.py index 34ee446f1..b5e0ffccb 100644 --- a/acestep/api/job_generation_setup.py +++ b/acestep/api/job_generation_setup.py @@ -174,6 +174,13 @@ def build_generation_setup( cfg_interval_end=req.cfg_interval_end, shift=req.shift, infer_method=req.infer_method, + sampler_mode=getattr(req, "sampler_mode", "euler"), + velocity_norm_threshold=getattr(req, "velocity_norm_threshold", 0.0), + velocity_ema_factor=getattr(req, "velocity_ema_factor", 0.0), + eta=getattr(req, "eta", 0.0), + momentum=getattr(req, "momentum", -0.75), + latent_shift=getattr(req, "latent_shift", 0.0), + latent_rescale=getattr(req, "latent_rescale", 1.0), timesteps=parsed_timesteps, repainting_start=req.repainting_start, repainting_end=req.repainting_end if req.repainting_end else -1, diff --git a/acestep/api/job_generation_setup_test.py b/acestep/api/job_generation_setup_test.py index 3ec650b94..20eef9956 100644 --- a/acestep/api/job_generation_setup_test.py +++ b/acestep/api/job_generation_setup_test.py @@ -263,5 +263,131 @@ def test_use_cot_metas_enabled_when_format_has_duration(self) -> None: self.assertTrue(setup.params.use_cot_metas) + def test_build_generation_setup_forwards_sampler_and_dit_latent_params(self) -> None: + """Sampler mode and latent post-processing params should be forwarded.""" + + req = _base_req() + req.sampler_mode = "heun" + req.velocity_norm_threshold = 2.0 + req.velocity_ema_factor = 0.1 + req.latent_shift = 0.05 + req.latent_rescale = 1.2 + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertEqual("heun", setup.params.sampler_mode) + self.assertAlmostEqual(2.0, setup.params.velocity_norm_threshold) + self.assertAlmostEqual(0.1, setup.params.velocity_ema_factor) + self.assertAlmostEqual(0.05, setup.params.latent_shift) + self.assertAlmostEqual(1.2, setup.params.latent_rescale) + + def test_build_generation_setup_defaults_sampler_params_when_missing(self) -> None: + """When req lacks sampler/latent fields, getattr defaults should apply.""" + + req = _base_req() + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertEqual("euler", setup.params.sampler_mode) + self.assertAlmostEqual(0.0, setup.params.velocity_norm_threshold) + self.assertAlmostEqual(0.0, setup.params.velocity_ema_factor) + self.assertAlmostEqual(0.0, setup.params.latent_shift) + self.assertAlmostEqual(1.0, setup.params.latent_rescale) + + def test_build_generation_setup_forwards_apg_eta_and_momentum(self) -> None: + """APG eta and momentum should be forwarded to GenerationParams.""" + + req = _base_req() + req.eta = 0.5 + req.momentum = -0.5 + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertAlmostEqual(0.5, setup.params.eta) + self.assertAlmostEqual(-0.5, setup.params.momentum) + + def test_build_generation_setup_defaults_apg_params_when_missing(self) -> None: + """When req lacks eta/momentum fields, the APG hardcoded defaults should apply.""" + + req = _base_req() + setup = build_generation_setup( + req=req, + caption="cap", + lyrics="lyr", + bpm=None, + key_scale="", + time_signature="", + audio_duration=None, + thinking=False, + sample_mode=False, + format_has_duration=False, + use_cot_caption=False, + use_cot_language=False, + lm_top_k=0, + lm_top_p=0.9, + parse_timesteps=lambda _value: None, + is_instrumental=lambda _lyrics: False, + default_dit_instruction="default instruction", + task_instructions={}, + ) + + self.assertAlmostEqual(0.0, setup.params.eta) + self.assertAlmostEqual(-0.75, setup.params.momentum) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/core/generation/handler/generate_music.py b/acestep/core/generation/handler/generate_music.py index d7dbdc950..5002705c2 100644 --- a/acestep/core/generation/handler/generate_music.py +++ b/acestep/core/generation/handler/generate_music.py @@ -214,6 +214,8 @@ def generate_music( dcw_scaler: float = 0.05, dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", + eta: float = 0.0, + momentum: float = -0.75, use_tiled_decode: bool = True, timesteps: Optional[List[float]] = None, latent_shift: float = 0.0, @@ -380,6 +382,8 @@ def generate_music( dcw_scaler=dcw_scaler, dcw_high_scaler=dcw_high_scaler, dcw_wavelet=dcw_wavelet, + eta=eta, + momentum=momentum, repaint_crossfade_frames=resolved_cf_frames, repaint_injection_ratio=injection_ratio, task_type=task_type, diff --git a/acestep/core/generation/handler/generate_music_execute.py b/acestep/core/generation/handler/generate_music_execute.py index 7077ad796..af35afc5c 100644 --- a/acestep/core/generation/handler/generate_music_execute.py +++ b/acestep/core/generation/handler/generate_music_execute.py @@ -41,6 +41,8 @@ def _run_generate_music_service_with_progress( dcw_scaler: float = 0.05, dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", + eta: float = 0.0, + momentum: float = -0.75, repaint_crossfade_frames: int = 10, repaint_injection_ratio: float = 0.5, task_type: str = "", @@ -96,6 +98,8 @@ def _service_target(): dcw_scaler=dcw_scaler, dcw_high_scaler=dcw_high_scaler, dcw_wavelet=dcw_wavelet, + eta=eta, + momentum=momentum, audio_code_hints=service_inputs["audio_code_hints_batch"], return_intermediate=service_inputs["should_return_intermediate"], timesteps=timesteps, diff --git a/acestep/core/generation/handler/service_generate.py b/acestep/core/generation/handler/service_generate.py index 05591ea56..7d2a85009 100644 --- a/acestep/core/generation/handler/service_generate.py +++ b/acestep/core/generation/handler/service_generate.py @@ -56,6 +56,8 @@ def service_generate( dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", task_type: str = "", + eta: float = 0.0, + momentum: float = -0.75, ) -> Dict[str, Any]: """Generate music latents and metadata from text/audio conditioning inputs. @@ -102,6 +104,8 @@ def service_generate( ``"sym8"``. task_type: Generation task selector used when preparing conditioning masks. + eta: APG parallel-component scaling factor (0.0 = pure orthogonal guidance). + momentum: APG MomentumBuffer decay factor (default -0.75). Returns: Dict[str, Any]: Service output payload containing generated latents, @@ -169,6 +173,8 @@ def service_generate( dcw_scaler=dcw_scaler, dcw_high_scaler=dcw_high_scaler, dcw_wavelet=dcw_wavelet, + eta=eta, + momentum=momentum, ) outputs, encoder_hidden_states, encoder_attention_mask, context_latents = ( self._execute_service_generate_diffusion( diff --git a/acestep/core/generation/handler/service_generate_execute.py b/acestep/core/generation/handler/service_generate_execute.py index d94e8b945..8009bbd7f 100644 --- a/acestep/core/generation/handler/service_generate_execute.py +++ b/acestep/core/generation/handler/service_generate_execute.py @@ -85,6 +85,8 @@ def _build_service_generate_kwargs( dcw_scaler: float = 0.05, dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", + eta: float = 0.0, + momentum: float = -0.75, ) -> Dict[str, Any]: """Build kwargs passed to model generation backends.""" repaint_mask = payload.get("repaint_mask") @@ -126,6 +128,8 @@ def _build_service_generate_kwargs( "dcw_scaler": dcw_scaler, "dcw_high_scaler": dcw_high_scaler, "dcw_wavelet": dcw_wavelet, + "eta": eta, + "momentum": momentum, } if timesteps is not None: kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32, device=self.device) diff --git a/acestep/inference.py b/acestep/inference.py index e4183f29a..a1d455749 100644 --- a/acestep/inference.py +++ b/acestep/inference.py @@ -59,6 +59,8 @@ class GenerationParams: normalization_db: Target loudness in dB for normalization (e.g., -1.0 for -1 dBFS peak). latent_shift: Additive shift applied to DiT latents before VAE decode (default 0, no shift). latent_rescale: Multiplicative rescale applied to DiT latents before VAE decode (default 1.0, no rescale). + eta: APG guidance eta (parallel-component scaling factor, default 0.0 = pure orthogonal guidance). + momentum: APG MomentumBuffer decay factor (default -0.75). # Generation Parameters inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model). @@ -138,6 +140,8 @@ class GenerationParams: sampler_mode: str = "euler" # "euler" (first-order) or "heun" (second-order predictor-corrector) velocity_norm_threshold: float = 0.0 # Clamp velocity prediction norms (0 = disabled, try 2.0) velocity_ema_factor: float = 0.0 # Velocity EMA smoothing (0 = disabled, try 0.1) + eta: float = 0.0 # APG eta: parallel-component scaling in adaptive projected guidance (default 0.0, pure orthogonal). + momentum: float = -0.75 # APG momentum: MomentumBuffer decay factor for accumulated guidance diff (default -0.75). # DCW — Differential Correction in Wavelet domain (CVPR 2026, arXiv:2604.16044). # On by default to mitigate SNR-t bias via per-band wavelet-domain correction # at each sampler step. Uses `pytorch_wavelets` + `PyWavelets` (managed deps). @@ -661,6 +665,8 @@ def generate_music( "sampler_mode": params.sampler_mode, "velocity_norm_threshold": params.velocity_norm_threshold, "velocity_ema_factor": params.velocity_ema_factor, + "eta": params.eta, + "momentum": params.momentum, "dcw_enabled": params.dcw_enabled, "dcw_mode": params.dcw_mode, "dcw_scaler": params.dcw_scaler, diff --git a/acestep/models/base/modeling_acestep_v15_base.py b/acestep/models/base/modeling_acestep_v15_base.py index cf1b86deb..b14fa243b 100644 --- a/acestep/models/base/modeling_acestep_v15_base.py +++ b/acestep/models/base/modeling_acestep_v15_base.py @@ -1875,6 +1875,8 @@ def generate_audio( dcw_scaler: float = 0.05, dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", + eta: float = 0.0, + momentum: float = -0.75, **kwargs, ): # Backward-compat: accept the old misspelled key "diffusion_guidance_sale" @@ -1956,7 +1958,7 @@ def generate_audio( noise = self.prepare_noise(context_latents, seed) bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) - momentum_buffer = MomentumBuffer() + momentum_buffer = MomentumBuffer(momentum=momentum) # Cover noise initialization: blend noise with src_latents if cover_noise_strength > 0.0: @@ -2053,6 +2055,7 @@ def generate_audio( pred_uncond=pred_null_cond, guidance_scale=diffusion_guidance_scale, momentum_buffer=momentum_buffer, + eta=eta, dims=[1], ) else: diff --git a/acestep/models/common/apg_guidance_test.py b/acestep/models/common/apg_guidance_test.py new file mode 100644 index 000000000..4d8f27416 --- /dev/null +++ b/acestep/models/common/apg_guidance_test.py @@ -0,0 +1,111 @@ +"""Unit tests for APG guidance primitives (``MomentumBuffer`` and ``apg_forward``).""" + +import unittest + +import torch + +from acestep.models.common.apg_guidance import MomentumBuffer, apg_forward + + +class MomentumBufferDefaultsTests(unittest.TestCase): + """Tests that MomentumBuffer preserves its historical default behavior.""" + + def test_default_momentum_matches_hardcoded_recipe_value(self): + """Default constructor should produce momentum=-0.75 (the ACE-Step default).""" + + buffer = MomentumBuffer() + self.assertAlmostEqual(-0.75, buffer.momentum) + + def test_explicit_momentum_is_honored(self): + """Caller-supplied momentum value must replace the default.""" + + buffer = MomentumBuffer(momentum=0.25) + self.assertAlmostEqual(0.25, buffer.momentum) + + def test_update_uses_momentum_scale_on_running_average(self): + """Running average must be updated as ``momentum * prev + new``.""" + + buffer = MomentumBuffer(momentum=0.5) + buffer.update(torch.tensor([1.0, 2.0])) + buffer.update(torch.tensor([3.0, 4.0])) + expected = torch.tensor([0.5, 1.0]) + torch.tensor([3.0, 4.0]) + self.assertTrue(torch.allclose(buffer.running_average, expected)) + + +class ApgForwardDefaultsTests(unittest.TestCase): + """Tests that ``apg_forward`` preserves its default eta/norm_threshold semantics.""" + + def _make_inputs(self): + """Return reproducible conditional/unconditional prediction tensors.""" + + torch.manual_seed(0) + pred_cond = torch.randn(2, 4, 8) + pred_uncond = torch.randn(2, 4, 8) + return pred_cond, pred_uncond + + def test_default_eta_produces_same_output_as_explicit_zero(self): + """Omitting eta must match eta=0.0 exactly (backward compatibility).""" + + pred_cond, pred_uncond = self._make_inputs() + implicit = apg_forward( + pred_cond=pred_cond, + pred_uncond=pred_uncond, + guidance_scale=3.0, + momentum_buffer=None, + dims=[1], + ) + explicit = apg_forward( + pred_cond=pred_cond, + pred_uncond=pred_uncond, + guidance_scale=3.0, + momentum_buffer=None, + eta=0.0, + dims=[1], + ) + self.assertTrue(torch.allclose(implicit, explicit)) + + def test_nonzero_eta_changes_output(self): + """A nonzero eta must change the guided prediction (sanity check).""" + + pred_cond, pred_uncond = self._make_inputs() + baseline = apg_forward( + pred_cond=pred_cond, + pred_uncond=pred_uncond, + guidance_scale=3.0, + momentum_buffer=None, + eta=0.0, + dims=[1], + ) + tweaked = apg_forward( + pred_cond=pred_cond, + pred_uncond=pred_uncond, + guidance_scale=3.0, + momentum_buffer=None, + eta=1.0, + dims=[1], + ) + self.assertFalse(torch.allclose(baseline, tweaked)) + + def test_default_momentum_buffer_matches_explicit_neg_075(self): + """Default MomentumBuffer() must produce the same trajectory as momentum=-0.75.""" + + pred_cond, pred_uncond = self._make_inputs() + implicit_out = apg_forward( + pred_cond=pred_cond, + pred_uncond=pred_uncond, + guidance_scale=3.0, + momentum_buffer=MomentumBuffer(), + dims=[1], + ) + explicit_out = apg_forward( + pred_cond=pred_cond, + pred_uncond=pred_uncond, + guidance_scale=3.0, + momentum_buffer=MomentumBuffer(momentum=-0.75), + dims=[1], + ) + self.assertTrue(torch.allclose(implicit_out, explicit_out)) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/models/sft/modeling_acestep_v15_base.py b/acestep/models/sft/modeling_acestep_v15_base.py index 935fc6656..8ba5e3632 100644 --- a/acestep/models/sft/modeling_acestep_v15_base.py +++ b/acestep/models/sft/modeling_acestep_v15_base.py @@ -1875,6 +1875,8 @@ def generate_audio( dcw_scaler: float = 0.05, dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", + eta: float = 0.0, + momentum: float = -0.75, **kwargs, ): # Backward-compat: accept the old misspelled key "diffusion_guidance_sale" @@ -1958,7 +1960,7 @@ def generate_audio( noise = self.prepare_noise(context_latents, seed) bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) - momentum_buffer = MomentumBuffer() + momentum_buffer = MomentumBuffer(momentum=momentum) # Cover noise initialization: blend noise with src_latents if cover_noise_strength > 0.0: @@ -2055,6 +2057,7 @@ def generate_audio( pred_uncond=pred_null_cond, guidance_scale=diffusion_guidance_scale, momentum_buffer=momentum_buffer, + eta=eta, dims=[1], ) else: diff --git a/acestep/models/xl_base/modeling_acestep_v15_xl_base.py b/acestep/models/xl_base/modeling_acestep_v15_xl_base.py index ae38b9336..61d28d158 100644 --- a/acestep/models/xl_base/modeling_acestep_v15_xl_base.py +++ b/acestep/models/xl_base/modeling_acestep_v15_xl_base.py @@ -1887,6 +1887,8 @@ def generate_audio( dcw_scaler: float = 0.05, dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", + eta: float = 0.0, + momentum: float = -0.75, **kwargs, ): # Backward-compat: accept the old misspelled key "diffusion_guidance_sale" @@ -1970,7 +1972,7 @@ def generate_audio( noise = self.prepare_noise(context_latents, seed) bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) - momentum_buffer = MomentumBuffer() + momentum_buffer = MomentumBuffer(momentum=momentum) # Cover noise initialization: blend noise with src_latents if cover_noise_strength > 0.0: @@ -2067,6 +2069,7 @@ def generate_audio( pred_uncond=pred_null_cond, guidance_scale=diffusion_guidance_scale, momentum_buffer=momentum_buffer, + eta=eta, dims=[1], ) else: diff --git a/acestep/models/xl_sft/modeling_acestep_v15_xl_base.py b/acestep/models/xl_sft/modeling_acestep_v15_xl_base.py index ae38b9336..61d28d158 100644 --- a/acestep/models/xl_sft/modeling_acestep_v15_xl_base.py +++ b/acestep/models/xl_sft/modeling_acestep_v15_xl_base.py @@ -1887,6 +1887,8 @@ def generate_audio( dcw_scaler: float = 0.05, dcw_high_scaler: float = 0.02, dcw_wavelet: str = "haar", + eta: float = 0.0, + momentum: float = -0.75, **kwargs, ): # Backward-compat: accept the old misspelled key "diffusion_guidance_sale" @@ -1970,7 +1972,7 @@ def generate_audio( noise = self.prepare_noise(context_latents, seed) bsz, device, dtype = context_latents.shape[0], context_latents.device, context_latents.dtype past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) - momentum_buffer = MomentumBuffer() + momentum_buffer = MomentumBuffer(momentum=momentum) # Cover noise initialization: blend noise with src_latents if cover_noise_strength > 0.0: @@ -2067,6 +2069,7 @@ def generate_audio( pred_uncond=pred_null_cond, guidance_scale=diffusion_guidance_scale, momentum_buffer=momentum_buffer, + eta=eta, dims=[1], ) else: