Skip to content

Commit 8d0c61f

Browse files
committed
MLA first commit
1 parent 9055c66 commit 8d0c61f

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

src/nanotron/models/llama.py

+159
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Dict, List, Optional, Union
1818

1919
import torch
20+
import torch.nn.functional as F
2021
from torch import nn
2122
from torch.utils.checkpoint import CheckpointFunction
2223

@@ -692,6 +693,164 @@ def forward(
692693
return {"hidden_states": output, "sequence_mask": sequence_mask}
693694

694695

696+
class MLA(nn.Module):
697+
def __init__(
698+
self,
699+
config: LlamaConfig,
700+
parallel_config: Optional[ParallelismArgs],
701+
tp_pg: dist.ProcessGroup,
702+
layer_idx: int,
703+
):
704+
super().__init__()
705+
706+
self.dim = config.hidden_size
707+
self.n_heads = config.num_attention_heads
708+
self.n_local_heads = config.num_attention_heads // tp_pg.size()
709+
self.q_lora_rank = config.q_lora_rank
710+
self.kv_lora_rank = config.kv_lora_rank
711+
self.qk_nope_head_dim = config.qk_nope_head_dim
712+
self.qk_rope_head_dim = config.qk_rope_head_dim
713+
self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
714+
self.v_head_dim = config.v_head_dim
715+
716+
# tp related
717+
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
718+
self.tp_mode = tp_mode
719+
tp_linear_async_communication = (
720+
parallel_config.tp_linear_async_communication if parallel_config is not None else False
721+
)
722+
q_up_contiguous_chunks = (
723+
self.n_heads * self.qk_nope_head_dim, # shape of q_nope
724+
self.n_heads * self.qk_rope_head_dim, # shape of q_rope
725+
)
726+
kv_up_contiguous_chunks = (
727+
self.n_heads * self.qk_nope_head_dim, # shape of k_nope
728+
self.n_heads * self.v_head_dim, # shape of v
729+
)
730+
731+
assert (
732+
self.n_heads % tp_pg.size() == 0
733+
), f"Number of attention heads ({self.n_heads}) must be divisible by TP size ({tp_pg.size()})."
734+
assert (
735+
self.q_lora_rank < self.n_heads * self.qk_head_dim
736+
), f"q_lora_rank ({self.q_lora_rank}) must be less than the product of the number of attention heads ({self.n_heads}) and the number of query/key head dimensions ({self.qk_head_dim})."
737+
assert tp_mode == TensorParallelLinearMode.ALL_REDUCE, "MLA only supports all-reduce TP mode for now"
738+
739+
# Initialize rotary embedding
740+
self.rotary_embedding = RotaryEmbedding(
741+
dim=self.qk_rope_head_dim, end=config.max_position_embeddings, theta=config.rope_theta
742+
)
743+
744+
# Initialize linear layers
745+
self.q_down = nn.Linear(self.dim, self.q_lora_rank, bias=False)
746+
self.q_norm = TritonRMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
747+
self.q_up = TensorParallelColumnLinear(
748+
self.q_lora_rank,
749+
self.n_heads * self.qk_head_dim,
750+
pg=tp_pg,
751+
mode=tp_mode,
752+
bias=False,
753+
async_communication=tp_linear_async_communication,
754+
contiguous_chunks=q_up_contiguous_chunks,
755+
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
756+
)
757+
758+
self.kv_down = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
759+
self.kv_norm = TritonRMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
760+
self.kv_up = TensorParallelColumnLinear(
761+
self.kv_lora_rank,
762+
self.n_heads * (self.qk_nope_head_dim + self.v_head_dim),
763+
pg=tp_pg,
764+
mode=tp_mode,
765+
bias=False,
766+
async_communication=tp_linear_async_communication,
767+
contiguous_chunks=kv_up_contiguous_chunks,
768+
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
769+
)
770+
771+
self.attention = CoreAttention(
772+
config,
773+
parallel_config=parallel_config,
774+
layer_idx=layer_idx,
775+
)
776+
777+
self.o_proj = TensorParallelRowLinear(
778+
self.n_heads * self.v_head_dim,
779+
self.dim,
780+
pg=tp_pg,
781+
mode=tp_mode,
782+
bias=False,
783+
async_communication=tp_linear_async_communication,
784+
)
785+
786+
def forward(
787+
self,
788+
hidden_states, # [seq_length, batch_size, hidden_size]
789+
sequence_mask, # [batch_size, seq_length]
790+
):
791+
seq_len, batch_size, _ = hidden_states.shape
792+
793+
q = self.q_up(self.q_norm(self.q_down(hidden_states)))
794+
q = q.view(seq_len, batch_size, self.n_local_heads, self.qk_head_dim)
795+
q_nope, q_pe = torch.split(
796+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
797+
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim]
798+
q_pe = (
799+
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous()
800+
) # [seq_len, batch_size, n_local_heads, qk_rope_head_dim]
801+
q = torch.cat(
802+
[q_nope, q_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1
803+
) # [seq_len, batch_size, n_heads, qk_head_dim]
804+
805+
kv = self.kv_down(hidden_states) # [seq_len, batch_size, qk_rope_head_dim + kv_lora_rank]
806+
kv, k_pe = torch.split(
807+
kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
808+
) # [seq_len, batch_size, kv_lora_rank], [seq_len, batch_size, qk_rope_head_dim]
809+
k_pe = (
810+
self.rotary_embedding(k_pe.unsqueeze(2).transpose(0, 1), position_ids=None).transpose(0, 1).contiguous()
811+
) # [seq_len, batch_size, 1, qk_rope_head_dim]
812+
kv = self.kv_up(self.kv_norm(kv)) # [seq_len, batch_size, n_local_heads * (qk_nope_head_dim + v_head_dim)]
813+
kv = kv.view(seq_len, batch_size, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
814+
k_nope, v = torch.split(
815+
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
816+
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, v_head_dim]
817+
k = torch.cat(
818+
[k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1
819+
) # [seq_len, batch_size, n_heads, qk_head_dim]
820+
821+
# FA API doesn't seem to support different head dimensions for key and value. Use SDPA instead.
822+
# TODO: a kernel for this.
823+
824+
# (seqlen, b, n_heads, d_qk) -> (b, n_heads, seqlen, d_qk)
825+
q = q.permute(1, 2, 0, 3).contiguous()
826+
k = k.permute(1, 2, 0, 3).contiguous()
827+
v = v.permute(1, 2, 0, 3).contiguous()
828+
829+
# Mask for SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
830+
if sequence_mask is not None:
831+
# Step 1: Create a lower triangular mask (allows attending to past and current positions)
832+
lower_triangular_mask = torch.tril(torch.ones(seq_len, seq_len)) # lower triangle
833+
lower_triangular_mask = lower_triangular_mask.bool() # Convert to boolean mask
834+
# Step 2: Expand the sequence mask to match attention shape [batch_size, seq_len, seq_len]
835+
sequence_mask_expanded = sequence_mask[:, None, :].expand(batch_size, seq_len, seq_len)
836+
# Step 3: Combine both masks: True means allowed, False means masked
837+
final_mask = (lower_triangular_mask) & (sequence_mask_expanded.bool()) # [batch_size, seq_len, seq_len]
838+
final_mask = final_mask.unsqueeze(1).expand(
839+
-1, self.n_local_heads, -1, -1
840+
) # [batch_size, n_local_heads, seq_len, seq_len]
841+
else:
842+
final_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
843+
844+
output = F.scaled_dot_product_attention(q, k, v, attn_mask=final_mask)
845+
output = (
846+
output.permute(2, 0, 1, 3).reshape(seq_len, batch_size, -1).contiguous()
847+
) # (b, n_heads, seqlen, d_qk) -> (seqlen, b, n_heads * d_qk)
848+
849+
output = self.o_proj(output)
850+
851+
return {"hidden_states": output, "sequence_mask": sequence_mask}
852+
853+
695854
class LlamaDecoderLayer(nn.Module):
696855
def __init__(
697856
self,

0 commit comments

Comments
 (0)