From f501981f9e1660ec60078c68c3b3a1f23a64b6e6 Mon Sep 17 00:00:00 2001 From: quic-shagun Date: Thu, 5 Jun 2025 03:36:08 -0700 Subject: [PATCH] BugFix: Fix reshape error for llama swiftkv models Signed-off-by: quic-shagun --- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index f6cf2de49..7b96aefcc 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -371,8 +371,8 @@ def forward( hidden_states = orig_hidden_states[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), last_pos_id, :] causal_mask = causal_mask[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), :, last_pos_id, :] else: - hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :] - causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :] + hidden_states = orig_hidden_states[torch.arange(bsz).reshape(-1, 1), last_pos_id, :] + causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( hidden_states, position_ids, past_key_values, causal_mask, batch_index