diff --git a/examples/bench_llama_7b.py b/examples/bench_llama_7b.py index 8c610083..fa757645 100644 --- a/examples/bench_llama_7b.py +++ b/examples/bench_llama_7b.py @@ -20,6 +20,7 @@ TokenizerArgs, TokensArgs, ) +from nanotron.config.config import AdamWOptimizerArgs from nanotron.logging import human_format # Config for a llama model with 6.74M parameters @@ -47,11 +48,13 @@ weight_decay=0.01, clip_grad=1.0, accumulate_grad_in_fp32=True, - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, learning_rate_scheduler=learning_rate, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), ) parallelism = ParallelismArgs( diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 32aab9cd..e752a468 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -188,35 +188,21 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg @checkpoint_method(attr_name="checkpoint_attention") def forward( self, - query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim] - key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) - kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] ): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True + from flash_attn.flash_attn_interface import flash_attn_func # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None - attn_output = flash_attn_varlen_func( + # For now we are assuming that we use causual mask. No magic here + causal = True + attn_output = flash_attn_func( q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], dropout_p=0.0, softmax_scale=softmax_scale, causal=causal, @@ -565,29 +551,14 @@ def forward( # [batch_size, seq_length, num_heads, d_qk] key_states, value_states = torch.split(key_value_states, 1, dim=2) - q_sequence_mask = sequence_mask - kv_sequence_mask = sequence_mask - kv_length = key_states.shape[1] - # [batch_size, seq_length, num_heads, d_qk] - # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func` - query_states = query_states.view( - batch_size * q_length, self.n_local_q_heads, self.d_qk - ) # [batch_size * q_length, self.n_heads, d_qk] - - key_states = key_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_qk - ) # [batch_size * kv_length, self.n_heads, d_qk] - value_states = value_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_v - ) # [batch_size * kv_length, self.n_heads, d_v] + key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk) + value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v) attention_output = self.attention( query_states=query_states, key_states=key_states, value_states=value_states, - q_sequence_mask=q_sequence_mask, - kv_sequence_mask=kv_sequence_mask, ) attention_output = (