Skip to content

Conversation

@SiqiLi-Fighting
Copy link

No description provided.

@SiqiLi-Fighting SiqiLi-Fighting added the enhancement New feature or request label Sep 24, 2025
@gemini-code-assist
Copy link

Summary of Changes

Hello @SiqiLi-Fighting, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a foundational set of changes to integrate Multi-Token Prediction (MTP) with the EAGLE speculative decoding algorithm into the SGLang framework. The primary objective is to significantly enhance inference throughput by enabling models to predict and verify multiple tokens concurrently, thereby reducing the sequential bottlenecks inherent in traditional token generation. The modifications span across documentation, core scheduling and model execution components, model configurations, and command-line arguments, laying the groundwork for a more efficient and scalable language model inference system.

Highlights

  • Multi-Token Prediction (MTP) RFC: A new Request for Comments (RFC) document has been added, detailing the proposal for Multi-Token Prediction (MTP) as an enhancement to the EAGLE speculative decoding algorithm. This aims to significantly improve inference throughput by allowing models to predict multiple tokens simultaneously.
  • Speculative Decoding Integration: Core components like ScheduleBatch, ModelWorkerBatch, and ForwardBatch have been updated to include spec_algorithm and spec_info fields, enabling the integration of speculative decoding algorithms like EAGLE into the scheduling and model execution flow.
  • New EAGLE Worker and Utilities: New files eagle_util.py, eagle_worker.py, and spec_info.py have been introduced. EAGLEWorker extends ModelWorker to manage the speculative decoding process, while eagle_util.py provides JAX-based utility functions for tree construction and cache management specific to EAGLE.
  • Model and Server Argument Enhancements: Model classes (Qwen, Qwen3, Qwen3_moe) now include methods to get and set embedding and LM head weights. ServerArgs has been extended with new CLI arguments to configure speculative decoding parameters such as algorithm type, draft model path, number of steps, and top-k values.
  • Refactoring and Type Hinting: The FlashAttention class was renamed to FlashAttentionBackend for consistency. LogitsProcessor received updates to handle auxiliary hidden states and improved type hinting with TYPE_CHECKING for ModelWorkerBatch.
  • Overlap Scheduling Restriction: A new validation check has been added to ServerArgs to explicitly disallow overlap scheduling when speculative decoding is enabled, indicating potential incompatibilities or performance considerations.
  • EAGLE Tree Building Tests: A dedicated test file test_eagle_tree_build.py has been added to validate the correctness of the JAX implementation for EAGLE tree construction, ensuring compatibility with PyTorch's expected behavior.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for EAGLE speculative decoding, a significant feature to improve inference throughput. The changes include a detailed RFC design document, new server configurations, and the core implementation of the EAGLE worker and its utilities.

My review has identified several key areas for improvement. Most critically, the EAGLEWorker implementation is incomplete, with several key methods like verify being placeholders. This will cause runtime errors and prevents the feature from functioning. Additionally, there's a significant performance concern in eagle_util.py where a core function uses Python loops instead of vectorized JAX operations. I've also left some minor feedback on the RFC document for clarity and correctness.

Addressing the incomplete implementation is crucial before this PR can be considered for merging.

Comment on lines 247 to 248
def verify(self, batch: ScheduleBatch, spec_info: EagleDraftInput):
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The verify method is a critical part of the speculative decoding loop, but it's currently implemented as a pass statement. This makes the EAGLEWorker incomplete. Without the verification step, the speculative decoding cannot function correctly. This method needs to be fully implemented. The call to this method on line 80 will lead to a runtime error as it expects multiple return values.

Comment on lines 250 to 251
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward_draft_extend_after_decode method is currently a placeholder with a pass statement. This method seems to be part of the speculative decoding cycle and needs to be implemented for the feature to be complete.

