Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLA #278

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -42,8 +42,11 @@ pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu1
pip install -e .

# Install dependencies if you want to use the example scripts
pip install datasets transformers
pip install datasets transformers datatrove
pip install triton "flash-attn>=2.5.0" --no-build-isolation

# If you want to use pre-commit
pip install pre-commit
```
> [!NOTE]
> If you get `undefined symbol: ncclCommRegister` error you should install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121`
105 changes: 105 additions & 0 deletions examples/mla/llama_with_mla.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
checkpoints:
checkpoint_interval: 100000
checkpoints_path: /fsx/haojun/long_context_weights/test
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
load_lr_scheduler: false
load_optimizer: false
save_final_state: true
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder:
- /fsx/elie_bakouch/data/fw-edu-dedup
num_loading_workers: 0
seed: 8
name: stable phase
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: MLA
run: long_run
seed: 6
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.0006
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 8192
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 32
num_hidden_layers: 10
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 2
rms_norm_eps: 1.0e-05
rope_interleaved: false
rope_scaling: null
rope_theta: 500000.0
tie_word_embeddings: true
use_cache: true
vocab_size: 128256
# # MLA
q_lora_rank: 1536
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 192
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0005
lr_decay_starting_step: 100000
lr_decay_steps: 80000
lr_decay_style: linear
lr_warmup_steps: 20000
lr_warmup_style: linear
min_decay_lr: 0.00005
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 8
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
tp_recompute_allgather: false
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Llama-3.1-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 2
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 4
sequence_length: 2048
train_steps: 200000
val_check_interval: -1
4 changes: 4 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
@@ -228,6 +228,10 @@ def __post_init__(self):
self.dtype = cast_str_to_torch_dtype(self.dtype)

self.model_config._is_using_mup = isinstance(self.init_method, SpectralMupInit)
if self.model_config.kv_lora_rank is not None:
# set num_key_value_heads to None for MLA(as it's same as num_attention_heads in the paper)
# to avoid unintended errors
self.model_config.num_key_value_heads = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add logger.warning here to warn user


# if self.model_config.max_position_embeddings is None:
# self.model_config.max_position_embeddings = 0
8 changes: 8 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
@@ -55,6 +55,14 @@ class LlamaConfig:
use_cache: bool = True
vocab_size: int = 32000

# MLA
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/config.json
q_lora_rank: Optional[int] = None
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
v_head_dim: Optional[int] = None
Comment on lines +60 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could regroup these in a MLAConfig to make them separate of the rest, or let's just follow transformers' config standards


def __post_init__(self):
# NOTE: user don't set self._init_method, ModelArgs will set it
# then we only pass LlamaConfig around
159 changes: 158 additions & 1 deletion src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
@@ -692,6 +692,162 @@ def forward(
return {"hidden_states": output, "sequence_mask": sequence_mask}


class MLA(nn.Module):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
"""
Implementation of DeepSeek's MLA
Section 2.1.1. Multi-Head Latent Attention
DeepSeek-V3 Technical Report
https://arxiv.org/abs/2412.19437
"""
super().__init__()

self.dim = config.hidden_size
self.n_heads = config.num_attention_heads
self.n_local_heads = config.num_attention_heads // tp_pg.size()
self.q_lora_rank = config.q_lora_rank
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim

# tp related
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
self.tp_mode = tp_mode
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
q_up_contiguous_chunks = (
self.n_heads * self.qk_nope_head_dim, # shape of q_nope
self.n_heads * self.qk_rope_head_dim, # shape of q_rope
)
kv_up_contiguous_chunks = (
self.n_heads * self.qk_nope_head_dim, # shape of k_nope
self.n_heads * self.v_head_dim, # shape of v
)

assert (
self.n_heads % tp_pg.size() == 0
), f"Number of attention heads ({self.n_heads}) must be divisible by TP size ({tp_pg.size()})."
assert (
self.q_lora_rank < self.n_heads * self.qk_head_dim
), f"q_lora_rank ({self.q_lora_rank}) must be less than the product of the number of attention heads ({self.n_heads}) and the number of query/key head dimensions ({self.qk_head_dim})."
assert tp_mode == TensorParallelLinearMode.ALL_REDUCE, "MLA only supports all-reduce TP mode for now"
# TODO: support different head dimensions for query/key and value
assert (
self.qk_head_dim == self.v_head_dim
), "MLA only supports equal query/key and value head dimensions for now"

# Initialize rotary embedding
self.rotary_embedding = RotaryEmbedding(
dim=self.qk_rope_head_dim, end=config.max_position_embeddings, theta=config.rope_theta
)

# Initialize linear layers
self.q_down = nn.Linear(self.dim, self.q_lora_rank, bias=False) # Note: this is duplicated across GPUs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add warning comment please?

self.q_norm = TritonRMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_up = TensorParallelColumnLinear(
self.q_lora_rank,
self.n_heads * self.qk_head_dim,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=q_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)

self.kv_down = nn.Linear(
self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False
) # Note: this is duplicated across GPUs
self.kv_norm = TritonRMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_up = TensorParallelColumnLinear(
self.kv_lora_rank,
self.n_heads * (self.qk_nope_head_dim + self.v_head_dim),
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=kv_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)

self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)

self.o_proj = TensorParallelRowLinear(
self.n_heads * self.v_head_dim,
self.dim,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
)

def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
):
seq_len, batch_size, _ = hidden_states.shape

q = self.q_up(self.q_norm(self.q_down(hidden_states)))
q = q.view(seq_len, batch_size, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, qk_rope_head_dim]
q_pe = (
self.rotary_embedding(q_pe.transpose(0, 1), position_ids=None).transpose(0, 1).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the transpose(0,1) needed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. otherwise the results would be different

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i meant why not transpose from the beginning of the forward in MLA? this way we avoid doing multiple small transposes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because transposes are very slow! and you have a lot of them in MLA's forward

) # [seq_len, batch_size, n_local_heads, qk_rope_head_dim]
q = torch.cat(
[q_nope, q_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1
) # [seq_len, batch_size, n_heads, qk_head_dim]

kv = self.kv_down(hidden_states) # [seq_len, batch_size, qk_rope_head_dim + kv_lora_rank]
kv, k_pe = torch.split(
kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
) # [seq_len, batch_size, kv_lora_rank], [seq_len, batch_size, qk_rope_head_dim]
k_pe = (
self.rotary_embedding(k_pe.unsqueeze(2).transpose(0, 1), position_ids=None).transpose(0, 1).contiguous()
) # [seq_len, batch_size, 1, qk_rope_head_dim]
kv = self.kv_up(self.kv_norm(kv)) # [seq_len, batch_size, n_local_heads * (qk_nope_head_dim + v_head_dim)]
kv = kv.view(seq_len, batch_size, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
) # [seq_len, batch_size, n_local_heads, qk_nope_head_dim], [seq_len, batch_size, n_local_heads, v_head_dim]
k = torch.cat(
[k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1
) # [seq_len, batch_size, n_heads, qk_head_dim]

q = q.transpose(0, 1).contiguous() # [batch_size, seq_len, n_heads, qk_head_dim]
k = k.transpose(0, 1).contiguous() # [batch_size, seq_len, n_heads, qk_head_dim]
v = v.transpose(0, 1).contiguous() # [batch_size, seq_len, n_heads, v_head_dim]

q = q.view(batch_size * seq_len, self.n_local_heads, self.qk_head_dim)
k = k.view(batch_size * seq_len, self.n_local_heads, self.qk_head_dim)
v = v.view(batch_size * seq_len, self.n_local_heads, self.v_head_dim)

output = self.attention(q, k, v, sequence_mask, sequence_mask)

output = (
output.view(batch_size, seq_len, self.n_local_heads * self.v_head_dim).transpose(0, 1).contiguous()
) # [seq_len, batch_size, n_heads, v_head_dim]

output = self.o_proj(output)

return {"hidden_states": output, "sequence_mask": sequence_mask}


class LlamaDecoderLayer(nn.Module):
def __init__(
self,
@@ -701,8 +857,9 @@ def __init__(
layer_idx: int,
):
super().__init__()
attn_cls = MLA if config.kv_lora_rank is not None else CausalSelfAttention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather make it more explicit like use config.use_mla here and assert somewhere that the other configs (e.g. kv_lora_rank) are well defined. This can be done in config.py

Copy link
Collaborator Author

@zzhhjjj zzhhjjj Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit redundant to me since kv_lora_rank = MLA in this case, meaning there's no unexpected behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes redundancy is fine if it make code cleaner! I still think we should have use_mla somewhere as kv_lora_rank only relates to MLA for now

self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
self.attn = attn_cls(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
2 changes: 1 addition & 1 deletion src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
@@ -414,7 +414,7 @@ def backward(ctx, grad_output: torch.Tensor):
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
if group.size() == 1:
sub_grad_input = grad_input
sub_grad_input = grad_input.reshape(input_size) # [s*b, h_in] -> [s, b, h_in]
else:
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
# We set grad_input to be contiguous in case it isn't already.
10 changes: 9 additions & 1 deletion src/nanotron/scaling/parametrization.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
TensorParallelRowLinear,
)
from torch import nn
from torch.nn import init
from torch.nn import Linear, init


class ParametrizationMethod(Enum):
@@ -38,11 +38,19 @@ def __init__(self, config: ModelArgs):
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
Linear: self._parametrize_linear_layer,
}

self.std = config.init_method.std
self.num_layers = config.model_config.num_hidden_layers

def _parametrize_linear_layer(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
elif "bias" == param_name:
module.bias.zero_()

def _parametrize_column_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]

Loading