Skip to content

Commit b602c04

Browse files
[BUG]fix some tiny bug (#16)
* fix outcache loc * fix some dtype bug/ simulate acc bug
1 parent 5f768ee commit b602c04

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

python/sgl_jax/srt/layers/logits_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,11 @@ def __call__(
297297
sample_indices = device_array(
298298
np.array(
299299
sample_indices,
300-
dtype=jnp.int64,
300+
dtype=np.int64,
301301
),
302302
)
303303
input_logprob_indices = device_array(
304-
np.array(input_logprob_indices, dtype=jnp.int64),
304+
np.array(input_logprob_indices, dtype=np.int64),
305305
)
306306

307307
# Compute logits for both input and sampled tokens.

python/sgl_jax/srt/managers/schedule_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ def get_model_worker_batch(
11591159
if positions_cpu is None:
11601160
# For decode: each sequence contributes one token at the next position (seq_len)
11611161
# Create positions for actual tokens (one per sequence at seq_len)
1162-
batch_positions = max(0, seq_lens_cpu - 1)
1162+
batch_positions = np.maximum(0, seq_lens_cpu - 1)
11631163
# Create positions array matching the length of input_ids (including padding)
11641164
positions_cpu = np.zeros(
11651165
len(input_ids_cpu), dtype=batch_positions.dtype

python/sgl_jax/srt/speculative/eagle_util.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def create_idle_input(
347347
verified_id=jnp.empty((0,), dtype=jnp.int32),
348348
hidden_states=jnp.empty((0, hidden_size), dtype=dtype),
349349
topk_p=jnp.empty((0, topk), dtype=jnp.float32),
350-
topk_index=jnp.empty((0, topk), dtype=jnp.int64),
350+
topk_index=jnp.empty((0, topk), dtype=jnp.int32),
351351
capture_hidden_mode=capture_hidden_mode,
352352
accept_length=jnp.empty((0,), dtype=jnp.int32),
353353
accept_length_cpu=[],
@@ -817,7 +817,7 @@ def verify(
817817
accept_length_cpu = accept_length.tolist()
818818
if len(unfinished_accept_index) > 0:
819819
unfinished_accept_index = jnp.concatenate(unfinished_accept_index)
820-
unfinished_index_device = jnp.array(unfinished_index, dtype=jnp.int64)
820+
unfinished_index_device = jnp.array(unfinished_index, dtype=jnp.int32)
821821
draft_input_accept_length_cpu = [
822822
accept_length_cpu[i] for i in unfinished_index
823823
]
@@ -826,7 +826,7 @@ def verify(
826826
else:
827827
batch.out_cache_loc = jnp.empty(
828828
len(unfinished_index) + sum(draft_input_accept_length_cpu),
829-
dtype=jnp.int64,
829+
dtype=jnp.int32,
830830
)
831831
accept_length_filter = create_accept_length_filter(
832832
accept_length,
@@ -903,16 +903,16 @@ def _generate_simulated_accept_index(
903903
weight_upper = simulate_acc_len_float - lower
904904
weight_lower = 1.0 - weight_upper
905905
# here, data is on cpu
906-
probs = numpy.array([weight_lower, weight_upper])
907-
sampled_index = jax.random.multinomial(rng, probs, shape=(1,))
906+
probs = jnp.array([weight_lower, weight_upper])
907+
sampled_index = jax.random.categorical(rng, jnp.log(probs))
908908
simulate_acc_len = lower if sampled_index == 0 else upper
909909
else:
910910
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
911911

912912
accept_indx_first_col = accept_index[:, 0].reshape(-1, 1)
913913
sim_accept_index = jnp.full((bs, spec_steps + 1), -1, dtype=jnp.int32)
914-
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + jnp.arange(
915-
simulate_acc_len
914+
sim_accept_index = sim_accept_index.at[:, :simulate_acc_len].set(
915+
accept_indx_first_col + jnp.arange(simulate_acc_len)
916916
)
917917
accept_length = accept_length.at[:].set(simulate_acc_len - 1)
918918
predict = predict.at[:].set(100) # some legit token id

0 commit comments

Comments
 (0)