diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 18867ff8a4..6ab636f565 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -430,7 +430,68 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) -def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): +class GaudiDistributedAttention(torch.nn.Module): + def __init__( + self, hpu_module_fsdpa: ModuleFusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8 + ): + super().__init__() + self._hpu_module_fsdpa = hpu_module_fsdpa + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + from deepspeed.sequence.layer import DistributedAttention + + self._hpu_module_fsdpa_distributed = DistributedAttention( + self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 1, 2 + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + dropout_p: float, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return self._hpu_module_fsdpa_distributed( + query, + key, + value, + 0, # As the shape for inputs is [B, N, S, H] + None, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + else: + return self._hpu_module_fsdpa( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + + +def get_gaudi_distributed_attention( + fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed +): if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: return fused_scaled_dot_product_attention_distributed else: @@ -472,14 +533,19 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): if FusedSDPA else None ) - # https://github.com/microsoft/DeepSpeed/issues/4359 # for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices self.fused_scaled_dot_product_attention_distributed = None if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: - from deepspeed.sequence.layer import DistributedAttention - - self.fused_scaled_dot_product_attention_distributed = DistributedAttention( - self.fused_scaled_dot_product_attention, parallel_state.get_sequence_parallel_group(), 1, 2 + self.fused_scaled_dot_product_attention_distributed = ( + GaudiDistributedAttention( + self.fused_scaled_dot_product_attention, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None ) def get_k_proj_weight(self): @@ -696,7 +762,7 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None - fused_scaled_dot_product_attention = GaudiDistributedAttention( + fused_scaled_dot_product_attention = get_gaudi_distributed_attention( self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed ) if use_flash_attention and FusedSDPA is not None: