From 7c77b1d28e4719a0c4f67a5f2032c5058b839cb2 Mon Sep 17 00:00:00 2001 From: Bhargav Date: Thu, 2 Jan 2025 11:38:48 +0530 Subject: [PATCH 1/3] [SW-207148] Add batch dim idx to support latest deepspeed DistributedAttention (#37) * Temp patch for batch_dim * Adding batch_dim_idx as per latest deepspeed * Update modeling_llama.py --- .../models/llama/modeling_llama.py | 88 ++++++++++++++----- 1 file changed, 66 insertions(+), 22 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 67f07437a1..b5d8648ed0 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -427,11 +427,67 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) -def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): - if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: - return fused_scaled_dot_product_attention_distributed - else: - return fused_scaled_dot_product_attention +class GaudiDistributedAttention(torch.nn.Module): + def __init__(self, scale, attention_dropout, enable_recompute, flash_attention_fp8): + super().__init__() + self._hpu_module_fsdpa = ModuleFusedSDPA( + FusedSDPA, + scale=scale, + attention_dropout=attention_dropout, + enable_recompute=enable_recompute, + flash_attention_fp8=flash_attention_fp8, + ) + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + from deepspeed.sequence.layer import DistributedAttention + + self._hpu_module_fsdpa_distributed = DistributedAttention( + self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 1, 2 + ) + + def forward( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return self._hpu_module_fsdpa_distributed( + query, + key, + value, + 0, # As the shape for inputs is [B, N, S, H] + None, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) + else: + return self._hpu_module_fsdpa( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) class GaudiLlamaAttention(LlamaAttention): @@ -459,8 +515,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) self.fused_scaled_dot_product_attention = ( - ModuleFusedSDPA( - FusedSDPA, + GaudiDistributedAttention( scale=self.norm_factor, attention_dropout=self.attention_dropout, enable_recompute=False, @@ -469,15 +524,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): if FusedSDPA else None ) - # https://github.com/microsoft/DeepSpeed/issues/4359 - # for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices - self.fused_scaled_dot_product_attention_distributed = None - if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: - from deepspeed.sequence.layer import DistributedAttention - - self.fused_scaled_dot_product_attention_distributed = DistributedAttention( - self.fused_scaled_dot_product_attention, parallel_state.get_sequence_parallel_group(), 1, 2 - ) def get_k_proj_weight(self): """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" @@ -683,13 +729,11 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None - fused_scaled_dot_product_attention = GaudiDistributedAttention( - self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed - ) + if use_flash_attention and FusedSDPA is not None: if q_len == 1: # next token - attn_output = fused_scaled_dot_product_attention( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -707,7 +751,7 @@ def pre_attn_forward( softmax_mode = "fast" if flash_attention_fast_softmax else "None" if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length - attn_output = fused_scaled_dot_product_attention( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -721,7 +765,7 @@ def pre_attn_forward( "left", ) else: - attn_output = fused_scaled_dot_product_attention( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, From 076fe807884c8c063511e0adc09417e67e95ae30 Mon Sep 17 00:00:00 2001 From: Bhargav Date: Wed, 8 Jan 2025 18:03:00 +0530 Subject: [PATCH 2/3] Restructuring code (#100) --- .../models/llama/modeling_llama.py | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index b5d8648ed0..bd035c2b3e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -428,15 +428,9 @@ def forward(self, cur, dim, idx): class GaudiDistributedAttention(torch.nn.Module): - def __init__(self, scale, attention_dropout, enable_recompute, flash_attention_fp8): + def __init__(self, hpu_module_fsdpa, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() - self._hpu_module_fsdpa = ModuleFusedSDPA( - FusedSDPA, - scale=scale, - attention_dropout=attention_dropout, - enable_recompute=enable_recompute, - flash_attention_fp8=flash_attention_fp8, - ) + self._hpu_module_fsdpa = hpu_module_fsdpa if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: from deepspeed.sequence.layer import DistributedAttention @@ -490,6 +484,13 @@ def forward( ) +def GetGaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + return fused_scaled_dot_product_attention_distributed + else: + return fused_scaled_dot_product_attention + + class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -515,7 +516,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) self.fused_scaled_dot_product_attention = ( - GaudiDistributedAttention( + ModuleFusedSDPA( + FusedSDPA, scale=self.norm_factor, attention_dropout=self.attention_dropout, enable_recompute=False, @@ -524,6 +526,20 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): if FusedSDPA else None ) + # for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices + self.fused_scaled_dot_product_attention_distributed = None + if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: + self.fused_scaled_dot_product_attention_distributed = ( + GaudiDistributedAttention( + self.fused_scaled_dot_product_attention, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) def get_k_proj_weight(self): """4bit quantization in GPTQ replaces the k_proj.weight with qweight.""" @@ -729,11 +745,13 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None - + fused_scaled_dot_product_attention = GetGaudiDistributedAttention( + self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed + ) if use_flash_attention and FusedSDPA is not None: if q_len == 1: # next token - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -751,7 +769,7 @@ def pre_attn_forward( softmax_mode = "fast" if flash_attention_fast_softmax else "None" if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, @@ -765,7 +783,7 @@ def pre_attn_forward( "left", ) else: - attn_output = self.fused_scaled_dot_product_attention( + attn_output = fused_scaled_dot_product_attention( query_states, key_states, value_states, From aad9c4946dc22edfa3f3d1bf9d77623af532f0f1 Mon Sep 17 00:00:00 2001 From: Bhargav Date: Thu, 6 Feb 2025 07:31:07 +0200 Subject: [PATCH 3/3] Addressing review comments --- .../models/llama/modeling_llama.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index bd035c2b3e..ead766c6ee 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -397,9 +397,9 @@ def allocate(self, inp_seq_len, dtype, device, shape): self.inp_seq_len = inp_seq_len self.cache = torch.zeros(shape, dtype=dtype, device=device) else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + assert self.inp_seq_len == inp_seq_len, ( + f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + ) self.cache.fill_(0) @staticmethod @@ -428,7 +428,9 @@ def forward(self, cur, dim, idx): class GaudiDistributedAttention(torch.nn.Module): - def __init__(self, hpu_module_fsdpa, scale, attention_dropout, enable_recompute, flash_attention_fp8): + def __init__( + self, hpu_module_fsdpa: ModuleFusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8 + ): super().__init__() self._hpu_module_fsdpa = hpu_module_fsdpa if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: @@ -440,11 +442,11 @@ def __init__(self, hpu_module_fsdpa, scale, attention_dropout, enable_recompute, def forward( self, - query, - key, - value, - attn_mask, - dropout_p, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + dropout_p: float, is_casual, scale, softmax_mode, @@ -484,7 +486,9 @@ def forward( ) -def GetGaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): +def get_gaudi_distributed_attention( + fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed +): if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: return fused_scaled_dot_product_attention_distributed else: @@ -745,7 +749,7 @@ def pre_attn_forward( kv_seq_len = key_states.shape[-2] else: past_key_value = None - fused_scaled_dot_product_attention = GetGaudiDistributedAttention( + fused_scaled_dot_product_attention = get_gaudi_distributed_attention( self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed ) if use_flash_attention and FusedSDPA is not None: