|
17 | 17 | from typing import Dict, List, Optional, Union
|
18 | 18 |
|
19 | 19 | import torch
|
| 20 | +import torch.nn.functional as F |
20 | 21 | from torch import nn
|
21 | 22 | from torch.utils.checkpoint import CheckpointFunction
|
22 | 23 |
|
@@ -692,6 +693,164 @@ def forward(
|
692 | 693 | return {"hidden_states": output, "sequence_mask": sequence_mask}
|
693 | 694 |
|
694 | 695 |
|
| 696 | +class MLA(nn.Module): |
| 697 | + def __init__( |
| 698 | + self, |
| 699 | + config: LlamaConfig, |
| 700 | + parallel_config: Optional[ParallelismArgs], |
| 701 | + tp_pg: dist.ProcessGroup, |
| 702 | + layer_idx: int, |
| 703 | + ): |
| 704 | + super().__init__() |
| 705 | + |
| 706 | + self.dim = config.hidden_size |
| 707 | + self.n_heads = config.num_attention_heads |
| 708 | + self.n_local_heads = config.num_attention_heads // tp_pg.size() |
| 709 | + self.q_lora_rank = config.q_lora_rank |
| 710 | + self.kv_lora_rank = config.kv_lora_rank |
| 711 | + self.qk_nope_head_dim = config.qk_nope_head_dim |
| 712 | + self.qk_rope_head_dim = config.qk_rope_head_dim |
| 713 | + self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim |
| 714 | + self.v_head_dim = config.v_head_dim |
| 715 | + |
| 716 | + # tp related |
| 717 | + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE |
| 718 | + self.tp_mode = tp_mode |
| 719 | + tp_linear_async_communication = ( |
| 720 | + parallel_config.tp_linear_async_communication if parallel_config is not None else False |
| 721 | + ) |
| 722 | + q_up_contiguous_chunks = ( |
| 723 | + self.n_heads * self.qk_nope_head_dim, # shape of q_nope |
| 724 | + self.n_heads * self.qk_rope_head_dim, # shape of q_rope |
| 725 | + ) |
| 726 | + kv_up_contiguous_chunks = ( |
| 727 | + self.n_heads * self.qk_nope_head_dim, # shape of k_nope |
| 728 | + self.n_heads * self.v_head_dim, # shape of v |
| 729 | + ) |
| 730 | + |
| 731 | + assert ( |
| 732 | + self.n_heads % tp_pg.size() == 0 |
| 733 | + ), f"Number of attention heads ({self.n_heads}) must be divisible by TP size ({tp_pg.size()})." |
| 734 | + assert ( |
| 735 | + self.q_lora_rank < self.n_heads * self.qk_head_dim |
| 736 | + ), f"q_lora_rank ({self.q_lora_rank}) must be less than the product of the number of attention heads ({self.n_heads}) and the number of query/key head dimensions ({self.qk_head_dim})." |
| 737 | + assert tp_mode == TensorParallelLinearMode.ALL_REDUCE, "MLA only supports all-reduce TP mode for now" |
| 738 | + |
| 739 | + # Initialize rotary embedding |
| 740 | + self.rotary_embedding = RotaryEmbedding( |
| 741 | + dim=self.qk_rope_head_dim, end=config.max_position_embeddings, theta=config.rope_theta |
| 742 | + ) |
| 743 | + |
| 744 | + # Initialize linear layers |
| 745 | + self.q_down = nn.Linear(self.dim, self.q_lora_rank, bias=False) |
| 746 | + self.q_norm = TritonRMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) |
| 747 | + self.q_up = TensorParallelColumnLinear( |
| 748 | + self.q_lora_rank, |
| 749 | + self.n_heads * self.qk_head_dim, |
| 750 | + pg=tp_pg, |
| 751 | + mode=tp_mode, |
| 752 | + bias=False, |
| 753 | + async_communication=tp_linear_async_communication, |
| 754 | + contiguous_chunks=q_up_contiguous_chunks, |
| 755 | + tp_recompute_allgather=parallel_config.tp_recompute_allgather, |
| 756 | + ) |
| 757 | + |
| 758 | + self.kv_down = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False) |
| 759 | + self.kv_norm = TritonRMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) |
| 760 | + self.kv_up = TensorParallelColumnLinear( |
| 761 | + self.kv_lora_rank, |
| 762 | + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), |
| 763 | + pg=tp_pg, |
| 764 | + mode=tp_mode, |
| 765 | + bias=False, |
| 766 | + async_communication=tp_linear_async_communication, |
| 767 | + contiguous_chunks=kv_up_contiguous_chunks, |
| 768 | + tp_recompute_allgather=parallel_config.tp_recompute_allgather, |
| 769 | + ) |
| 770 | + |
| 771 | + self.attention = CoreAttention( |
| 772 | + config, |
| 773 | + parallel_config=parallel_config, |
| 774 | + layer_idx=layer_idx, |
| 775 | + ) |
| 776 | + |
| 777 | + self.o_proj = TensorParallelRowLinear( |
| 778 | + self.n_heads * self.v_head_dim, |
| 779 | + self.dim, |
| 780 | + pg=tp_pg, |
| 781 | + mode=tp_mode, |
| 782 | + bias=False, |
| 783 | + async_communication=tp_linear_async_communication, |
| 784 | + ) |
| 785 | + |
| 786 | + def forward( |
| 787 | + self, |
| 788 | + hidden_states, # [seq_length, batch_size, hidden_size] |
| 789 | + sequence_mask, # [batch_size, seq_length] |
| 790 | + ): |
| 791 | + seq_len, batch_size, _ = hidden_states.shape |
| 792 | + |
| 793 | + q = self.q_up(self.q_norm(self.q_down(hidden_states))) |
| 794 | + q = q.view(seq_len, batch_size, self.n_local_heads, self.qk_head_dim) |
| 795 | + q_nope, q_pe = torch.split( |
| 796 | + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 |
| 797 | + ) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim] |
| 798 | + q_pe = ( |
| 799 | + self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous() |
| 800 | + ) # [seq_len, batch_size, n_local_heads, qk_rope_head_dim] |
| 801 | + q = torch.cat( |
| 802 | + [q_nope, q_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1 |
| 803 | + ) # [seq_len, batch_size, n_heads, qk_head_dim] |
| 804 | + |
| 805 | + kv = self.kv_down(hidden_states) # [seq_len, batch_size, qk_rope_head_dim + kv_lora_rank] |
| 806 | + kv, k_pe = torch.split( |
| 807 | + kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 |
| 808 | + ) # [seq_len, batch_size, kv_lora_rank], [seq_len, batch_size, qk_rope_head_dim] |
| 809 | + k_pe = ( |
| 810 | + self.rotary_embedding(k_pe.unsqueeze(2).transpose(0, 1), position_ids=None).transpose(0, 1).contiguous() |
| 811 | + ) # [seq_len, batch_size, 1, qk_rope_head_dim] |
| 812 | + kv = self.kv_up(self.kv_norm(kv)) # [seq_len, batch_size, n_local_heads * (qk_nope_head_dim + v_head_dim)] |
| 813 | + kv = kv.view(seq_len, batch_size, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) |
| 814 | + k_nope, v = torch.split( |
| 815 | + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 |
| 816 | + ) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, v_head_dim] |
| 817 | + k = torch.cat( |
| 818 | + [k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1 |
| 819 | + ) # [seq_len, batch_size, n_heads, qk_head_dim] |
| 820 | + |
| 821 | + # FA API doesn't seem to support different head dimensions for key and value. Use SDPA instead. |
| 822 | + # TODO: a kernel for this. |
| 823 | + |
| 824 | + # (seqlen, b, n_heads, d_qk) -> (b, n_heads, seqlen, d_qk) |
| 825 | + q = q.permute(1, 2, 0, 3).contiguous() |
| 826 | + k = k.permute(1, 2, 0, 3).contiguous() |
| 827 | + v = v.permute(1, 2, 0, 3).contiguous() |
| 828 | + |
| 829 | + # Mask for SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html |
| 830 | + if sequence_mask is not None: |
| 831 | + # Step 1: Create a lower triangular mask (allows attending to past and current positions) |
| 832 | + lower_triangular_mask = torch.tril(torch.ones(seq_len, seq_len)) # lower triangle |
| 833 | + lower_triangular_mask = lower_triangular_mask.bool() # Convert to boolean mask |
| 834 | + # Step 2: Expand the sequence mask to match attention shape [batch_size, seq_len, seq_len] |
| 835 | + sequence_mask_expanded = sequence_mask[:, None, :].expand(batch_size, seq_len, seq_len) |
| 836 | + # Step 3: Combine both masks: True means allowed, False means masked |
| 837 | + final_mask = (lower_triangular_mask) & (sequence_mask_expanded.bool()) # [batch_size, seq_len, seq_len] |
| 838 | + final_mask = final_mask.unsqueeze(1).expand( |
| 839 | + -1, self.n_local_heads, -1, -1 |
| 840 | + ) # [batch_size, n_local_heads, seq_len, seq_len] |
| 841 | + else: |
| 842 | + final_mask = torch.tril(torch.ones(seq_len, seq_len)).bool() |
| 843 | + |
| 844 | + output = F.scaled_dot_product_attention(q, k, v, attn_mask=final_mask) |
| 845 | + output = ( |
| 846 | + output.permute(2, 0, 1, 3).reshape(seq_len, batch_size, -1).contiguous() |
| 847 | + ) # (b, n_heads, seqlen, d_qk) -> (seqlen, b, n_heads * d_qk) |
| 848 | + |
| 849 | + output = self.o_proj(output) |
| 850 | + |
| 851 | + return {"hidden_states": output, "sequence_mask": sequence_mask} |
| 852 | + |
| 853 | + |
695 | 854 | class LlamaDecoderLayer(nn.Module):
|
696 | 855 | def __init__(
|
697 | 856 | self,
|
|
0 commit comments