-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLA #278
base: main
Are you sure you want to change the base?
Add MLA #278
Conversation
Config examplemodel:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
...
vocab_size: 50272
# MLA
q_lora_rank: 1536
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, it looks good, but I recommend adding unit tests to sanity-check MLA with different values of tp_mode, async_communication, and tp_recompute_allgather, and make sure that the output shape of the MLA class is as expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LFG
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 | ||
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim] | ||
q_pe = ( | ||
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the transpose(0,1) needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. otherwise the results would be different
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i meant why not transpose from the beginning of the forward in MLA? this way we avoid doing multiple small transposes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because transposes are very slow! and you have a lot of them in MLA's forward
@@ -701,8 +855,9 @@ def __init__( | |||
layer_idx: int, | |||
): | |||
super().__init__() | |||
attn_cls = MLA if config.kv_lora_rank is not None else CausalSelfAttention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather make it more explicit like use config.use_mla
here and assert somewhere that the other configs (e.g. kv_lora_rank) are well defined. This can be done in config.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems a bit redundant to me since kv_lora_rank = MLA in this case, meaning there's no unexpected behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes redundancy is fine if it make code cleaner! I still think we should have use_mla somewhere as kv_lora_rank
only relates to MLA for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments!! Looking nice alrdy
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 | ||
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim] | ||
q_pe = ( | ||
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i meant why not transpose from the beginning of the forward in MLA? this way we avoid doing multiple small transposes
@@ -701,8 +855,9 @@ def __init__( | |||
layer_idx: int, | |||
): | |||
super().__init__() | |||
attn_cls = MLA if config.kv_lora_rank is not None else CausalSelfAttention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes redundancy is fine if it make code cleaner! I still think we should have use_mla somewhere as kv_lora_rank
only relates to MLA for now
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 | ||
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim] | ||
q_pe = ( | ||
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because transposes are very slow! and you have a lot of them in MLA's forward
if self.model_config.kv_lora_rank is not None: | ||
# set num_key_value_heads to None for MLA(as it's same as num_attention_heads in the paper) | ||
# to avoid unintended errors | ||
self.model_config.num_key_value_heads = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add logger.warning here to warn user
q_lora_rank: Optional[int] = None | ||
kv_lora_rank: Optional[int] = None | ||
qk_nope_head_dim: Optional[int] = None | ||
qk_rope_head_dim: Optional[int] = None | ||
v_head_dim: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could regroup these in a MLAConfig to make them separate of the rest, or let's just follow transformers' config standards
) | ||
|
||
# Initialize linear layers | ||
self.q_down = nn.Linear(self.dim, self.q_lora_rank, bias=False) # Note: this is duplicated across GPUs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add warning comment please?
Add MLA to Nanotron: