diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index f6cf2de49..7b96aefcc 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -371,8 +371,8 @@ def forward( hidden_states = orig_hidden_states[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), last_pos_id, :] causal_mask = causal_mask[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), :, last_pos_id, :] else: - hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :] - causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :] + hidden_states = orig_hidden_states[torch.arange(bsz).reshape(-1, 1), last_pos_id, :] + causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( hidden_states, position_ids, past_key_values, causal_mask, batch_index