Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLA #278

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Add MLA #278

wants to merge 10 commits into from

Conversation

zzhhjjj
Copy link
Collaborator

@zzhhjjj zzhhjjj commented Feb 5, 2025

Add MLA to Nanotron:

Compared MLA with GQA on 25B tokens.
    1. LM loss:
        MLA: 2.58,  GQA: 2.51
        0.07 difference
    2. Throughput: 
        31279 v.s. 27466 tokens/s/gpu
        87% end-to-end throughput, which is expected due to the MLA structure
    3. KV cache:
        GQA:
                Hidden dim 4096, 32 heads, 8 key values
                2048 for key value
        MLA:
                kv_lora_rank: 512
                qk_rope_head_dim: 64
                in total: 576
        4 times less kv cache

@zzhhjjj
Copy link
Collaborator Author

zzhhjjj commented Feb 17, 2025

Config example

model:   
  ddp_bucket_cap_mb: 25
  dtype: bfloat16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1  
  model_config:  
    ...
    vocab_size: 50272
    # MLA
    q_lora_rank: 1536 
    kv_lora_rank: 512
    qk_nope_head_dim: 128
    qk_rope_head_dim: 64
    v_head_dim: 128

@xrsrke xrsrke self-requested a review February 24, 2025 19:19
Copy link
Member

@xrsrke xrsrke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, it looks good, but I recommend adding unit tests to sanity-check MLA with different values of tp_mode, async_communication, and tp_recompute_allgather, and make sure that the output shape of the MLA class is as expected

Copy link
Member

@xrsrke xrsrke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LFG

q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim]
q_pe = (
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the transpose(0,1) needed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. otherwise the results would be different

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i meant why not transpose from the beginning of the forward in MLA? this way we avoid doing multiple small transposes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because transposes are very slow! and you have a lot of them in MLA's forward

@@ -701,8 +855,9 @@ def __init__(
layer_idx: int,
):
super().__init__()
attn_cls = MLA if config.kv_lora_rank is not None else CausalSelfAttention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather make it more explicit like use config.use_mla here and assert somewhere that the other configs (e.g. kv_lora_rank) are well defined. This can be done in config.py

Copy link
Collaborator Author

@zzhhjjj zzhhjjj Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit redundant to me since kv_lora_rank = MLA in this case, meaning there's no unexpected behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes redundancy is fine if it make code cleaner! I still think we should have use_mla somewhere as kv_lora_rank only relates to MLA for now

Copy link
Member

@NouamaneTazi NouamaneTazi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments!! Looking nice alrdy

q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim]
q_pe = (
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i meant why not transpose from the beginning of the forward in MLA? this way we avoid doing multiple small transposes

@@ -701,8 +855,9 @@ def __init__(
layer_idx: int,
):
super().__init__()
attn_cls = MLA if config.kv_lora_rank is not None else CausalSelfAttention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes redundancy is fine if it make code cleaner! I still think we should have use_mla somewhere as kv_lora_rank only relates to MLA for now

q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim]
q_pe = (
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because transposes are very slow! and you have a lot of them in MLA's forward

if self.model_config.kv_lora_rank is not None:
# set num_key_value_heads to None for MLA(as it's same as num_attention_heads in the paper)
# to avoid unintended errors
self.model_config.num_key_value_heads = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add logger.warning here to warn user

Comment on lines +60 to +64
q_lora_rank: Optional[int] = None
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
v_head_dim: Optional[int] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could regroup these in a MLAConfig to make them separate of the rest, or let's just follow transformers' config standards

)

# Initialize linear layers
self.q_down = nn.Linear(self.dim, self.q_lora_rank, bias=False) # Note: this is duplicated across GPUs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add warning comment please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants