-
Notifications
You must be signed in to change notification settings - Fork 0
[Feat]eagle support #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @SiqiLi-Fighting, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a foundational set of changes to integrate Multi-Token Prediction (MTP) with the EAGLE speculative decoding algorithm into the SGLang framework. The primary objective is to significantly enhance inference throughput by enabling models to predict and verify multiple tokens concurrently, thereby reducing the sequential bottlenecks inherent in traditional token generation. The modifications span across documentation, core scheduling and model execution components, model configurations, and command-line arguments, laying the groundwork for a more efficient and scalable language model inference system. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for EAGLE speculative decoding, a significant feature to improve inference throughput. The changes include a detailed RFC design document, new server configurations, and the core implementation of the EAGLE worker and its utilities.
My review has identified several key areas for improvement. Most critically, the EAGLEWorker implementation is incomplete, with several key methods like verify being placeholders. This will cause runtime errors and prevents the feature from functioning. Additionally, there's a significant performance concern in eagle_util.py where a core function uses Python loops instead of vectorized JAX operations. I've also left some minor feedback on the RFC document for clarity and correctness.
Addressing the incomplete implementation is crucial before this PR can be considered for merging.
| def verify(self, batch: ScheduleBatch, spec_info: EagleDraftInput): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The verify method is a critical part of the speculative decoding loop, but it's currently implemented as a pass statement. This makes the EAGLEWorker incomplete. Without the verification step, the speculative decoding cannot function correctly. This method needs to be fully implemented. The call to this method on line 80 will lead to a runtime error as it expects multiple return values.
| def forward_draft_extend_after_decode(self, batch: ScheduleBatch): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def build_eagle_tree_structure( | ||
| parent_list: jax.Array, | ||
| selected_index: jax.Array, | ||
| verified_seq_len: jax.Array, | ||
| bs: int, | ||
| draft_token_num: int, | ||
| topk: int, | ||
| depth: int, | ||
| tree_mask_mode: int = 0, # FULL_MASK = 0 | ||
| ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: | ||
| """ | ||
| Args: | ||
| parent_list: Parent indices array [bs, topk * (depth-1) + 1] | ||
| selected_index: Selected token indices [bs, draft_token_num - 1] | ||
| verified_seq_len: Sequence lengths [bs] | ||
| bs: Batch size | ||
| draft_token_num: Number of draft tokens | ||
| topk: Top-k value | ||
| depth: Tree depth | ||
| tree_mask_mode: Tree mask mode (0=FULL_MASK) | ||
|
|
||
| Returns: | ||
| tuple of (positions, retrive_index, retrive_next_token, retrive_next_sibling) | ||
| """ | ||
|
|
||
| # Initialize arrays | ||
| positions = jnp.zeros((bs * draft_token_num,), dtype=jnp.int32) | ||
| retrive_index = jnp.full((bs, draft_token_num), -1, dtype=jnp.int32) | ||
| retrive_next_token = jnp.full((bs, draft_token_num), -1, dtype=jnp.int32) | ||
| retrive_next_sibling = jnp.full((bs, draft_token_num), -1, dtype=jnp.int32) | ||
|
|
||
| for bid in range(bs): | ||
| seq_len = verified_seq_len[bid] | ||
|
|
||
| # selected_index[bid * (draft_token_num - 1) + index] | ||
| # parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx] | ||
|
|
||
| for tid in range(draft_token_num): | ||
| global_token_idx = bid * draft_token_num + tid | ||
|
|
||
| if tid == 0: | ||
| # Verified token (tid == 0) | ||
| positions = positions.at[global_token_idx].set(seq_len) | ||
| retrive_index = retrive_index.at[bid, tid].set(global_token_idx) | ||
|
|
||
| # Build retrive_next_token and retrive_next_sibling (backwards iteration) | ||
| retrive_index_offset = bid * draft_token_num | ||
| for i in range( | ||
| draft_token_num - 1, 0, -1 | ||
| ): # i from draft_token_num-1 to 1 | ||
| current_token_idx = retrive_index_offset + i | ||
| retrive_index = retrive_index.at[bid, i].set(current_token_idx) | ||
|
|
||
| selected_idx = bid * (draft_token_num - 1) + i - 1 | ||
| parent_tb_idx = selected_index.flatten()[selected_idx] // topk | ||
| parent_position = 0 | ||
|
|
||
| if parent_tb_idx > 0: | ||
| parent_list_idx = bid * (topk * (depth - 1) + 1) + parent_tb_idx | ||
| if parent_list_idx < parent_list.size: | ||
| parent_token_idx = parent_list.flatten()[parent_list_idx] | ||
|
|
||
| for parent_pos in range(draft_token_num - 1): | ||
| check_idx = bid * (draft_token_num - 1) + parent_pos | ||
| if ( | ||
| check_idx < selected_index.size | ||
| and selected_index.flatten()[check_idx] | ||
| == parent_token_idx | ||
| ): | ||
| parent_position = ( | ||
| parent_pos + 1 | ||
| ) # +1 to convert to 1-indexed | ||
| break | ||
| else: | ||
| parent_position = draft_token_num # Not found | ||
| else: | ||
| parent_position = draft_token_num # Invalid parent_list_idx | ||
| else: | ||
| parent_position = 0 # Root node | ||
|
|
||
| if parent_position >= draft_token_num: | ||
| # Invalid parent, skip | ||
| continue | ||
|
|
||
| next_token_idx = bid * draft_token_num + parent_position | ||
| if retrive_next_token.flatten()[next_token_idx] == -1: | ||
| retrive_next_token = retrive_next_token.at[ | ||
| bid, parent_position | ||
| ].set(i) | ||
| else: | ||
| # There's already a next_token, so set sibling | ||
| origin_next_token = retrive_next_token.flatten()[next_token_idx] | ||
| retrive_next_token = retrive_next_token.at[ | ||
| bid, parent_position | ||
| ].set(i) | ||
| retrive_next_sibling = retrive_next_sibling.at[bid, i].set( | ||
| origin_next_token | ||
| ) | ||
|
|
||
| retrive_index = retrive_index.at[bid, 0].set(bid * draft_token_num) | ||
|
|
||
| else: | ||
| # Draft token (tid > 0) | ||
| # Calculate position by tracing back to root | ||
| position = 0 | ||
| cur_position = tid - 1 # Convert to 0-indexed for selected_index | ||
|
|
||
| while True: | ||
| position += 1 | ||
| selected_idx = bid * (draft_token_num - 1) + cur_position | ||
| parent_tb_idx = selected_index.flatten()[selected_idx] // topk | ||
|
|
||
| if parent_tb_idx == 0: | ||
| # Reached root | ||
| break | ||
|
|
||
| parent_list_idx = bid * (topk * (depth - 1) + 1) + parent_tb_idx | ||
| if parent_list_idx < parent_list.size: | ||
| token_idx = parent_list.flatten()[parent_list_idx] | ||
|
|
||
| found = False | ||
| for cur_pos in range(draft_token_num - 1): | ||
| check_idx = bid * (draft_token_num - 1) + cur_pos | ||
| if ( | ||
| check_idx < selected_index.size | ||
| and selected_index.flatten()[check_idx] == token_idx | ||
| ): | ||
| cur_position = cur_pos | ||
| found = True | ||
| break | ||
|
|
||
| if not found: | ||
| break # Invalid tree structure | ||
| else: | ||
| break # Invalid parent_list_idx | ||
|
|
||
| positions = positions.at[global_token_idx].set(position + seq_len) | ||
| retrive_index = retrive_index.at[bid, tid].set(global_token_idx) | ||
|
|
||
| return positions, retrive_index, retrive_next_token, retrive_next_sibling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function build_eagle_tree_structure contains a for loop that iterates over the batch dimension (for bid in range(bs)). In JAX, Python-level loops over batch dimensions are not JIT-compiled efficiently and can lead to significant performance degradation, as the loop body is executed sequentially. For better performance, this logic should be vectorized using jax.vmap. This would allow JAX to process the entire batch in parallel on the accelerator.
| return MTPVerifyResult( | ||
| accepted_sequences=accepted_sequences, | ||
| acceptance_rate=len(accepted_sequences) / len(mtp_tree.sequences), | ||
| next_tokens=self._extract_accepted_tokens(accepted_sequences) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ## References | ||
|
|
||
| 1. [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/abs/2401.15077) | ||
| 2. [Multi-Token Prediction Paper](https://arxiv.org/abs/2412.19437) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _draft_preprocess_idle(self, batch: ScheduleBatch): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Add eagle worker and utilities for speculative decoding - Add support for draft model path configuration - Update scheduler and model runner for speculative execution - Add hidden state capture functionality - Various bug fixes and import corrections
impl draft augressive decode
3e9d778 to
52cfdeb
Compare
* fix tree mask calculate
feat: implement eagle verify phase
* fix donate bug
* fix outcache loc * fix some dtype bug/ simulate acc bug
add uncasual mask for pagged attention
No description provided.