Skip to content

Commit 691ce91

Browse files
py4Pooya Moradi
andauthored
[Spec Decoding][Eagle3] Fix bug of eagle-3 not being compataible with non-8b models. (#1165)
Signed-off-by: Pooya Moradi <[email protected]> Co-authored-by: Pooya Moradi <[email protected]>
1 parent ddd5471 commit 691ce91

File tree

4 files changed

+56
-49
lines changed

4 files changed

+56
-49
lines changed

tpu_inference/models/jax/llama_eagle3.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,15 @@ def load_weights(self, rng_key: jax.Array):
304304
"fc": "model.fc.kernel",
305305
"lm_head": "lm_head.kernel",
306306
"d2t": "draft_id_to_target_id",
307+
"embed_tokens":
308+
"model.embed_tokens.embedding", # Some checkpoints need this
307309
}
308310

309311
# Define keys to keep in original dtype (e.g., float32 for stability)
310312
keep_original_dtype_keys_regex = [
311313
r".*d2t.*",
312314
]
313315

314-
# `embed_tokens` is shared between target and draft.
315-
exclude_regex = [r".*embed_tokens.*"]
316316
metadata_map = get_default_maps(
317317
self.vllm_config.speculative_config.draft_model_config, self.mesh,
318318
mappings)
@@ -325,10 +325,9 @@ def load_weights(self, rng_key: jax.Array):
325325
metadata_map=metadata_map,
326326
mesh=self.mesh,
327327
is_draft_model=True,
328-
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
329-
exclude_regex=exclude_regex if exclude_regex else None)
328+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
330329