Comment on lines 355 to 494
def build_eagle_tree_structure(
parent_list: jax.Array,
selected_index: jax.Array,
verified_seq_len: jax.Array,
bs: int,
draft_token_num: int,
topk: int,
depth: int,
tree_mask_mode: int = 0, # FULL_MASK = 0
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""
Args:
parent_list: Parent indices array [bs, topk * (depth-1) + 1]
selected_index: Selected token indices [bs, draft_token_num - 1]
verified_seq_len: Sequence lengths [bs]
bs: Batch size
draft_token_num: Number of draft tokens
topk: Top-k value
depth: Tree depth
tree_mask_mode: Tree mask mode (0=FULL_MASK)

Returns:
tuple of (positions, retrive_index, retrive_next_token, retrive_next_sibling)
"""

# Initialize arrays
positions = jnp.zeros((bs * draft_token_num,), dtype=jnp.int32)
retrive_index = jnp.full((bs, draft_token_num), -1, dtype=jnp.int32)
retrive_next_token = jnp.full((bs, draft_token_num), -1, dtype=jnp.int32)
retrive_next_sibling = jnp.full((bs, draft_token_num), -1, dtype=jnp.int32)

for bid in range(bs):
seq_len = verified_seq_len[bid]

# selected_index[bid * (draft_token_num - 1) + index]
# parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]

for tid in range(draft_token_num):
global_token_idx = bid * draft_token_num + tid

if tid == 0:
# Verified token (tid == 0)
positions = positions.at[global_token_idx].set(seq_len)
retrive_index = retrive_index.at[bid, tid].set(global_token_idx)

# Build retrive_next_token and retrive_next_sibling (backwards iteration)
retrive_index_offset = bid * draft_token_num
for i in range(
draft_token_num - 1, 0, -1
): # i from draft_token_num-1 to 1
current_token_idx = retrive_index_offset + i
retrive_index = retrive_index.at[bid, i].set(current_token_idx)

selected_idx = bid * (draft_token_num - 1) + i - 1
parent_tb_idx = selected_index.flatten()[selected_idx] // topk
parent_position = 0

if parent_tb_idx > 0:
parent_list_idx = bid * (topk * (depth - 1) + 1) + parent_tb_idx
if parent_list_idx < parent_list.size:
parent_token_idx = parent_list.flatten()[parent_list_idx]

for parent_pos in range(draft_token_num - 1):
check_idx = bid * (draft_token_num - 1) + parent_pos
if (
check_idx < selected_index.size
and selected_index.flatten()[check_idx]
== parent_token_idx
):
parent_position = (
parent_pos + 1
) # +1 to convert to 1-indexed
break
else:
parent_position = draft_token_num # Not found
else:
parent_position = draft_token_num # Invalid parent_list_idx
else:
parent_position = 0 # Root node

if parent_position >= draft_token_num:
# Invalid parent, skip
continue

next_token_idx = bid * draft_token_num + parent_position
if retrive_next_token.flatten()[next_token_idx] == -1:
retrive_next_token = retrive_next_token.at[
bid, parent_position
].set(i)
else:
# There's already a next_token, so set sibling
origin_next_token = retrive_next_token.flatten()[next_token_idx]
retrive_next_token = retrive_next_token.at[
bid, parent_position
].set(i)
retrive_next_sibling = retrive_next_sibling.at[bid, i].set(
origin_next_token
)

retrive_index = retrive_index.at[bid, 0].set(bid * draft_token_num)

else:
# Draft token (tid > 0)
# Calculate position by tracing back to root
position = 0
cur_position = tid - 1 # Convert to 0-indexed for selected_index

while True:
position += 1
selected_idx = bid * (draft_token_num - 1) + cur_position
parent_tb_idx = selected_index.flatten()[selected_idx] // topk

if parent_tb_idx == 0:
# Reached root
break

parent_list_idx = bid * (topk * (depth - 1) + 1) + parent_tb_idx
if parent_list_idx < parent_list.size:
token_idx = parent_list.flatten()[parent_list_idx]

found = False
for cur_pos in range(draft_token_num - 1):
check_idx = bid * (draft_token_num - 1) + cur_pos
if (
check_idx < selected_index.size
and selected_index.flatten()[check_idx] == token_idx
):
cur_position = cur_pos
found = True
break

if not found:
break # Invalid tree structure
else:
break # Invalid parent_list_idx

positions = positions.at[global_token_idx].set(position + seq_len)
retrive_index = retrive_index.at[bid, tid].set(global_token_idx)

return positions, retrive_index, retrive_next_token, retrive_next_sibling

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function build_eagle_tree_structure contains a for loop that iterates over the batch dimension (for bid in range(bs)). In JAX, Python-level loops over batch dimensions are not JIT-compiled efficiently and can lead to significant performance degradation, as the loop body is executed sequentially. For better performance, this logic should be vectorized using jax.vmap. This would allow JAX to process the entire batch in parallel on the accelerator.

Comment on lines +182 to +186
return MTPVerifyResult(
accepted_sequences=accepted_sequences,
acceptance_rate=len(accepted_sequences) / len(mtp_tree.sequences),
next_tokens=self._extract_accepted_tokens(accepted_sequences)
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the verify_mtp_sequences function's example, the acceptance_rate is calculated as len(accepted_sequences) / len(mtp_tree.sequences). This could lead to a ZeroDivisionError if mtp_tree.sequences is empty. It would be good to add a check to handle this edge case in the implementation.

## References

1. [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/abs/2401.15077)
2. [Multi-Token Prediction Paper](https://arxiv.org/abs/2412.19437)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reference to "Multi-Token Prediction Paper" points to an arXiv link with a future date (2412.19437). This seems to be a placeholder or a typo. Could you please verify and update the link?

Comment on lines +329 to +330
def _draft_preprocess_idle(self, batch: ScheduleBatch):
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _draft_preprocess_idle method is a placeholder. While it might not be as critical as verify, it should be implemented or have a comment explaining why it's empty if that's the intended behavior.

- Add eagle worker and utilities for speculative decoding
- Add support for draft model path configuration
- Update scheduler and model runner for speculative execution
- Add hidden state capture functionality
- Various bug fixes and import corrections
impl draft augressive decode
@Iamleos Iamleos marked this pull request as draft November 10, 2025 06:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants