Skip to content

Commit b3e73ab

Browse files
lujangusclaude
andcommitted
fix(llama_eagle3): cast embeds to hidden_states dtype + allow multi-layer Eagle3
Two small extensions for FP8 + depth-ablation work: 1. LlamaDecoderLayer.forward: cast embeds to hidden_states.dtype before concat. FP8 target models can produce float32 embeds while draft hidden_states is bfloat16 (post fc-cast in LlamaModel.forward), so torch.cat upcasts to float32 and downstream linear ops fail. 2. LlamaForCausalLMEagle3: relax the num_hidden_layers!=1 hard error to a logger.warning so depth-ablation drafters (L=2,3,4) load. This is the tails-mpt fork's multi-layer Eagle3 path. Also tightens the FC-dtype-cast comment in LlamaModel.forward to point at the canonical fork commit (71e0bf0) instead of the long historical note. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d137832 commit b3e73ab

1 file changed

Lines changed: 15 additions & 8 deletions

File tree

python/sglang/srt/models/llama_eagle3.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
2121

2222
import copy
23+
import logging
2324
from typing import Iterable, Optional, Tuple
2425

2526
import torch
27+
28+
logger = logging.getLogger(__name__)
2629
from torch import nn
2730
from transformers import LlamaConfig
2831

@@ -81,6 +84,13 @@ def forward(
8184
residual: Optional[torch.Tensor],
8285
) -> Tuple[torch.Tensor, torch.Tensor]:
8386

87+
# FP8 fix extension: cast embeds to match hidden_states dtype before
88+
# concat. FP8 target models can produce float32 embeds while the
89+
# draft's hidden_states is bfloat16 (post the fc-cast in LlamaModel.forward).
90+
# Without this, torch.cat upcasts to float32 and downstream linear ops fail.
91+
if embeds.dtype != hidden_states.dtype:
92+
embeds = embeds.to(hidden_states.dtype)
93+
8494
residual = hidden_states
8595
embeds = self.input_layernorm(embeds)
8696
hidden_states = self.hidden_norm(hidden_states)
@@ -173,13 +183,10 @@ def forward(
173183
positions = forward_batch.mrope_positions
174184

175185
hidden_states = forward_batch.spec_info.hidden_states
176-
# Cast aux hidden_states to match FC weight dtype.
177-
# FP8 target models produce float32 dequantized aux states, but the
178-
# Eagle3 draft head's FC is bfloat16 — without this cast, F.linear
179-
# raises "expected mat1 and mat2 to have the same dtype" during CUDA
180-
# graph capture. Originally landed as cfbffdc56 (Gus, 2026-04-01);
181-
# subsequently lost in ea2f129a9 (upstream sync). Re-applying as the
182-
# durable fix.
186+
# FP8 fix (sglang fork commit 71e0bf009): FP8 target models produce
187+
# float32 dequantized aux hidden states; the Eagle3 draft FC is bf16.
188+
# Without this cast, F.linear raises a dtype mismatch during CUDA
189+
# graph capture.
183190
if hidden_states.dtype != self.fc.weight.dtype:
184191
hidden_states = hidden_states.to(self.fc.weight.dtype)
185192
if hidden_states.shape[-1] != embeds.shape[-1]:
@@ -219,7 +226,7 @@ def __init__(
219226
self.pp_group = get_pp_group()
220227

221228
if self.config.num_hidden_layers != 1:
222-
raise ValueError("EAGLE3 currently only supports 1 layer")
229+
logger.warning(f"Multi-layer EAGLE3 drafter (num_hidden_layers={self.config.num_hidden_layers}) — depth-ablation patch by tails-mpt fork")
223230

224231
self.model = LlamaModel(
225232
config, quant_config=quant_config, prefix=add_prefix("model", prefix)

0 commit comments

Comments
 (0)