331-
# If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
330+
# If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
332331
if isinstance(self.model.embed_tokens.embedding.value,
333332
jax.ShapeDtypeStruct):
334333
self.model.embed_tokens.embedding.value = jnp.zeros(

tpu_inference/models/jax/utils/weight_utils.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ def _load_hf_weights_on_thread(
402402
weights_file: str,
403403
filter_regex: Optional[str] = None,
404404
keep_original_dtype_keys_regex: Optional[list[str]] = None,
405-
exclude_regex: Optional[list[str]] = None,
406405
):
407406
"""Loads weights from a single weights file."""
408407
try:
@@ -412,17 +411,6 @@ def _load_hf_weights_on_thread(
412411

413412
for hf_key, hf_weight in model_weights_single_file_generator(
414413
weights_file, framework="flax", filter_regex=filter_regex):
415-
# Check if the key should be excluded
416-
if exclude_regex:
417-
should_exclude = False
418-
for pattern in exclude_regex:
419-
if re.search(pattern, hf_key):
420-
logger.info(
421-
f"Excluding {hf_key} based on pattern {pattern}")
422-
should_exclude = True
423-
break
424-
if should_exclude:
425-
continue
426414
_load_and_shard_weight(
427415
vllm_config,
428416
params,
@@ -443,7 +431,6 @@ def load_hf_weights(
443431
filter_regex: Optional[str] = None,
444432
is_draft_model: bool = False,
445433
keep_original_dtype_keys_regex: Optional[list[str]] = None,
446-
exclude_regex: Optional[list[str]] = None,
447434
):
448435
"""Load weights into a JAX model from either an iterator or files."""
449436
params = nnx.state(model)
@@ -491,17 +478,17 @@ def load_hf_weights(
491478
max_workers = 1
492479
with ThreadPoolExecutor(max_workers=max_workers) as executor:
493480
futures = [
494-
executor.submit(_load_hf_weights_on_thread,
495-
vllm_config,
496-
params,
497-
metadata_map,
498-
mesh,
499-
weights_file,
500-
filter_regex=filter_regex,
501-
keep_original_dtype_keys_regex=
502-
keep_original_dtype_keys_regex,
503-
exclude_regex=exclude_regex)
504-
for weights_file in weights_files
481+
executor.submit(
482+
_load_hf_weights_on_thread,
483+
vllm_config,
484+
params,
485+
metadata_map,
486+
mesh,
487+
weights_file,
488+
filter_regex=filter_regex,
489+
keep_original_dtype_keys_regex=
490+
keep_original_dtype_keys_regex,
491+
) for weights_file in weights_files
505492
]
506493
for future in futures:
507494
future.result()

tpu_inference/runner/compilation_manager.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,9 @@ def _precompile_rejection_sampler(self) -> None:
548548
def _precompile_eagle3_helpers(self) -> None:
549549
logger.info(
550550
"Compiling eagle3 jitted helpers with different input shapes.")
551-
hidden_size = self.runner.model_config.get_hidden_size()
551+
target_hidden_size = self.runner.model_config.get_hidden_size()
552+
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
553+
)
552554
dtype = self.runner.model_config.dtype
553555

554556
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -595,7 +597,7 @@ def _precompile_eagle3_helpers(self) -> None:
595597

596598
for num_logits in self.runner.num_logits_paddings:
597599
hidden_states = self._create_dummy_tensor(
598-
(num_logits, hidden_size), jnp.bfloat16)
600+
(num_logits, draft_hidden_size), jnp.bfloat16)
599601
self._run_compilation(
600602
"eagle3_get_draft_token_ids",
601603
self.runner.drafter._get_draft_token_ids,
@@ -606,18 +608,21 @@ def _precompile_eagle3_helpers(self) -> None:
606608
input_ids_loop = self._create_dummy_tensor(
607609
(self.runner.max_num_reqs, ), jnp.int32,
608610
NamedSharding(self.runner.mesh, PartitionSpec()))
609-
target_hidden_state_loop = self._create_dummy_tensor(
610-
(self.runner.max_num_reqs, hidden_size), dtype,
611+
draft_hidden_state_loop = self._create_dummy_tensor(
612+
(self.runner.max_num_reqs, draft_hidden_size), dtype,
611613
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
612614
next_token_ids = self._create_dummy_tensor(
613615
(self.runner.max_num_reqs, ), jnp.int32)
614616
last_token_indices = self._create_dummy_tensor(
615617
(self.runner.max_num_reqs, ), jnp.int32)
616618
for num_tokens in self.runner.num_tokens_paddings:
617619
aux_hidden_states = [
618-
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
619-
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
620-
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
620+
self._create_dummy_tensor((num_tokens, target_hidden_size),
621+
dtype),
622+
self._create_dummy_tensor((num_tokens, target_hidden_size),
623+
dtype),
624+
self._create_dummy_tensor((num_tokens, target_hidden_size),
625+
dtype),
621626
]
622627

623628
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -648,15 +653,15 @@ def filter_token_and_prepare_initial_inputs_wrapper(
648653
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
649654
aux_hidden_states = [
650655
self._create_dummy_tensor(
651-
(num_tokens, hidden_size), jnp.bfloat16,
656+
(num_tokens, target_hidden_size), jnp.bfloat16,
652657
NamedSharding(self.runner.mesh, PartitionSpec(None,
653658
None))),
654659
self._create_dummy_tensor(
655-
(num_tokens, hidden_size), jnp.bfloat16,
660+
(num_tokens, target_hidden_size), jnp.bfloat16,
656661
NamedSharding(self.runner.mesh, PartitionSpec(None,
657662
None))),
658663
self._create_dummy_tensor(
659-
(num_tokens, hidden_size), jnp.bfloat16,
664+
(num_tokens, target_hidden_size), jnp.bfloat16,
660665
NamedSharding(self.runner.mesh, PartitionSpec(None,
661666
None))),
662667
]
@@ -688,17 +693,17 @@ def draft_model_fn_wrapper(
688693
state,
689694
kv_caches,
690695
input_ids,
691-
target_hidden_states,
696+
draft_hidden_states,
692697
attention_metadata,
693698
):
694699
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
695-
state, kv_caches, input_ids, target_hidden_states,
700+
state, kv_caches, input_ids, draft_hidden_states,
696701
attention_metadata)
697702
self.runner.kv_caches = kv_caches
698703
return hidden_states
699704

700-
target_hidden_states = self._create_dummy_tensor(
701-
(num_tokens, hidden_size), dtype,
705+
draft_hidden_states = self._create_dummy_tensor(
706+
(num_tokens, draft_hidden_size), dtype,
702707
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
703708
input_ids = self._create_dummy_tensor(
704709
(num_tokens, ), jnp.int32,
@@ -709,7 +714,7 @@ def draft_model_fn_wrapper(
709714
self.runner.drafter.state,
710715
self.runner.kv_caches,
711716
input_ids,
712-
target_hidden_states,
717+
draft_hidden_states,
713718
attention_metadata,
714719
num_tokens=num_tokens,
715720
)
@@ -741,13 +746,13 @@ def draft_model_fn_wrapper(
741746
self.runner.drafter.state,
742747
self.runner.kv_caches,
743748
input_ids_loop,
744-
target_hidden_state_loop,
749+
draft_hidden_state_loop,
745750
attention_metadata,
746751
num_tokens=num_tokens,
747752
)
748753

749754
hidden_states = self._create_dummy_tensor(
750-
(num_tokens, hidden_size), jnp.bfloat16,
755+
(num_tokens, draft_hidden_size), jnp.bfloat16,
751756
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
752757

753758
self._run_compilation(

tpu_inference/spec_decode/jax/eagle3.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99
from vllm.config import VllmConfig
1010

1111
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12+
from tpu_inference.logger import init_logger
1213
from tpu_inference.models.common.model_loader import get_model
1314
from tpu_inference.runner import utils as runner_utils
1415
from tpu_inference.utils import device_array
1516

17+
logger = init_logger(__name__)
18+
1619

1720
class Eagle3Proposer:
1821
"""A proposer for speculative decoding using the Eagle3 method.
@@ -51,9 +54,22 @@ def load_model(self, target_model: Any) -> None:
5154
"""Loads the draft model."""
5255
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
5356
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54-
if 'embed_tokens' in self.state.model:
55-
del self.state.model['embed_tokens']
56-
self.state.model.embed_tokens = target_model.model.embed
57+
58+
draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
59+
if draft_embed_tokens is None or ~jnp.any(
60+
draft_embed_tokens.embedding):
61+
logger.info(
62+
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
63+
)
64+
self.state.model.embed_tokens = target_model.model.embed
65+
elif jnp.array_equal(draft_embed_tokens.embedding,
66+
target_model.model.embed.embedding):
67+
logger.info(
68+
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
69+
)
70+
self.state.model.embed_tokens = target_model.model.embed
71+
else:
72+
logger.info("Draft model has its own embed_tokens.")
5773

5874
@functools.partial(jax.jit, static_argnums=(0, ))
5975
def _prepare_input_ids(

0 commit comments

Comments
 (0)