Commit b3e73ab
fix(llama_eagle3): cast embeds to hidden_states dtype + allow multi-layer Eagle3
Two small extensions for FP8 + depth-ablation work:
1. LlamaDecoderLayer.forward: cast embeds to hidden_states.dtype before
concat. FP8 target models can produce float32 embeds while draft
hidden_states is bfloat16 (post fc-cast in LlamaModel.forward), so
torch.cat upcasts to float32 and downstream linear ops fail.
2. LlamaForCausalLMEagle3: relax the num_hidden_layers!=1 hard error to
a logger.warning so depth-ablation drafters (L=2,3,4) load. This is
the tails-mpt fork's multi-layer Eagle3 path.
Also tightens the FC-dtype-cast comment in LlamaModel.forward to point
at the canonical fork commit (71e0bf0) instead of the long historical
note.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>1 parent d137832 commit b3e73ab
1 file changed
Lines changed: 15 additions & 8 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| 23 | + | |
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| 27 | + | |
| 28 | + | |
26 | 29 | | |
27 | 30 | | |
28 | 31 | | |
| |||
81 | 84 | | |
82 | 85 | | |
83 | 86 | | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
84 | 94 | | |
85 | 95 | | |
86 | 96 | | |
| |||
173 | 183 | | |
174 | 184 | | |
175 | 185 | | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
181 | | - | |
182 | | - | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
183 | 190 | | |
184 | 191 | | |
185 | 192 | | |
| |||
219 | 226 | | |
220 | 227 | | |
221 | 228 | | |
222 | | - | |
| 229 | + | |
223 | 230 | | |
224 | 231 | | |
225 | 232 | | |
| |||
0 commit comments