Skip to content

Commit 5f768ee

Browse files
[BUG]fix outcache loc (#15)
1 parent 31a27d3 commit 5f768ee

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/sgl_jax/srt/speculative/eagle_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,9 @@ def draft_forward(self, schedule_batch: ScheduleBatch):
513513
)
514514
if self.hot_token_id is not None:
515515
topk_index = self.hot_token_id[topk_index]
516-
out_cache_loc = out_cache_loc.reshape(
517-
schedule_batch.batch_size(), self.topk, self.speculative_num_steps
518-
)
516+
out_cache_loc = out_cache_loc[
517+
: (schedule_batch.batch_size() * self.topk * self.speculative_num_steps)
518+
].reshape(schedule_batch.batch_size(), self.topk, self.speculative_num_steps)
519519
out_cache_loc = jnp.transpose(out_cache_loc, (2, 0, 1)).reshape(
520520
self.speculative_num_steps, -1
521521
)

0 commit comments

Comments
 (0)