diff --git a/optimum/habana/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/optimum/habana/transformers/models/xlm_roberta/modeling_xlm_roberta.py index d5bd2127fd..1c95b836ce 100644 --- a/optimum/habana/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/optimum/habana/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -26,6 +26,15 @@ logger = logging.get_logger(__name__) +def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + """ + Backported from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L180 + """ + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def gaudi_XLMRoberta_Sdpa_SelfAttention_forward( self, hidden_states: torch.Tensor, @@ -74,8 +83,8 @@ def gaudi_XLMRoberta_Sdpa_SelfAttention_forward( if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: key_layer, value_layer = past_key_value else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) + key_layer = transpose_for_scores(self, self.key(current_states)) + value_layer = transpose_for_scores(self, self.value(current_states)) if past_key_value is not None and not is_cross_attention: key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2)