@@ -548,7 +548,9 @@ def _precompile_rejection_sampler(self) -> None:
548548 def _precompile_eagle3_helpers (self ) -> None :
549549 logger .info (
550550 "Compiling eagle3 jitted helpers with different input shapes." )
551- hidden_size = self .runner .model_config .get_hidden_size ()
551+ target_hidden_size = self .runner .model_config .get_hidden_size ()
552+ draft_hidden_size = self .runner .speculative_config .draft_model_config .get_hidden_size (
553+ )
552554 dtype = self .runner .model_config .dtype
553555
554556 num_kv_cache_groups = len (self .runner .kv_cache_config .kv_cache_groups )
@@ -595,7 +597,7 @@ def _precompile_eagle3_helpers(self) -> None:
595597
596598 for num_logits in self .runner .num_logits_paddings :
597599 hidden_states = self ._create_dummy_tensor (
598- (num_logits , hidden_size ), jnp .bfloat16 )
600+ (num_logits , draft_hidden_size ), jnp .bfloat16 )
599601 self ._run_compilation (
600602 "eagle3_get_draft_token_ids" ,
601603 self .runner .drafter ._get_draft_token_ids ,
@@ -606,18 +608,21 @@ def _precompile_eagle3_helpers(self) -> None:
606608 input_ids_loop = self ._create_dummy_tensor (
607609 (self .runner .max_num_reqs , ), jnp .int32 ,
608610 NamedSharding (self .runner .mesh , PartitionSpec ()))
609- target_hidden_state_loop = self ._create_dummy_tensor (
610- (self .runner .max_num_reqs , hidden_size ), dtype ,
611+ draft_hidden_state_loop = self ._create_dummy_tensor (
612+ (self .runner .max_num_reqs , draft_hidden_size ), dtype ,
611613 NamedSharding (self .runner .mesh , PartitionSpec (None , None )))
612614 next_token_ids = self ._create_dummy_tensor (
613615 (self .runner .max_num_reqs , ), jnp .int32 )
614616 last_token_indices = self ._create_dummy_tensor (
615617 (self .runner .max_num_reqs , ), jnp .int32 )
616618 for num_tokens in self .runner .num_tokens_paddings :
617619 aux_hidden_states = [
618- self ._create_dummy_tensor ((num_tokens , hidden_size ), dtype ),
619- self ._create_dummy_tensor ((num_tokens , hidden_size ), dtype ),
620- self ._create_dummy_tensor ((num_tokens , hidden_size ), dtype ),
620+ self ._create_dummy_tensor ((num_tokens , target_hidden_size ),
621+ dtype ),
622+ self ._create_dummy_tensor ((num_tokens , target_hidden_size ),
623+ dtype ),
624+ self ._create_dummy_tensor ((num_tokens , target_hidden_size ),
625+ dtype ),
621626 ]
622627
623628 positions = self ._create_dummy_tensor ((num_tokens , ), jnp .int32 )
@@ -648,15 +653,15 @@ def filter_token_and_prepare_initial_inputs_wrapper(
648653 input_ids = self ._create_dummy_tensor ((num_tokens , ), jnp .int32 )
649654 aux_hidden_states = [
650655 self ._create_dummy_tensor (
651- (num_tokens , hidden_size ), jnp .bfloat16 ,
656+ (num_tokens , target_hidden_size ), jnp .bfloat16 ,
652657 NamedSharding (self .runner .mesh , PartitionSpec (None ,
653658 None ))),
654659 self ._create_dummy_tensor (
655- (num_tokens , hidden_size ), jnp .bfloat16 ,
660+ (num_tokens , target_hidden_size ), jnp .bfloat16 ,
656661 NamedSharding (self .runner .mesh , PartitionSpec (None ,
657662 None ))),
658663 self ._create_dummy_tensor (
659- (num_tokens , hidden_size ), jnp .bfloat16 ,
664+ (num_tokens , target_hidden_size ), jnp .bfloat16 ,
660665 NamedSharding (self .runner .mesh , PartitionSpec (None ,
661666 None ))),
662667 ]
@@ -688,17 +693,17 @@ def draft_model_fn_wrapper(
688693 state ,
689694 kv_caches ,
690695 input_ids ,
691- target_hidden_states ,
696+ draft_hidden_states ,
692697 attention_metadata ,
693698 ):
694699 kv_caches , hidden_states , _ = self .runner .drafter .model_fn (
695- state , kv_caches , input_ids , target_hidden_states ,
700+ state , kv_caches , input_ids , draft_hidden_states ,
696701 attention_metadata )
697702 self .runner .kv_caches = kv_caches
698703 return hidden_states
699704
700- target_hidden_states = self ._create_dummy_tensor (
701- (num_tokens , hidden_size ), dtype ,
705+ draft_hidden_states = self ._create_dummy_tensor (
706+ (num_tokens , draft_hidden_size ), dtype ,
702707 NamedSharding (self .runner .mesh , PartitionSpec (None , "model" )))
703708 input_ids = self ._create_dummy_tensor (
704709 (num_tokens , ), jnp .int32 ,
@@ -709,7 +714,7 @@ def draft_model_fn_wrapper(
709714 self .runner .drafter .state ,
710715 self .runner .kv_caches ,
711716 input_ids ,
712- target_hidden_states ,
717+ draft_hidden_states ,
713718 attention_metadata ,
714719 num_tokens = num_tokens ,
715720 )
@@ -741,13 +746,13 @@ def draft_model_fn_wrapper(
741746 self .runner .drafter .state ,
742747 self .runner .kv_caches ,
743748 input_ids_loop ,
744- target_hidden_state_loop ,
749+ draft_hidden_state_loop ,
745750 attention_metadata ,
746751 num_tokens = num_tokens ,
747752 )
748753
749754 hidden_states = self ._create_dummy_tensor (
750- (num_tokens , hidden_size ), jnp .bfloat16 ,
755+ (num_tokens , draft_hidden_size ), jnp .bfloat16 ,
751756 NamedSharding (self .runner .mesh , PartitionSpec (None , None )))
752757
753758 self ._run_compilation (
0 commit comments