Skip to content

Commit

Permalink
fix(distributed): align mistral attention forward
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Jan 30, 2025
1 parent 894fb2b commit e84e0c6
Showing 1 changed file with 4 additions and 21 deletions.
25 changes: 4 additions & 21 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,12 +720,10 @@ def patch_for_sequence_parallelism(cls, model: "PreTrainedModel", sequence_paral
def attention_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand All @@ -748,12 +746,8 @@ def attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += cache_position[0]

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
Expand All @@ -767,18 +761,7 @@ def attention_forward(

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

attn_weights = attn_weights + attention_mask

# upcast attention to fp32
Expand All @@ -804,7 +787,7 @@ def attention_forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output, attn_weights

for module in model.modules():
if isinstance(module, MistralAttention):
Expand Down

0 comments on commit e84e0c6

Please sign in to comment.