Skip to content

Commit d137832

Browse files
in-vm-agentlujangus
authored andcommitted
feat(speculative): add typical-acceptance verify mode for Eagle3 draft
Squeeze pipeline Track B B1: drop-in flag-gated alternative to strict rejection-sampling verification. Adds two server-args flags to ServerArgs in sglang/srt/server_args.py: --speculative-verify-mode {rejection_sampling, typical_acceptance} --speculative-typical-acceptance-alpha FLOAT (default 0.8) When `speculative_verify_mode == "typical_acceptance"`, the Eagle3 verification path in `sglang/srt/speculative/eagle_info.py` overrides both `threshold_single` and `threshold_acc` with the alpha value before calling the existing `tree_speculative_sampling_target_only` kernel. The kernel acceptance condition if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) { // accept token } (in `sgl-kernel/csrc/speculative/speculative_sampling.cuh:80`) is the Medusa typical-acceptance formula when threshold_single == threshold_acc == alpha and 0 < alpha <= 1. So the kernel math is already correct; this commit just exposes the alpha knob. Defaults preserve existing behavior: rejection_sampling is the default mode and the existing `--speculative-accept-threshold-{single,acc}` flags continue to work unchanged. alpha=1.0 in typical_acceptance mode also reproduces strict rejection sampling. Scope intentionally narrow per the squeeze B1 preflight at `experiments/MiniMax-M2.5/squeeze/relaxed/B1-typical-acceptance/preflight.md`: - Eagle3 path only (eagle_info.py). ngram_info.py and dflash_utils.py also call tree_speculative_sampling_target_only but are not in the squeeze experiment scope; they continue to use the strict thresholds. - Global server-args flag, not per-request. Avoids mixed-mode KV-cache state. Per-request override deferred to a future revision if needed. To use: python -m sglang.launch_server \ --model-path <target> \ --speculative-algorithm EAGLE3 \ --speculative-draft-model-path thoughtworks/<Model>-Eagle3 \ --speculative-verify-mode typical_acceptance \ --speculative-typical-acceptance-alpha 0.8 \ ... Squeeze B1 alpha-sweep protocol: alpha in {0.7, 0.8, 0.9}, with alpha=1.0 as a control reproducing rejection-sampling baseline. Per-dataset quality must stay within 3% of the lossless Exp F baseline at every concurrency point per the squeeze plan §187 quality floor. Branch sits on top of `fix/llama-eagle3-fp8-aux-dtype-cast` (commit 71e0bf0) so it can run end-to-end on FP8 targets like MiniMaxAI/MiniMax-M2.5 immediately.
1 parent ea6c448 commit d137832

2 files changed

Lines changed: 49 additions & 2 deletions

File tree

python/sglang/srt/server_args.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,14 @@ class ServerArgs:
514514
speculative_adaptive_topk_recovery: float = 0.0
515515
speculative_adaptive_topk_window: int = 32
516516
speculative_adaptive_topk_alpha: float = 0.1 # EMA smoothing
517+
# Track-B Squeeze: relaxed verification via Medusa-style typical-acceptance.
518+
# When `speculative_verify_mode == "typical_acceptance"`, both threshold knobs above
519+
# are overridden by `speculative_typical_acceptance_alpha` at verify time.
520+
# Default `rejection_sampling` preserves existing strict behavior.
521+
speculative_verify_mode: Literal[
522+
"rejection_sampling", "typical_acceptance"
523+
] = "rejection_sampling"
524+
speculative_typical_acceptance_alpha: float = 0.8
517525
speculative_token_map: Optional[str] = None
518526
speculative_attention_mode: str = "prefill"
519527
speculative_draft_attention_backend: Optional[str] = None
@@ -5132,6 +5140,34 @@ def add_cli_args(parser: argparse.ArgumentParser):
51325140
help="A3.1 — EMA smoothing factor in [0, 1] for accept-length tracking. Higher = more responsive to recent batches.",
51335141
default=ServerArgs.speculative_adaptive_topk_alpha,
51345142
)
5143+
parser.add_argument(
5144+
"--speculative-verify-mode",
5145+
type=str,
5146+
choices=["rejection_sampling", "typical_acceptance"],
5147+
default=ServerArgs.speculative_verify_mode,
5148+
help=(
5149+
"Verification regime for speculative decoding. "
5150+
"'rejection_sampling' (default) is strict: a draft token is accepted "
5151+
"iff coin <= prob_acc/threshold_acc OR target_prob_single >= "
5152+
"threshold_single, with both thresholds defaulting to 1.0. "
5153+
"'typical_acceptance' is the Medusa-style alpha-tunable mode: both "
5154+
"thresholds are set to --speculative-typical-acceptance-alpha at verify "
5155+
"time, trading a small amount of distributional fidelity for higher "
5156+
"accept rate (and hence higher throughput) on long-tail tokens. The "
5157+
"trade-off is alpha-tunable; alpha=1.0 reproduces rejection sampling, "
5158+
"alpha~0.7-0.9 is the typical Squeeze Track-B operating range."
5159+
),
5160+
)
5161+
parser.add_argument(
5162+
"--speculative-typical-acceptance-alpha",
5163+
type=float,
5164+
default=ServerArgs.speculative_typical_acceptance_alpha,
5165+
help=(
5166+
"Alpha threshold for typical_acceptance verify mode. Ignored when "
5167+
"--speculative-verify-mode is rejection_sampling. Range (0, 1]. "
5168+
"Smaller values -> higher accept rate, lower distributional fidelity."
5169+
),
5170+
)
51355171
parser.add_argument(
51365172
"--speculative-token-map",
51375173
type=str,

python/sglang/srt/speculative/eagle_info.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,17 @@ def verify(
381381
coins_for_final_sampling = torch.rand(
382382
(bs,), dtype=torch.float32, device=batch.device
383383
)
384+
# Track-B Squeeze: typical-acceptance verify mode overrides both
385+
# threshold knobs with a single alpha. alpha=1.0 reproduces
386+
# rejection sampling (the strict default).
387+
_server_args = get_global_server_args()
388+
if _server_args.speculative_verify_mode == "typical_acceptance":
389+
_alpha = _server_args.speculative_typical_acceptance_alpha
390+
_threshold_single = _alpha
391+
_threshold_acc = _alpha
392+
else:
393+
_threshold_single = _server_args.speculative_accept_threshold_single
394+
_threshold_acc = _server_args.speculative_accept_threshold_acc
384395
tree_speculative_sampling_target_only(
385396
predicts=predict, # mutable
386397
accept_index=accept_index, # mutable
@@ -393,8 +404,8 @@ def verify(
393404
uniform_samples_for_final_sampling=coins_for_final_sampling,
394405
target_probs=target_probs,
395406
draft_probs=draft_probs,
396-
threshold_single=get_global_server_args().speculative_accept_threshold_single,
397-
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
407+
threshold_single=_threshold_single,
408+
threshold_acc=_threshold_acc,
398409
deterministic=True,
399410
)
400411

0 commit comments

Comments
 (0)