@@ -347,7 +347,7 @@ def create_idle_input(
347347 verified_id = jnp .empty ((0 ,), dtype = jnp .int32 ),
348348 hidden_states = jnp .empty ((0 , hidden_size ), dtype = dtype ),
349349 topk_p = jnp .empty ((0 , topk ), dtype = jnp .float32 ),
350- topk_index = jnp .empty ((0 , topk ), dtype = jnp .int64 ),
350+ topk_index = jnp .empty ((0 , topk ), dtype = jnp .int32 ),
351351 capture_hidden_mode = capture_hidden_mode ,
352352 accept_length = jnp .empty ((0 ,), dtype = jnp .int32 ),
353353 accept_length_cpu = [],
@@ -817,7 +817,7 @@ def verify(
817817 accept_length_cpu = accept_length .tolist ()
818818 if len (unfinished_accept_index ) > 0 :
819819 unfinished_accept_index = jnp .concatenate (unfinished_accept_index )
820- unfinished_index_device = jnp .array (unfinished_index , dtype = jnp .int64 )
820+ unfinished_index_device = jnp .array (unfinished_index , dtype = jnp .int32 )
821821 draft_input_accept_length_cpu = [
822822 accept_length_cpu [i ] for i in unfinished_index
823823 ]
@@ -826,7 +826,7 @@ def verify(
826826 else :
827827 batch .out_cache_loc = jnp .empty (
828828 len (unfinished_index ) + sum (draft_input_accept_length_cpu ),
829- dtype = jnp .int64 ,
829+ dtype = jnp .int32 ,
830830 )
831831 accept_length_filter = create_accept_length_filter (
832832 accept_length ,
@@ -903,16 +903,16 @@ def _generate_simulated_accept_index(
903903 weight_upper = simulate_acc_len_float - lower
904904 weight_lower = 1.0 - weight_upper
905905 # here, data is on cpu
906- probs = numpy .array ([weight_lower , weight_upper ])
907- sampled_index = jax .random .multinomial (rng , probs , shape = ( 1 , ))
906+ probs = jnp .array ([weight_lower , weight_upper ])
907+ sampled_index = jax .random .categorical (rng , jnp . log ( probs ))
908908 simulate_acc_len = lower if sampled_index == 0 else upper
909909 else :
910910 raise ValueError (f"Invalid simulate_acc_method: { SIMULATE_ACC_METHOD } " )
911911
912912 accept_indx_first_col = accept_index [:, 0 ].reshape (- 1 , 1 )
913913 sim_accept_index = jnp .full ((bs , spec_steps + 1 ), - 1 , dtype = jnp .int32 )
914- sim_accept_index [:, :simulate_acc_len ] = accept_indx_first_col + jnp . arange (
915- simulate_acc_len
914+ sim_accept_index = sim_accept_index . at [:, :simulate_acc_len ]. set (
915+ accept_indx_first_col + jnp . arange ( simulate_acc_len )
916916 )
917917 accept_length = accept_length .at [:].set (simulate_acc_len - 1 )
918918 predict = predict .at [:].set (100 ) # some legit token id
0 commit comments