From d7bf8be86a6c9ae1b20ece6f90fcccac57e4f438 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 29 Jan 2025 12:47:04 +0000 Subject: [PATCH 01/40] first draft of domino forward pass --- examples/config_tiny_llama_domino.yaml | 113 ++++++++++++++++++ src/nanotron/config/parallelism_config.py | 17 +++ src/nanotron/models/llama.py | 90 +++++++++++--- src/nanotron/optim/gradient_accumulator.py | 4 + .../distributed_differentiable_primitives.py | 21 ++-- .../parallel/tensor_parallel/functional.py | 12 +- src/nanotron/parallel/tensor_parallel/nn.py | 5 +- 7 files changed, 235 insertions(+), 27 deletions(-) create mode 100644 examples/config_tiny_llama_domino.yaml diff --git a/examples/config_tiny_llama_domino.yaml b/examples/config_tiny_llama_domino.yaml new file mode 100644 index 00000000..66e22dbd --- /dev/null +++ b/examples/config_tiny_llama_domino.yaml @@ -0,0 +1,113 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + # dp: 2 + # pp: 2 + dp: 1 + pp: 1 + tp: 2 + expert_parallel_size: 1 + pp_engine: 1f1b + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + domino: + num_input_batches: 2 +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 15 + val_check_interval: -1 diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 7f20ad99..2701bf9c 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -11,6 +11,22 @@ from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +@dataclass +class DominoArgs: + """ + Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping + https://arxiv.org/abs/2409.15241 + """ + + # NOTE: if the number of input batches is 1, + # it's equivalent to non-domino mode + # so if you want to enable domino mode, set this to > 1 + num_input_batches: int + + def __post_init__(self): + assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1" + + @dataclass class ParallelismArgs: """Arguments related to TP/PP/DP @@ -37,6 +53,7 @@ class ParallelismArgs: tp_recompute_allgather: bool = True expert_parallel_size: int = 1 + domino: Optional[DominoArgs] = None def __post_init__(self): # Conservative defaults diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..bc495624 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -237,13 +237,14 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) - hidden_states = self.down_proj(self.split_silu_mul(merged_states)) - return {"hidden_states": hidden_states} + hidden_states, work = self.down_proj(self.split_silu_mul(merged_states)) + return {"hidden_states": hidden_states, "work": work} class CoreAttention(nn.Module): @@ -335,6 +336,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, + async_all_reduce: bool = False, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -418,6 +420,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + async_all_reduce=async_all_reduce, ) self.attention = CoreAttention( @@ -687,9 +690,9 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output = self.o_proj(attention_output) + output, work = self.o_proj(attention_output) - return {"hidden_states": output, "sequence_mask": sequence_mask} + return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} class LlamaDecoderLayer(nn.Module): @@ -702,36 +705,80 @@ def __init__( ): super().__init__() self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, + async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) self.recompute_layer = parallel_config.recompute_layer + self.parallel_config = parallel_config def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> List[Union[torch.Tensor, TensorPointer]]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) - hidden_states = output["hidden_states"] - hidden_states = hidden_states + residual + num_input_batches = self.parallel_config.domino.num_input_batches + assert num_input_batches == 2 + hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) + orig_sequence_mask = sequence_mask + sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) + + hidden_states0, hidden_states1 = hidden_states + sequence_mask0, sequence_mask1 = sequence_mask + + # # Combine the chunks into a list of dictionaries + # hidden_encoder_states_list = [ + # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} + # for i in range(num_input_batches) + # ] + + residual0 = hidden_states0 + residual1 = hidden_states1 + + hidden_states0 = self.input_layernorm(hidden_states0) + hidden_states1 = self.input_layernorm(hidden_states1) + + attn_output0 = self.attn(hidden_states=hidden_states0, sequence_mask=sequence_mask0) + attn_output0_work = attn_output0["work"] + + attn_output1 = self.attn(hidden_states=hidden_states1, sequence_mask=sequence_mask1) + attn_output1_work = attn_output1["work"] + + attn_output0_work.wait() + hidden_states0 = attn_output0["hidden_states"] + hidden_states0 = hidden_states0 + residual0 + residual0 = hidden_states0 + hidden_states0 = self.post_attention_layernorm(hidden_states0) - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] - hidden_states = hidden_states + residual + mlp_output0 = self.mlp(hidden_states=hidden_states0) - return hidden_states, output["sequence_mask"] + attn_output1_work.wait() + hidden_states1 = attn_output1["hidden_states"] + hidden_states1 = hidden_states1 + residual1 + residual1 = hidden_states1 + hidden_states1 = self.post_attention_layernorm(hidden_states1) + + mlp_output1 = self.mlp(hidden_states=hidden_states1) + mlp_output0["work"].wait() + mlp_output1["work"].wait() + + hidden_states0 = mlp_output0["hidden_states"] + hidden_states1 = mlp_output1["hidden_states"] + + hidden_states0 = hidden_states0 + residual0 + hidden_states1 = hidden_states1 + residual1 + + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + return hidden_states, orig_sequence_mask def _checkpointed_forward( self, @@ -899,9 +946,24 @@ def forward_with_hidden_states( "hidden_states": output["input_embeds"], "sequence_mask": input_mask, } + + # assert 1 == 1 + # num_input_batches = self.parallel_config.domino.num_input_batches + # hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1) + # hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0) + + # # Combine the chunks into a list of dictionaries + # hidden_encoder_states_list = [ + # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} + # for i in range(num_input_batches) + # ] + for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + # for hidden_encoder_states in hidden_encoder_states_list: + # hidden_encoder_states = encoder_block(**hidden_encoder_states) + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 2e940744..ba0d5dd0 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -202,6 +202,10 @@ def build_grad_buffers( return fp32_grad_buffers, contiguous_buffer_f32_gradients def backward(self, loss: torch.Tensor): + if isinstance(loss, tuple): + assert 1 == 1 + raise NotImplementedError("Not implemented yet") + result = loss.backward() for name, elt in self.fp32_grad_buffers.items(): diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index bd41347a..0ae9a4de 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from torch import distributed as torch_dist @@ -32,23 +32,28 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllReduceSum.apply(grad_output, group), None + return DifferentiableAllReduceSum.apply(grad_output, group, False), None class DifferentiableAllReduceSum(torch.autograd.Function): """All-reduce in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward( + ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool + ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: if group.size() == 1: return tensor - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) - return tensor + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce) + if async_all_reduce: + return tensor, handle + else: + return tensor, None @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -134,8 +139,8 @@ def differentiable_identity(tensor, group: Optional[ProcessGroup] = None): return DifferentiableIdentity.apply(tensor, group) -def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableAllReduceSum.apply(tensor, group) +def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False): + return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index e2ee3a29..ffb3e3f9 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import Optional, Tuple import torch from torch.nn import functional as F @@ -587,18 +587,22 @@ def row_linear( bias: Optional[torch.Tensor], group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, + # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, -): + async_all_reduce: bool, +) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group) + out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) + work = None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") - return out + return out, work diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 4c7325cd..19cbdf88 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -111,6 +111,7 @@ def __init__( device=None, dtype=None, async_communication: bool = False, + async_all_reduce: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, ): self.pg = pg @@ -133,6 +134,7 @@ def __init__( ) self.mode = mode self.async_communication = async_communication + self.async_all_reduce = async_all_reduce if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") @@ -166,6 +168,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + async_all_reduce=self.async_all_reduce, ) def extra_repr(self) -> str: @@ -290,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: out = out * (~input_mask[..., None]) if self.mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=self.pg) + out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=self.pg) else: From 3803b1927e4063c2b329857dfbd0497778a235b1 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 30 Jan 2025 14:03:36 +0000 Subject: [PATCH 02/40] support the backward pass --- src/nanotron/optim/gradient_accumulator.py | 2 +- src/nanotron/parallel/comm.py | 26 +++++++++++++++++++ .../parallel/pipeline_parallel/engine.py | 8 ++++-- .../distributed_differentiable_primitives.py | 22 +++++++++++++--- .../parallel/tensor_parallel/functional.py | 10 ++++++- src/nanotron/parallel/tensor_parallel/nn.py | 2 +- 6 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 src/nanotron/parallel/comm.py diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index ba0d5dd0..b5ef7d89 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -202,7 +202,7 @@ def build_grad_buffers( return fp32_grad_buffers, contiguous_buffer_f32_gradients def backward(self, loss: torch.Tensor): - if isinstance(loss, tuple): + if not isinstance(loss, torch.Tensor): assert 1 == 1 raise NotImplementedError("Not implemented yet") diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py new file mode 100644 index 00000000..76e33e21 --- /dev/null +++ b/src/nanotron/parallel/comm.py @@ -0,0 +1,26 @@ +from typing import Dict + + +class AsyncCommBucket: + """ + + Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + RuntimeError: expected Variable or None (got tuple) + Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + RuntimeError: expected Variable or None (got tuple) + """ + + _async_op: Dict[int, "dist.Work"] = {} + + @staticmethod + def add(tensor_id: int, work: "dist.Work"): + AsyncCommBucket._async_op[tensor_id] = work + + @staticmethod + def get(tensor_id: int): + return AsyncCommBucket._async_op.get(tensor_id) + + @staticmethod + def wait(tensor_id: int): + work = AsyncCommBucket._async_op.pop(tensor_id) + work.wait() diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..8160f302 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -12,8 +15,6 @@ from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -83,6 +84,9 @@ def backward( if grad_accumulator is None: sum(activations).backward() else: + # if not isinstance(activations, torch.Tensor): + # raise NotImplementedError("Only support sum of tensors for now") + grad_accumulator.backward(sum(activations)) # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 0ae9a4de..05ade53d 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,6 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup +from nanotron.parallel.comm import AsyncCommBucket class DifferentiableIdentity(torch.autograd.Function): @@ -42,14 +43,29 @@ class DifferentiableAllReduceSum(torch.autograd.Function): def forward( ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: + # ctx.mark_non_differentiable(async_all_reduce) + ctx.async_all_reduce = async_all_reduce + if group.size() == 1: return tensor + orig_id = id(tensor) handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce) + # if async_all_reduce: + # handle.wait() + new_id = id(tensor) + assert 1 == 1 + assert orig_id == new_id + # if async_all_reduce: + # return tensor, handle + # else: + # return tensor, None if async_all_reduce: - return tensor, handle - else: - return tensor, None + # AsyncCommBucket.add(tensor, handle) + # AsyncCommBucket.add(id(tensor), handle) + AsyncCommBucket.add(orig_id, handle) + + return tensor @staticmethod def backward(ctx, grad_output): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index ffb3e3f9..b05a50bf 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -597,7 +597,15 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + orig_out_id = id(out) + # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? + out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + if async_all_reduce: + from nanotron.parallel.comm import AsyncCommBucket + + work = AsyncCommBucket.get(orig_out_id) + assert 1 == 1 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 19cbdf88..14a7486a 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -293,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: out = out * (~input_mask[..., None]) if self.mode is TensorParallelLinearMode.ALL_REDUCE: - out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) + out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=self.pg) else: From d765fd57e29a30b0a083012bed6e443a2df72b7f Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 31 Jan 2025 14:20:32 +0000 Subject: [PATCH 03/40] the first draft for bwd overlapping --- src/nanotron/constants.py | 3 + src/nanotron/helpers.py | 2 + src/nanotron/models/llama.py | 57 ++++++++++++++++--- src/nanotron/parallel/comm.py | 46 +++++++++++++++ .../distributed_differentiable_primitives.py | 27 ++++++--- .../parallel/tensor_parallel/functional.py | 6 +- src/nanotron/parallel/tensor_parallel/nn.py | 5 +- 7 files changed, 127 insertions(+), 19 deletions(-) diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99d..78fd0bb9 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -10,3 +10,6 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" + + +CUDA_STREAMS = {} diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 73ca3484..7f31d812 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -482,7 +482,9 @@ def get_profiler(config: Config): on_trace_ready=on_trace_ready, # record_shapes=True, # profile_memory=True, + with_flops=True, with_stack=True, + with_modules=True, ) else: prof = contextlib.nullcontext() diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index bc495624..fb112f8e 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,6 +30,7 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import WaitComm from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P @@ -46,6 +47,8 @@ logger = logging.get_logger(__name__) +DOMINO_COMM_STREAM = "domino_comm_stream_{}" + class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -241,8 +244,8 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states, handle_idx) hidden_states, work = self.down_proj(self.split_silu_mul(merged_states)) return {"hidden_states": hidden_states, "work": work} @@ -437,6 +440,7 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] + handle_idx=None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -445,7 +449,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states + hidden_states, handle_idx=handle_idx ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -720,6 +724,18 @@ def __init__( self.recompute_layer = parallel_config.recompute_layer self.parallel_config = parallel_config + # if parallel_config.domino is not None and parallel_config.domino.num_input_batches > 1: + # from nanotron.parallel.comm import CudaStreamManager + # # NOTE: we use different cuda streams for different gpus, so it can overlaps the communication + # CudaStreamManager.create(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) + num_gpus = torch.cuda.device_count() + for i in range(num_gpus): + from nanotron import constants + + constants.CUDA_STREAMS[i] = torch.cuda.Stream(device=torch.device(f"cuda:{i}")) + + self.layer_idx = layer_idx + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -747,29 +763,52 @@ def _core_forward( hidden_states0 = self.input_layernorm(hidden_states0) hidden_states1 = self.input_layernorm(hidden_states1) - attn_output0 = self.attn(hidden_states=hidden_states0, sequence_mask=sequence_mask0) + attn_output0 = self.attn( + hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=f"layer_{self.layer_idx}_batch_0" + ) attn_output0_work = attn_output0["work"] - attn_output1 = self.attn(hidden_states=hidden_states1, sequence_mask=sequence_mask1) + attn_output1 = self.attn( + hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=f"layer_{self.layer_idx}_batch_1" + ) attn_output1_work = attn_output1["work"] - attn_output0_work.wait() + from nanotron import constants + + comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] + # comm_stream = CudaStreamManager.get(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) + with torch.cuda.stream(comm_stream): + attn_output0_work.wait() + # attn_output0_work.wait() + hidden_states0 = attn_output0["hidden_states"] hidden_states0 = hidden_states0 + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) + hidden_states0 = WaitComm.apply(hidden_states0, f"layer_{self.layer_idx}_batch_0") + # mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_{self.layer_idx}_batch_0") mlp_output0 = self.mlp(hidden_states=hidden_states0) - attn_output1_work.wait() + with torch.cuda.stream(comm_stream): + attn_output1_work.wait() + # attn_output1_work.wait() + hidden_states1 = attn_output1["hidden_states"] hidden_states1 = hidden_states1 + residual1 residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) + hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") + # mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_{self.layer_idx}_batch_1") mlp_output1 = self.mlp(hidden_states=hidden_states1) - mlp_output0["work"].wait() - mlp_output1["work"].wait() + + with torch.cuda.stream(comm_stream): + mlp_output0["work"].wait() + mlp_output1["work"].wait() + + # mlp_output0["work"].wait() + # mlp_output1["work"].wait() hidden_states0 = mlp_output0["hidden_states"] hidden_states1 = mlp_output1["hidden_states"] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 76e33e21..e26e2814 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -1,5 +1,27 @@ +from contextlib import contextmanager from typing import Dict +import torch + + +class CudaStreamManager: + _streams: Dict[str, "torch.cuda.Stream"] = {} + + @staticmethod + def create(name: str): + assert name not in CudaStreamManager._streams + CudaStreamManager._streams[name] = torch.cuda.Stream() + + @staticmethod + def get(name: str): + return CudaStreamManager._streams.get(name) + + @contextmanager + def run_on_stream(name: str): + stream = CudaStreamManager.get(name) + with torch.cuda.stream(stream): + yield stream + class AsyncCommBucket: """ @@ -14,13 +36,37 @@ class AsyncCommBucket: @staticmethod def add(tensor_id: int, work: "dist.Work"): + assert ( + tensor_id not in AsyncCommBucket._async_op + ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}" AsyncCommBucket._async_op[tensor_id] = work @staticmethod def get(tensor_id: int): return AsyncCommBucket._async_op.get(tensor_id) + @staticmethod + def pop(tensor_id: int): + return AsyncCommBucket._async_op.pop(tensor_id) + @staticmethod def wait(tensor_id: int): work = AsyncCommBucket._async_op.pop(tensor_id) work.wait() + + +class WaitComm(torch.autograd.Function): + @staticmethod + def forward(ctx, input, wait_handle_idx): + ctx.wait_handle_idx = wait_handle_idx + return input + + @staticmethod + def backward(ctx, grad_output): + import pydevd + + pydevd.settrace(suspend=False, trace_only_current_thread=True) + if ctx.wait_handle_idx != "layer_1_batch_1": + handle = AsyncCommBucket.pop(ctx.wait_handle_idx) + handle.wait() + return grad_output, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 05ade53d..38c6bafd 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -26,14 +26,23 @@ class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward(ctx, tensor, group: Optional[ProcessGroup], handle_idx=None): + # assert handle_idx is not None + ctx.handle_idx = handle_idx ctx.group = group return tensor @staticmethod def backward(ctx, grad_output): + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async + # assert ctx.handle_idx is not None group = ctx.group - return DifferentiableAllReduceSum.apply(grad_output, group, False), None + if ctx.handle_idx is not None: + assert 1 == 1 + + return DifferentiableAllReduceSum.apply(grad_output, group, True, ctx.handle_idx), None, None class DifferentiableAllReduceSum(torch.autograd.Function): @@ -41,7 +50,7 @@ class DifferentiableAllReduceSum(torch.autograd.Function): @staticmethod def forward( - ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool + ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: # ctx.mark_non_differentiable(async_all_reduce) ctx.async_all_reduce = async_all_reduce @@ -63,13 +72,17 @@ def forward( if async_all_reduce: # AsyncCommBucket.add(tensor, handle) # AsyncCommBucket.add(id(tensor), handle) - AsyncCommBucket.add(orig_id, handle) + # try: + # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) + # except Exception as e: + # assert 1 == 1 + AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) return tensor @staticmethod def backward(ctx, grad_output): - return grad_output, None, None + return grad_output, None, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -151,8 +164,8 @@ def backward(ctx, grad_output): # ----------------- -def differentiable_identity(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableIdentity.apply(tensor, group) +def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, handle_idx=None): + return DifferentiableIdentity.apply(tensor, group, handle_idx) def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index b05a50bf..ff43c98b 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -436,12 +436,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + handle_idx: Optional[int] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group) + input = differentiable_identity(input, group=group, handle_idx=handle_idx) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -604,7 +605,8 @@ def row_linear( if async_all_reduce: from nanotron.parallel.comm import AsyncCommBucket - work = AsyncCommBucket.get(orig_out_id) + # work = AsyncCommBucket.get(orig_out_id) + work = AsyncCommBucket.pop(orig_out_id) assert 1 == 1 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 14a7486a..f4ceff63 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -52,6 +52,7 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, + # handle_idx: Optional[int] = None, ): self.pg = pg self.world_size = pg.size() @@ -72,6 +73,7 @@ def __init__( self.mode = mode self.async_communication = async_communication + # self.handle_idx = handle_idx if contiguous_chunks is not None: assert ( @@ -85,7 +87,7 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -94,6 +96,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, + handle_idx=handle_idx, ) def extra_repr(self) -> str: From 9924608fccfddcf6bb87548610653407bc581ef5 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 3 Feb 2025 09:29:03 +0000 Subject: [PATCH 04/40] add backward pass overlapping --- src/nanotron/models/llama.py | 21 +++++++++++++-------- src/nanotron/parallel/comm.py | 12 +++++++++--- src/nanotron/trainer.py | 4 ++++ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index fb112f8e..8dbcc661 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -764,12 +764,16 @@ def _core_forward( hidden_states1 = self.input_layernorm(hidden_states1) attn_output0 = self.attn( - hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=f"layer_{self.layer_idx}_batch_0" + hidden_states=hidden_states0, + sequence_mask=sequence_mask0, + handle_idx=f"layer_attn_{self.layer_idx}_batch_0", ) attn_output0_work = attn_output0["work"] attn_output1 = self.attn( - hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=f"layer_{self.layer_idx}_batch_1" + hidden_states=hidden_states1, + sequence_mask=sequence_mask1, + handle_idx=f"layer_attn_{self.layer_idx}_batch_1", ) attn_output1_work = attn_output1["work"] @@ -785,10 +789,11 @@ def _core_forward( hidden_states0 = hidden_states0 + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - hidden_states0 = WaitComm.apply(hidden_states0, f"layer_{self.layer_idx}_batch_0") + hidden_states0 = WaitComm.apply(hidden_states0, f"layer_mlp_{self.layer_idx}_batch_1") - # mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_{self.layer_idx}_batch_0") - mlp_output0 = self.mlp(hidden_states=hidden_states0) + mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_mlp_{self.layer_idx}_batch_0") + mlp_output0 = WaitComm.apply(mlp_output0, f"layer_mlp_{self.layer_idx}_batch_1") + # mlp_output0 = self.mlp(hidden_states=hidden_states0) with torch.cuda.stream(comm_stream): attn_output1_work.wait() @@ -798,10 +803,10 @@ def _core_forward( hidden_states1 = hidden_states1 + residual1 residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") + # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") - # mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_{self.layer_idx}_batch_1") - mlp_output1 = self.mlp(hidden_states=hidden_states1) + mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_mlp_{self.layer_idx}_batch_1") + # mlp_output1 = self.mlp(hidden_states=hidden_states1) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index e26e2814..dd99f20a 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -54,6 +54,10 @@ def wait(tensor_id: int): work = AsyncCommBucket._async_op.pop(tensor_id) work.wait() + @staticmethod + def clear_all(): + AsyncCommBucket._async_op.clear() + class WaitComm(torch.autograd.Function): @staticmethod @@ -63,10 +67,12 @@ def forward(ctx, input, wait_handle_idx): @staticmethod def backward(ctx, grad_output): - import pydevd + # import pydevd - pydevd.settrace(suspend=False, trace_only_current_thread=True) - if ctx.wait_handle_idx != "layer_1_batch_1": + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + # if ctx.wait_handle_idx != "layer_1_batch_1": + if ctx.wait_handle_idx != "layer_30_batch_1": handle = AsyncCommBucket.pop(ctx.wait_handle_idx) handle.wait() + # assert 1 == 1 return grad_output, None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 94b03c6e..52d5df48 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -578,6 +578,10 @@ def training_step( self.post_train_step() + from nanotron.parallel.comm import AsyncCommBucket + + AsyncCommBucket.clear_all() + return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: From d6bc8da4a5df9f35b3345148639a4775b90de568 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 4 Feb 2025 14:32:38 +0000 Subject: [PATCH 05/40] fix some ops dont execute in the bwd pass --- examples/config_tiny_llama_domino.yaml | 113 ------------------ src/nanotron/models/llama.py | 49 ++++++-- src/nanotron/parallel/comm.py | 23 +++- .../distributed_differentiable_primitives.py | 84 ++++++++----- .../parallel/tensor_parallel/functional.py | 14 ++- src/nanotron/parallel/tensor_parallel/nn.py | 6 +- 6 files changed, 129 insertions(+), 160 deletions(-) delete mode 100644 examples/config_tiny_llama_domino.yaml diff --git a/examples/config_tiny_llama_domino.yaml b/examples/config_tiny_llama_domino.yaml deleted file mode 100644 index 66e22dbd..00000000 --- a/examples/config_tiny_llama_domino.yaml +++ /dev/null @@ -1,113 +0,0 @@ -checkpoints: - checkpoint_interval: 10 - checkpoints_path: checkpoints - checkpoints_path_is_shared_file_system: false - resume_checkpoint_path: null - save_initial_state: false -data_stages: -- data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 - name: Stable Training Stage - start_training_step: 1 -- data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 - name: Annealing Phase - start_training_step: 10 -general: - benchmark_csv_path: null - consumed_train_samples: null - ignore_sanity_checks: true - project: debug - run: tiny_llama_%date_%jobid - seed: 42 - step: null -lighteval: null -logging: - iteration_step_info_interval: 1 - log_level: info - log_level_replica: info -model: - ddp_bucket_cap_mb: 25 - dtype: bfloat16 - init_method: - std: 0.025 - make_vocab_size_divisible_by: 1 - model_config: - bos_token_id: 1 - eos_token_id: 2 - hidden_act: silu - hidden_size: 16 - initializer_range: 0.02 - intermediate_size: 64 - is_llama_config: true - max_position_embeddings: 256 - num_attention_heads: 4 - num_hidden_layers: 2 - num_key_value_heads: 4 - pad_token_id: null - pretraining_tp: 1 - rms_norm_eps: 1.0e-05 - rope_scaling: null - tie_word_embeddings: true - use_cache: true - vocab_size: 256 -optimizer: - accumulate_grad_in_fp32: true - clip_grad: 1.0 - learning_rate_scheduler: - learning_rate: 0.0003 - lr_decay_starting_step: null - lr_decay_steps: 13 - lr_decay_style: cosine - lr_warmup_steps: 2 - lr_warmup_style: linear - min_decay_lr: 1.0e-05 - optimizer_factory: - adam_beta1: 0.9 - adam_beta2: 0.95 - adam_eps: 1.0e-08 - name: adamW - torch_adam_is_fused: true - weight_decay: 0.01 - zero_stage: 0 -parallelism: - # dp: 2 - # pp: 2 - dp: 1 - pp: 1 - tp: 2 - expert_parallel_size: 1 - pp_engine: 1f1b - tp_linear_async_communication: false - tp_mode: ALL_REDUCE - domino: - num_input_batches: 2 -profiler: null -tokenizer: - tokenizer_max_length: null - tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel - tokenizer_revision: null -tokens: - batch_accumulation_per_replica: 1 - limit_test_batches: 0 - limit_val_batches: 0 - micro_batch_size: 2 - sequence_length: 256 - train_steps: 15 - val_check_interval: -1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 8dbcc661..47394240 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -49,6 +49,11 @@ DOMINO_COMM_STREAM = "domino_comm_stream_{}" +FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" +BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" +FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" +BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" + class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -245,8 +250,8 @@ def __init__( self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states, handle_idx) - hidden_states, work = self.down_proj(self.split_silu_mul(merged_states)) + merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx) + hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx) return {"hidden_states": hidden_states, "work": work} @@ -449,7 +454,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, handle_idx=handle_idx + hidden_states, async_all_reduce=True, handle_idx=handle_idx ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -694,7 +699,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output) + output, work = self.o_proj(attention_output, handle_idx=handle_idx) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -766,14 +771,26 @@ def _core_forward( attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, - handle_idx=f"layer_attn_{self.layer_idx}_batch_0", + # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_0", + handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + ) + attn_output0["hidden_states"] = WaitComm.apply( + attn_output0["hidden_states"], + # f"bwd.layer_attn_{self.layer_idx}_batch_1" + BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), ) attn_output0_work = attn_output0["work"] attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, - handle_idx=f"layer_attn_{self.layer_idx}_batch_1", + # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_1", + handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + ) + attn_output1["hidden_states"] = WaitComm.apply( + attn_output1["hidden_states"], + # f"bwd.layer_mlp_{self.layer_idx}_batch_0" + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) attn_output1_work = attn_output1["work"] @@ -789,10 +806,18 @@ def _core_forward( hidden_states0 = hidden_states0 + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - hidden_states0 = WaitComm.apply(hidden_states0, f"layer_mlp_{self.layer_idx}_batch_1") + # hidden_states0 = WaitComm.apply(hidden_states0, f"bwd.layer_mlp_{self.layer_idx}_batch_0") - mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_mlp_{self.layer_idx}_batch_0") - mlp_output0 = WaitComm.apply(mlp_output0, f"layer_mlp_{self.layer_idx}_batch_1") + mlp_output0 = self.mlp( + hidden_states=hidden_states0, + # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_0" + handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + ) + mlp_output0["hidden_states"] = WaitComm.apply( + mlp_output0["hidden_states"], + # f"bwd.layer_mlp_{self.layer_idx}_batch_1" + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + ) # mlp_output0 = self.mlp(hidden_states=hidden_states0) with torch.cuda.stream(comm_stream): @@ -805,7 +830,11 @@ def _core_forward( hidden_states1 = self.post_attention_layernorm(hidden_states1) # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") - mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_mlp_{self.layer_idx}_batch_1") + mlp_output1 = self.mlp( + hidden_states=hidden_states1, + # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_1" + handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + ) # mlp_output1 = self.mlp(hidden_states=hidden_states1) with torch.cuda.stream(comm_stream): diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index dd99f20a..c2b18e05 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -47,6 +47,7 @@ def get(tensor_id: int): @staticmethod def pop(tensor_id: int): + assert tensor_id in AsyncCommBucket._async_op, f"tensor_id: {tensor_id}" return AsyncCommBucket._async_op.pop(tensor_id) @staticmethod @@ -59,6 +60,17 @@ def clear_all(): AsyncCommBucket._async_op.clear() +def is_async_comm(x): + import re + + NON_ASYNC_HANDLE_IDX = ["bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0"] + + patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers + regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex + not_async = bool(regex.match(x)) + return not not_async + + class WaitComm(torch.autograd.Function): @staticmethod def forward(ctx, input, wait_handle_idx): @@ -68,11 +80,16 @@ def forward(ctx, input, wait_handle_idx): @staticmethod def backward(ctx, grad_output): # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - # if ctx.wait_handle_idx != "layer_1_batch_1": - if ctx.wait_handle_idx != "layer_30_batch_1": + + if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: + assert 1 == 1 + + # if ctx.wait_handle_idx != "bwd.layer_mlp_1_batch_1": + # if ctx.wait_handle_idx != "layer_30_batch_1": + if is_async_comm(ctx.wait_handle_idx): handle = AsyncCommBucket.pop(ctx.wait_handle_idx) + assert handle is not None handle.wait() # assert 1 == 1 return grad_output, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 38c6bafd..3badfb34 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -26,8 +26,9 @@ class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup], handle_idx=None): + def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None): # assert handle_idx is not None + ctx.async_all_reduce = async_all_reduce ctx.handle_idx = handle_idx ctx.group = group return tensor @@ -39,10 +40,31 @@ def backward(ctx, grad_output): # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async # assert ctx.handle_idx is not None group = ctx.group - if ctx.handle_idx is not None: - assert 1 == 1 - return DifferentiableAllReduceSum.apply(grad_output, group, True, ctx.handle_idx), None, None + if ctx.handle_idx is not None and "fwd." in ctx.handle_idx: + handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") + # if "bwd.layer_mlp_1_batch_1" == handle_idx: + # from nanotron.parallel.comm import is_async_comm + # async_all_reduce = is_async_comm(handle_idx) + # else: + # async_all_reduce = ctx.async_all_reduce + from nanotron.parallel.comm import is_async_comm + + async_all_reduce = is_async_comm(handle_idx) + else: + handle_idx = ctx.handle_idx + async_all_reduce = ctx.async_all_reduce + + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None + + +def is_last_batch_of_attn(x): + import re + + pattern = r"layer_attn_\d+_batch_0" + if re.match(pattern, x): + return True + return False class DifferentiableAllReduceSum(torch.autograd.Function): @@ -52,31 +74,33 @@ class DifferentiableAllReduceSum(torch.autograd.Function): def forward( ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: - # ctx.mark_non_differentiable(async_all_reduce) ctx.async_all_reduce = async_all_reduce if group.size() == 1: return tensor - orig_id = id(tensor) - handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce) - # if async_all_reduce: - # handle.wait() - new_id = id(tensor) - assert 1 == 1 - assert orig_id == new_id - # if async_all_reduce: - # return tensor, handle - # else: - # return tensor, None - if async_all_reduce: - # AsyncCommBucket.add(tensor, handle) - # AsyncCommBucket.add(id(tensor), handle) - # try: - # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) - # except Exception as e: - # assert 1 == 1 - AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) + if handle_idx == "bwd.layer_mlp_1_batch_0": + assert 1 == 1 + + id(tensor) + if async_all_reduce is True: + if isinstance(handle_idx, str): + do_async = is_last_batch_of_attn(handle_idx) is False + else: + do_async = async_all_reduce + + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) + if do_async: + # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx + # if handle_idx is not None and "bwd." in handle_idx: + # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) + # else: + # AsyncCommBucket.add(orig_id, handle) + # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx + assert handle_idx is not None + AsyncCommBucket.add(handle_idx, handle) + else: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) return tensor @@ -164,12 +188,16 @@ def backward(ctx, grad_output): # ----------------- -def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, handle_idx=None): - return DifferentiableIdentity.apply(tensor, group, handle_idx) +def differentiable_identity( + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None +): + return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx) -def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False): - return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce) +def differentiable_all_reduce_sum( + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None +): + return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index ff43c98b..a3f1248e 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -436,13 +436,14 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + async_all_reduce: bool = False, handle_idx: Optional[int] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group, handle_idx=handle_idx) + input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -591,6 +592,7 @@ def row_linear( # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, async_all_reduce: bool, + handle_idx=None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -599,14 +601,18 @@ def row_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) - orig_out_id = id(out) + id(out) # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? - out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) if async_all_reduce: from nanotron.parallel.comm import AsyncCommBucket # work = AsyncCommBucket.get(orig_out_id) - work = AsyncCommBucket.pop(orig_out_id) + # work = AsyncCommBucket.pop(orig_out_id) + if handle_idx == "fwd.layer_mlp_1_batch_0": + assert 1 == 1 + + work = AsyncCommBucket.pop(handle_idx) assert 1 == 1 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index f4ceff63..847454fd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -87,7 +87,7 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -96,6 +96,7 @@ def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, + async_all_reduce=async_all_reduce, handle_idx=handle_idx, ) @@ -163,7 +164,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -172,6 +173,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, async_all_reduce=self.async_all_reduce, + handle_idx=handle_idx, ) def extra_repr(self) -> str: From 93b2f106bb91e4f9328546d77dfc169348eb12b9 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 5 Feb 2025 10:49:03 +0000 Subject: [PATCH 06/40] fix can't find an ops in fwd --- src/nanotron/parallel/comm.py | 8 +++++++- .../distributed_differentiable_primitives.py | 11 +++++++---- src/nanotron/parallel/tensor_parallel/functional.py | 6 +++++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index c2b18e05..459ea23d 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -63,7 +63,13 @@ def clear_all(): def is_async_comm(x): import re - NON_ASYNC_HANDLE_IDX = ["bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0"] + NON_ASYNC_HANDLE_IDX = [ + # "fwd.layer_attn_{}_batch_0", + # "fwd.layer_mlp_{}_batch_0", + # "fwd.layer_mlp_{}_batch_1", + "bwd.layer_mlp_{}_batch_1", + "bwd.layer_attn_{}_batch_0", + ] patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 3badfb34..9d65878b 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -84,10 +84,13 @@ def forward( id(tensor) if async_all_reduce is True: - if isinstance(handle_idx, str): - do_async = is_last_batch_of_attn(handle_idx) is False - else: - do_async = async_all_reduce + # if isinstance(handle_idx, str): + # do_async = is_last_batch_of_attn(handle_idx) is False + # else: + # do_async = async_all_reduce + from nanotron.parallel.comm import is_async_comm + + do_async = is_async_comm(handle_idx) handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) if do_async: diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index a3f1248e..f0ca3a0d 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -603,13 +603,17 @@ def row_linear( # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) id(out) # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? + if handle_idx == "fwd.layer_attn_0_batch_0": + assert 1 == 1 + out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) if async_all_reduce: from nanotron.parallel.comm import AsyncCommBucket # work = AsyncCommBucket.get(orig_out_id) # work = AsyncCommBucket.pop(orig_out_id) - if handle_idx == "fwd.layer_mlp_1_batch_0": + # if handle_idx == "fwd.layer_mlp_1_batch_0": + if handle_idx == "fwd.layer_attn_0_batch_0": assert 1 == 1 work = AsyncCommBucket.pop(handle_idx) From 31db05dafaf3190c1e3c33b4eb0b87cb9e5f0f04 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 5 Feb 2025 16:49:44 +0000 Subject: [PATCH 07/40] partially overlapping bwd pass --- src/nanotron/models/llama.py | 166 ++++++++++++------ src/nanotron/parallel/comm.py | 5 +- .../distributed_differentiable_primitives.py | 6 + src/nanotron/trainer.py | 3 +- 4 files changed, 123 insertions(+), 57 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 47394240..acbece96 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -50,9 +50,9 @@ DOMINO_COMM_STREAM = "domino_comm_stream_{}" FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" -BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" +BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" class RotaryEmbedding(nn.Module): @@ -741,116 +741,176 @@ def __init__( self.layer_idx = layer_idx + # def _core_forward( + # self, + # hidden_states: Union[torch.Tensor, TensorPointer], + # sequence_mask: Union[torch.Tensor, TensorPointer], + # ) -> List[Union[torch.Tensor, TensorPointer]]: + # from nanotron import constants + + # num_input_batches = self.parallel_config.domino.num_input_batches + # orig_sequence_mask = sequence_mask + + # assert num_input_batches == 2 + # hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) + # sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) + + # hidden_states0, hidden_states1 = hidden_states + # sequence_mask0, sequence_mask1 = sequence_mask + + # residual0 = hidden_states0 + # residual1 = hidden_states1 + + # hidden_states0 = self.input_layernorm(hidden_states0) + # hidden_states1 = self.input_layernorm(hidden_states1) + + # attn_output0 = self.attn( + # hidden_states=hidden_states0, + # sequence_mask=sequence_mask0, + # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + # ) + # # attn_output0["hidden_states"] = WaitComm.apply( + # # attn_output0["hidden_states"], + # # BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + # # ) + + # attn_output1 = self.attn( + # hidden_states=hidden_states1, + # sequence_mask=sequence_mask1, + # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + # ) + # # attn_output1["hidden_states"] = WaitComm.apply( + # # attn_output1["hidden_states"], + # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + # # ) + + # comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] + # with torch.cuda.stream(comm_stream): + # attn_output0["work"].wait() + + # hidden_states0 = attn_output0["hidden_states"] + residual0 + # residual0 = hidden_states0 + # hidden_states0 = self.post_attention_layernorm(hidden_states0) + # hidden_states0 = WaitComm.apply( + # hidden_states0, + # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + # ) # new + + # mlp_output0 = self.mlp( + # hidden_states=hidden_states0, + # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + # ) + # # mlp_output0["hidden_states"] = WaitComm.apply( + # # mlp_output0["hidden_states"], + # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + # # ) + + # with torch.cuda.stream(comm_stream): + # attn_output1["work"].wait() + + # hidden_states1 = attn_output1["hidden_states"] + residual1 + # residual1 = hidden_states1 + # hidden_states1 = self.post_attention_layernorm(hidden_states1) + # hidden_states1 = WaitComm.apply( + # hidden_states1, + # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + # ) + + # mlp_output1 = self.mlp( + # hidden_states=hidden_states1, + # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + # ) + + # with torch.cuda.stream(comm_stream): + # mlp_output0["work"].wait() + # mlp_output1["work"].wait() + + # hidden_states0 = mlp_output0["hidden_states"] + residual0 + # hidden_states1 = mlp_output1["hidden_states"] + residual1 + + # hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + # return hidden_states, orig_sequence_mask + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> List[Union[torch.Tensor, TensorPointer]]: + from nanotron import constants num_input_batches = self.parallel_config.domino.num_input_batches + orig_sequence_mask = sequence_mask + assert num_input_batches == 2 hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) - orig_sequence_mask = sequence_mask sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) hidden_states0, hidden_states1 = hidden_states sequence_mask0, sequence_mask1 = sequence_mask - # # Combine the chunks into a list of dictionaries - # hidden_encoder_states_list = [ - # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} - # for i in range(num_input_batches) - # ] - residual0 = hidden_states0 residual1 = hidden_states1 hidden_states0 = self.input_layernorm(hidden_states0) hidden_states1 = self.input_layernorm(hidden_states1) + hidden_states0 = WaitComm.apply( + hidden_states0, + BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + ) + hidden_states1 = WaitComm.apply( + hidden_states1, + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + ) attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, - # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_0", handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), ) - attn_output0["hidden_states"] = WaitComm.apply( - attn_output0["hidden_states"], - # f"bwd.layer_attn_{self.layer_idx}_batch_1" - BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), - ) - attn_output0_work = attn_output0["work"] - attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, - # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_1", handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), ) - attn_output1["hidden_states"] = WaitComm.apply( - attn_output1["hidden_states"], - # f"bwd.layer_mlp_{self.layer_idx}_batch_0" - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - ) - attn_output1_work = attn_output1["work"] - - from nanotron import constants comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] - # comm_stream = CudaStreamManager.get(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) with torch.cuda.stream(comm_stream): - attn_output0_work.wait() - # attn_output0_work.wait() + attn_output0["work"].wait() - hidden_states0 = attn_output0["hidden_states"] - hidden_states0 = hidden_states0 + residual0 + hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - # hidden_states0 = WaitComm.apply(hidden_states0, f"bwd.layer_mlp_{self.layer_idx}_batch_0") + hidden_states0 = WaitComm.apply( + hidden_states0, + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + ) # new mlp_output0 = self.mlp( hidden_states=hidden_states0, - # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_0" handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) - mlp_output0["hidden_states"] = WaitComm.apply( - mlp_output0["hidden_states"], - # f"bwd.layer_mlp_{self.layer_idx}_batch_1" - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - ) - # mlp_output0 = self.mlp(hidden_states=hidden_states0) with torch.cuda.stream(comm_stream): - attn_output1_work.wait() - # attn_output1_work.wait() + attn_output1["work"].wait() - hidden_states1 = attn_output1["hidden_states"] - hidden_states1 = hidden_states1 + residual1 + hidden_states1 = attn_output1["hidden_states"] + residual1 residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") mlp_output1 = self.mlp( hidden_states=hidden_states1, - # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_1" handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), ) - # mlp_output1 = self.mlp(hidden_states=hidden_states1) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() mlp_output1["work"].wait() - # mlp_output0["work"].wait() - # mlp_output1["work"].wait() - - hidden_states0 = mlp_output0["hidden_states"] - hidden_states1 = mlp_output1["hidden_states"] - - hidden_states0 = hidden_states0 + residual0 - hidden_states1 = hidden_states1 + residual1 + hidden_states0 = mlp_output0["hidden_states"] + residual0 + hidden_states1 = mlp_output1["hidden_states"] + residual1 hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + assert 1 == 1 return hidden_states, orig_sequence_mask def _checkpointed_forward( diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 459ea23d..b00f6e9e 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -91,11 +91,10 @@ def backward(ctx, grad_output): if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: assert 1 == 1 - # if ctx.wait_handle_idx != "bwd.layer_mlp_1_batch_1": - # if ctx.wait_handle_idx != "layer_30_batch_1": if is_async_comm(ctx.wait_handle_idx): handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - # assert 1 == 1 + # assert handle.is_completed() is True + return grad_output, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 9d65878b..c4f69c05 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -55,6 +55,9 @@ def backward(ctx, grad_output): handle_idx = ctx.handle_idx async_all_reduce = ctx.async_all_reduce + if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: + assert 1 == 1 + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None @@ -94,6 +97,9 @@ def forward( handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) if do_async: + if "bwd" in handle_idx: + assert 1 == 1 + # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx # if handle_idx is not None and "bwd." in handle_idx: # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 52d5df48..96af6b12 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -580,7 +580,8 @@ def training_step( from nanotron.parallel.comm import AsyncCommBucket - AsyncCommBucket.clear_all() + assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" + # AsyncCommBucket.clear_all() return outputs, loss_avg From 23f210815e06147cf95fff3fede4641cc1c43101 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 10 Feb 2025 16:33:23 +0000 Subject: [PATCH 08/40] fix stream not sync --- src/nanotron/constants.py | 4 + src/nanotron/models/llama.py | 46 ++++---- src/nanotron/parallel/comm.py | 15 ++- src/nanotron/parallel/dependency.py | 102 ++++++++++++++++++ .../distributed_differentiable_primitives.py | 4 + src/nanotron/parallel/tensor_parallel/nn.py | 8 +- src/nanotron/trainer.py | 17 ++- 7 files changed, 172 insertions(+), 24 deletions(-) create mode 100644 src/nanotron/parallel/dependency.py diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 78fd0bb9..3fe440a8 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -13,3 +13,7 @@ CUDA_STREAMS = {} + +CLOCK = 0 +_AUTOGRAD_RUNS = [] +_NOT_BWD_ASYNC_OPS = [] diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index acbece96..72ebf478 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -245,13 +245,15 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - async_all_reduce=parallel_config.domino.num_input_batches > 1, + # async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx) - hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx) + hidden_states, work = self.down_proj( + self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx + ) return {"hidden_states": hidden_states, "work": work} @@ -428,7 +430,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - async_all_reduce=async_all_reduce, + # async_all_reduce=async_all_reduce, ) self.attention = CoreAttention( @@ -699,7 +701,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output, handle_idx=handle_idx) + output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -876,6 +878,7 @@ def _core_forward( comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] with torch.cuda.stream(comm_stream): attn_output0["work"].wait() + attn_output0["work"].is_completed() hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 @@ -890,8 +893,16 @@ def _core_forward( handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) + # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend( + # run_after=attn_output1["hidden_states"], + # run_before=mlp_output0["hidden_states"] + # ) + with torch.cuda.stream(comm_stream): attn_output1["work"].wait() + attn_output1["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 residual1 = hidden_states1 @@ -906,11 +917,24 @@ def _core_forward( mlp_output0["work"].wait() mlp_output1["work"].wait() + mlp_output0["work"].is_completed() + mlp_output1["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) + hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 + # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1) + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) assert 1 == 1 + + # assert attn_output0["work"].is_completed() + # assert attn_output1["work"].is_completed() + # assert mlp_output0["work"].is_completed() + # assert mlp_output1["work"].is_completed() + return hidden_states, orig_sequence_mask def _checkpointed_forward( @@ -1080,23 +1104,9 @@ def forward_with_hidden_states( "sequence_mask": input_mask, } - # assert 1 == 1 - # num_input_batches = self.parallel_config.domino.num_input_batches - # hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1) - # hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0) - - # # Combine the chunks into a list of dictionaries - # hidden_encoder_states_list = [ - # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} - # for i in range(num_input_batches) - # ] - for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) - # for hidden_encoder_states in hidden_encoder_states_list: - # hidden_encoder_states = encoder_block(**hidden_encoder_states) - hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index b00f6e9e..789416c3 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -33,6 +33,7 @@ class AsyncCommBucket: """ _async_op: Dict[int, "dist.Work"] = {} + _copy_async_op: Dict[int, "dist.Work"] = {} @staticmethod def add(tensor_id: int, work: "dist.Work"): @@ -40,6 +41,7 @@ def add(tensor_id: int, work: "dist.Work"): tensor_id not in AsyncCommBucket._async_op ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}" AsyncCommBucket._async_op[tensor_id] = work + AsyncCommBucket._copy_async_op[tensor_id] = work @staticmethod def get(tensor_id: int): @@ -58,6 +60,7 @@ def wait(tensor_id: int): @staticmethod def clear_all(): AsyncCommBucket._async_op.clear() + AsyncCommBucket._copy_async_op.clear() def is_async_comm(x): @@ -92,9 +95,19 @@ def backward(ctx, grad_output): assert 1 == 1 if is_async_comm(ctx.wait_handle_idx): + from nanotron.constants import _AUTOGRAD_RUNS + + _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}") handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - # assert handle.is_completed() is True + # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" + else: + + from nanotron import constants + + # if dist.get_rank() == 0: + # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) + constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) return grad_output, None diff --git a/src/nanotron/parallel/dependency.py b/src/nanotron/parallel/dependency.py new file mode 100644 index 00000000..6a633d8a --- /dev/null +++ b/src/nanotron/parallel/dependency.py @@ -0,0 +1,102 @@ +from typing import Dict, Tuple + +import torch +from torch import Tensor + +_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} + + +def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: + """Gets a phony. Phony is tensor without space. It is useful to make + arbitrary dependency in a autograd graph because it doesn't require any + gradient accumulation. + + .. note:: + + Phonies for each device are cached. If an autograd function gets a phony + internally, the phony must be detached to be returned. Otherwise, the + autograd engine will mutate the cached phony in-place:: + + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() # detach() is necessary. + + """ + key = (device, requires_grad) + + try: + phony = _phonies[key] + except KeyError: + with torch.cuda.stream(torch.cuda.default_stream(device)): + phony = torch.empty(0, device=device, requires_grad=requires_grad) + + _phonies[key] = phony + + return phony + + +def fork(input: Tensor) -> Tuple[Tensor, Tensor]: + """Branches out from an autograd lane of the given tensor.""" + if torch.is_grad_enabled() and input.requires_grad: + input, phony = Fork.apply(input) + else: + phony = get_phony(input.device, requires_grad=False) + + return input, phony + + +class Fork(torch.autograd.Function): + @staticmethod + def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + phony = get_phony(input.device, requires_grad=False) + return input, phony.detach() + + @staticmethod + def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + return grad_input + + +def join(input: Tensor, phony: Tensor) -> Tensor: + """Merges two autograd lanes.""" + if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): + input = Join.apply(input, phony) + + return input + + +class Join(torch.autograd.Function): + @staticmethod + def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore + return input + + @staticmethod + def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + return grad_input, None + + +# def depend(fork_from, join_to) -> None: +# # Ensure that batches[i-1] is executed after batches[i] in +# # # backpropagation by an explicit dependency. +# # if i != 0: +# # depend(batches[i-1], batches[i]) +# # depend(run_after, run_before) +# fork_from, phony = fork(fork_from) +# join_to = join(join_to, phony) +# return fork_from, join_to + + +def depend(run_after, run_before) -> None: + # Ensure that batches[i-1] is executed after batches[i] in + # # backpropagation by an explicit dependency. + # if i != 0: + # depend(batches[i-1], batches[i]) + # depend(run_after, run_before) + run_after, phony = fork(run_after) + run_before = join(run_before, phony) + return run_after, run_before diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index c4f69c05..58275368 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -58,6 +58,10 @@ def backward(ctx, grad_output): if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: assert 1 == 1 + from nanotron.constants import _AUTOGRAD_RUNS + + _AUTOGRAD_RUNS.append(handle_idx) + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 847454fd..4fea1838 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -115,7 +115,7 @@ def __init__( device=None, dtype=None, async_communication: bool = False, - async_all_reduce: bool = False, + # async_all_reduce: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, ): self.pg = pg @@ -138,7 +138,7 @@ def __init__( ) self.mode = mode self.async_communication = async_communication - self.async_all_reduce = async_all_reduce + # self.async_all_reduce = async_all_reduce if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") @@ -164,7 +164,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -172,7 +172,7 @@ def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - async_all_reduce=self.async_all_reduce, + async_all_reduce=async_all_reduce, handle_idx=handle_idx, ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 96af6b12..e58af9f5 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -564,6 +564,9 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) + if dist.get_rank() == 0: + assert 1 == 1 + # Apply gradient self.optimizer.step() self.optimizer.zero_grad() @@ -580,8 +583,20 @@ def training_step( from nanotron.parallel.comm import AsyncCommBucket + # import torch.distributed as dist + + not_finished = [] + for k, v in AsyncCommBucket._copy_async_op.items(): + # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}" + if v.is_completed() is not True: + not_finished.append((k, v)) + + # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: + # assert 1 == 1 + + assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" - # AsyncCommBucket.clear_all() + AsyncCommBucket.clear_all() return outputs, loss_avg From 3e3ae8c414f841da6c90155a914374de7b9eaecd Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 21 Feb 2025 18:20:57 +0000 Subject: [PATCH 09/40] exp2a1c7c2_like_exp2a1c1_domini_llama3_3b_with_tp8_and_seqlen4096_and_mbs2_and_gbs_300k_and_input_splitting_and_commit_23f2_but_remove_call_is_async_comm_twice_and_keep_not_async_bwd.layer_mlp_1__and_bwd.layer_attn_0 --- src/nanotron/models/llama.py | 105 ------------------ src/nanotron/parallel/comm.py | 4 + .../distributed_differentiable_primitives.py | 64 +---------- 3 files changed, 9 insertions(+), 164 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 72ebf478..79f75855 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -743,96 +743,6 @@ def __init__( self.layer_idx = layer_idx - # def _core_forward( - # self, - # hidden_states: Union[torch.Tensor, TensorPointer], - # sequence_mask: Union[torch.Tensor, TensorPointer], - # ) -> List[Union[torch.Tensor, TensorPointer]]: - # from nanotron import constants - - # num_input_batches = self.parallel_config.domino.num_input_batches - # orig_sequence_mask = sequence_mask - - # assert num_input_batches == 2 - # hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) - # sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) - - # hidden_states0, hidden_states1 = hidden_states - # sequence_mask0, sequence_mask1 = sequence_mask - - # residual0 = hidden_states0 - # residual1 = hidden_states1 - - # hidden_states0 = self.input_layernorm(hidden_states0) - # hidden_states1 = self.input_layernorm(hidden_states1) - - # attn_output0 = self.attn( - # hidden_states=hidden_states0, - # sequence_mask=sequence_mask0, - # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), - # ) - # # attn_output0["hidden_states"] = WaitComm.apply( - # # attn_output0["hidden_states"], - # # BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), - # # ) - - # attn_output1 = self.attn( - # hidden_states=hidden_states1, - # sequence_mask=sequence_mask1, - # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), - # ) - # # attn_output1["hidden_states"] = WaitComm.apply( - # # attn_output1["hidden_states"], - # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - # # ) - - # comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] - # with torch.cuda.stream(comm_stream): - # attn_output0["work"].wait() - - # hidden_states0 = attn_output0["hidden_states"] + residual0 - # residual0 = hidden_states0 - # hidden_states0 = self.post_attention_layernorm(hidden_states0) - # hidden_states0 = WaitComm.apply( - # hidden_states0, - # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - # ) # new - - # mlp_output0 = self.mlp( - # hidden_states=hidden_states0, - # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - # ) - # # mlp_output0["hidden_states"] = WaitComm.apply( - # # mlp_output0["hidden_states"], - # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - # # ) - - # with torch.cuda.stream(comm_stream): - # attn_output1["work"].wait() - - # hidden_states1 = attn_output1["hidden_states"] + residual1 - # residual1 = hidden_states1 - # hidden_states1 = self.post_attention_layernorm(hidden_states1) - # hidden_states1 = WaitComm.apply( - # hidden_states1, - # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - # ) - - # mlp_output1 = self.mlp( - # hidden_states=hidden_states1, - # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - # ) - - # with torch.cuda.stream(comm_stream): - # mlp_output0["work"].wait() - # mlp_output1["work"].wait() - - # hidden_states0 = mlp_output0["hidden_states"] + residual0 - # hidden_states1 = mlp_output1["hidden_states"] + residual1 - - # hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) - # return hidden_states, orig_sequence_mask - def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -892,12 +802,6 @@ def _core_forward( hidden_states=hidden_states0, handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) - - # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend( - # run_after=attn_output1["hidden_states"], - # run_before=mlp_output0["hidden_states"] - # ) - with torch.cuda.stream(comm_stream): attn_output1["work"].wait() attn_output1["work"].is_completed() @@ -925,16 +829,7 @@ def _core_forward( hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 - # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1) - hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) - assert 1 == 1 - - # assert attn_output0["work"].is_completed() - # assert attn_output1["work"].is_completed() - # assert mlp_output0["work"].is_completed() - # assert mlp_output1["work"].is_completed() - return hidden_states, orig_sequence_mask def _checkpointed_forward( diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 789416c3..3b0e3bf4 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -73,6 +73,10 @@ def is_async_comm(x): "bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] + # NON_ASYNC_HANDLE_IDX = [ + # "fwd.layer_mlp_{}_batch_1", + # "bwd.layer_attn_{}_batch_0", + # ] patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 58275368..11501f97 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -27,7 +27,6 @@ class DifferentiableIdentity(torch.autograd.Function): @staticmethod def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None): - # assert handle_idx is not None ctx.async_all_reduce = async_all_reduce ctx.handle_idx = handle_idx ctx.group = group @@ -35,45 +34,15 @@ def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, @staticmethod def backward(ctx, grad_output): - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async - # assert ctx.handle_idx is not None group = ctx.group - if ctx.handle_idx is not None and "fwd." in ctx.handle_idx: - handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") - # if "bwd.layer_mlp_1_batch_1" == handle_idx: - # from nanotron.parallel.comm import is_async_comm - # async_all_reduce = is_async_comm(handle_idx) - # else: - # async_all_reduce = ctx.async_all_reduce - from nanotron.parallel.comm import is_async_comm - - async_all_reduce = is_async_comm(handle_idx) - else: - handle_idx = ctx.handle_idx - async_all_reduce = ctx.async_all_reduce - - if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: - assert 1 == 1 - - from nanotron.constants import _AUTOGRAD_RUNS - - _AUTOGRAD_RUNS.append(handle_idx) + from nanotron.parallel.comm import is_async_comm + handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") if ctx.handle_idx is not None else None + async_all_reduce = is_async_comm(handle_idx) if handle_idx is not None else ctx.async_all_reduce return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None -def is_last_batch_of_attn(x): - import re - - pattern = r"layer_attn_\d+_batch_0" - if re.match(pattern, x): - return True - return False - - class DifferentiableAllReduceSum(torch.autograd.Function): """All-reduce in a differentiable fashion""" @@ -86,32 +55,9 @@ def forward( if group.size() == 1: return tensor - if handle_idx == "bwd.layer_mlp_1_batch_0": - assert 1 == 1 - - id(tensor) if async_all_reduce is True: - # if isinstance(handle_idx, str): - # do_async = is_last_batch_of_attn(handle_idx) is False - # else: - # do_async = async_all_reduce - from nanotron.parallel.comm import is_async_comm - - do_async = is_async_comm(handle_idx) - - handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) - if do_async: - if "bwd" in handle_idx: - assert 1 == 1 - - # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx - # if handle_idx is not None and "bwd." in handle_idx: - # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) - # else: - # AsyncCommBucket.add(orig_id, handle) - # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx - assert handle_idx is not None - AsyncCommBucket.add(handle_idx, handle) + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) + AsyncCommBucket.add(handle_idx, handle) else: dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) From c26148899f2121ec4a189623dfcde0ec9ca71be0 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 21 Feb 2025 19:06:36 +0000 Subject: [PATCH 10/40] refactor --- src/nanotron/models/llama.py | 31 ++---- src/nanotron/optim/gradient_accumulator.py | 4 - src/nanotron/parallel/comm.py | 46 ++------ src/nanotron/parallel/dependency.py | 102 ------------------ .../parallel/pipeline_parallel/engine.py | 3 - .../parallel/tensor_parallel/domino.py | 18 ++++ .../parallel/tensor_parallel/functional.py | 21 +--- src/nanotron/parallel/tensor_parallel/nn.py | 4 - src/nanotron/trainer.py | 7 +- 9 files changed, 41 insertions(+), 195 deletions(-) delete mode 100644 src/nanotron/parallel/dependency.py create mode 100644 src/nanotron/parallel/tensor_parallel/domino.py diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 79f75855..585c8acf 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,10 +30,16 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import WaitComm +from nanotron.parallel.comm import CudaStreamManager, WaitComm from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.domino import ( + BWD_ATTN_HANDLE_IDX, + BWD_MLP_HANDLE_IDX, + FWD_ATTN_HANDLE_IDX, + FWD_MLP_HANDLE_IDX, +) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -49,11 +55,6 @@ DOMINO_COMM_STREAM = "domino_comm_stream_{}" -FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" -FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" -BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" -BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" - class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -730,17 +731,6 @@ def __init__( self.recompute_layer = parallel_config.recompute_layer self.parallel_config = parallel_config - - # if parallel_config.domino is not None and parallel_config.domino.num_input_batches > 1: - # from nanotron.parallel.comm import CudaStreamManager - # # NOTE: we use different cuda streams for different gpus, so it can overlaps the communication - # CudaStreamManager.create(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) - num_gpus = torch.cuda.device_count() - for i in range(num_gpus): - from nanotron import constants - - constants.CUDA_STREAMS[i] = torch.cuda.Stream(device=torch.device(f"cuda:{i}")) - self.layer_idx = layer_idx def _core_forward( @@ -748,12 +738,12 @@ def _core_forward( hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> List[Union[torch.Tensor, TensorPointer]]: - from nanotron import constants - num_input_batches = self.parallel_config.domino.num_input_batches + assert num_input_batches == 2 + orig_sequence_mask = sequence_mask + comm_stream = CudaStreamManager.get(f"comm_stream_{torch.cuda.current_device()}") - assert num_input_batches == 2 hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) @@ -785,7 +775,6 @@ def _core_forward( handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), ) - comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index b5ef7d89..2e940744 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -202,10 +202,6 @@ def build_grad_buffers( return fp32_grad_buffers, contiguous_buffer_f32_gradients def backward(self, loss: torch.Tensor): - if not isinstance(loss, torch.Tensor): - assert 1 == 1 - raise NotImplementedError("Not implemented yet") - result = loss.backward() for name, elt in self.fp32_grad_buffers.items(): diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 3b0e3bf4..e1aa664c 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -3,17 +3,21 @@ import torch +from nanotron.parallel.tensor_parallel.domino import is_async_comm + class CudaStreamManager: _streams: Dict[str, "torch.cuda.Stream"] = {} @staticmethod - def create(name: str): + def create(name: str, device: torch.device = None): assert name not in CudaStreamManager._streams - CudaStreamManager._streams[name] = torch.cuda.Stream() + CudaStreamManager._streams[name] = torch.cuda.Stream(device=device) @staticmethod def get(name: str): + if name not in CudaStreamManager._streams: + CudaStreamManager.create(name) return CudaStreamManager._streams.get(name) @contextmanager @@ -63,27 +67,6 @@ def clear_all(): AsyncCommBucket._copy_async_op.clear() -def is_async_comm(x): - import re - - NON_ASYNC_HANDLE_IDX = [ - # "fwd.layer_attn_{}_batch_0", - # "fwd.layer_mlp_{}_batch_0", - # "fwd.layer_mlp_{}_batch_1", - "bwd.layer_mlp_{}_batch_1", - "bwd.layer_attn_{}_batch_0", - ] - # NON_ASYNC_HANDLE_IDX = [ - # "fwd.layer_mlp_{}_batch_1", - # "bwd.layer_attn_{}_batch_0", - # ] - - patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers - regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex - not_async = bool(regex.match(x)) - return not not_async - - class WaitComm(torch.autograd.Function): @staticmethod def forward(ctx, input, wait_handle_idx): @@ -92,26 +75,9 @@ def forward(ctx, input, wait_handle_idx): @staticmethod def backward(ctx, grad_output): - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - - if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: - assert 1 == 1 - if is_async_comm(ctx.wait_handle_idx): - from nanotron.constants import _AUTOGRAD_RUNS - - _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}") handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" - else: - - from nanotron import constants - - # if dist.get_rank() == 0: - # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) - constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) return grad_output, None diff --git a/src/nanotron/parallel/dependency.py b/src/nanotron/parallel/dependency.py deleted file mode 100644 index 6a633d8a..00000000 --- a/src/nanotron/parallel/dependency.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Dict, Tuple - -import torch -from torch import Tensor - -_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} - - -def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: - """Gets a phony. Phony is tensor without space. It is useful to make - arbitrary dependency in a autograd graph because it doesn't require any - gradient accumulation. - - .. note:: - - Phonies for each device are cached. If an autograd function gets a phony - internally, the phony must be detached to be returned. Otherwise, the - autograd engine will mutate the cached phony in-place:: - - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() # detach() is necessary. - - """ - key = (device, requires_grad) - - try: - phony = _phonies[key] - except KeyError: - with torch.cuda.stream(torch.cuda.default_stream(device)): - phony = torch.empty(0, device=device, requires_grad=requires_grad) - - _phonies[key] = phony - - return phony - - -def fork(input: Tensor) -> Tuple[Tensor, Tensor]: - """Branches out from an autograd lane of the given tensor.""" - if torch.is_grad_enabled() and input.requires_grad: - input, phony = Fork.apply(input) - else: - phony = get_phony(input.device, requires_grad=False) - - return input, phony - - -class Fork(torch.autograd.Function): - @staticmethod - def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore - phony = get_phony(input.device, requires_grad=False) - return input, phony.detach() - - @staticmethod - def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - return grad_input - - -def join(input: Tensor, phony: Tensor) -> Tensor: - """Merges two autograd lanes.""" - if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): - input = Join.apply(input, phony) - - return input - - -class Join(torch.autograd.Function): - @staticmethod - def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore - return input - - @staticmethod - def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - return grad_input, None - - -# def depend(fork_from, join_to) -> None: -# # Ensure that batches[i-1] is executed after batches[i] in -# # # backpropagation by an explicit dependency. -# # if i != 0: -# # depend(batches[i-1], batches[i]) -# # depend(run_after, run_before) -# fork_from, phony = fork(fork_from) -# join_to = join(join_to, phony) -# return fork_from, join_to - - -def depend(run_after, run_before) -> None: - # Ensure that batches[i-1] is executed after batches[i] in - # # backpropagation by an explicit dependency. - # if i != 0: - # depend(batches[i-1], batches[i]) - # depend(run_after, run_before) - run_after, phony = fork(run_after) - run_before = join(run_before, phony) - return run_after, run_before diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 8160f302..076943c7 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -84,9 +84,6 @@ def backward( if grad_accumulator is None: sum(activations).backward() else: - # if not isinstance(activations, torch.Tensor): - # raise NotImplementedError("Only support sum of tensors for now") - grad_accumulator.backward(sum(activations)) # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py new file mode 100644 index 00000000..0139388b --- /dev/null +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -0,0 +1,18 @@ +import re + +FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" +FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" +BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" +BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" + + +def is_async_comm(x): + NON_ASYNC_HANDLE_IDX = [ + "bwd.layer_mlp_{}_batch_1", + "bwd.layer_attn_{}_batch_0", + ] + + patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers + regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex + not_async = bool(regex.match(x)) + return not not_async diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index f0ca3a0d..654a14f3 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,6 +19,7 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, @@ -40,7 +41,8 @@ def forward( logits_max = torch.max(sharded_logits, dim=-1)[0] dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group) # Subtract the maximum value. - sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1) + # sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1) + sharded_logits.sub_(logits_max.unsqueeze(dim=-1)) # Get the shard's indices sharded_hidden_size = sharded_logits.shape[-1] @@ -600,24 +602,11 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) - id(out) - # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? - if handle_idx == "fwd.layer_attn_0_batch_0": - assert 1 == 1 - out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) if async_all_reduce: - from nanotron.parallel.comm import AsyncCommBucket - - # work = AsyncCommBucket.get(orig_out_id) - # work = AsyncCommBucket.pop(orig_out_id) - # if handle_idx == "fwd.layer_mlp_1_batch_0": - if handle_idx == "fwd.layer_attn_0_batch_0": - assert 1 == 1 - work = AsyncCommBucket.pop(handle_idx) - assert 1 == 1 + else: + work = None elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 4fea1838..69e742ca 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -52,7 +52,6 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, - # handle_idx: Optional[int] = None, ): self.pg = pg self.world_size = pg.size() @@ -73,7 +72,6 @@ def __init__( self.mode = mode self.async_communication = async_communication - # self.handle_idx = handle_idx if contiguous_chunks is not None: assert ( @@ -115,7 +113,6 @@ def __init__( device=None, dtype=None, async_communication: bool = False, - # async_all_reduce: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, ): self.pg = pg @@ -138,7 +135,6 @@ def __init__( ) self.mode = mode self.async_communication = async_communication - # self.async_all_reduce = async_all_reduce if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index e58af9f5..5c5a3d2e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -564,9 +564,6 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) - if dist.get_rank() == 0: - assert 1 == 1 - # Apply gradient self.optimizer.step() self.optimizer.zero_grad() @@ -594,8 +591,8 @@ def training_step( # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: # assert 1 == 1 - assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" - assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" + # assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" + # assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" AsyncCommBucket.clear_all() return outputs, loss_avg From 841c7d6f721c7b6522558545b5b7f9e01f82ef46 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 21 Feb 2025 20:01:40 +0000 Subject: [PATCH 11/40] add tests and more refactoring --- src/nanotron/config/parallelism_config.py | 8 ++ src/nanotron/models/llama.py | 65 ++++++++----- src/nanotron/parallel/comm.py | 41 +++++--- tests/helpers/llama.py | 49 +++++----- tests/test_base_model.py | 44 ++++++++- tests/test_comm.py | 113 ++++++++++++++++++++++ tests/test_domino.py | 71 ++++++++++++++ 7 files changed, 327 insertions(+), 64 deletions(-) create mode 100644 tests/test_comm.py create mode 100644 tests/test_domino.py diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 2701bf9c..07688959 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -25,6 +25,7 @@ class DominoArgs: def __post_init__(self): assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1" + assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino" @dataclass @@ -68,3 +69,10 @@ def __post_init__(self): self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine) if isinstance(self.tp_mode, str): self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()] + + if self.is_domino_enabled is True: + assert self.tp > 1, "Domino requires TP > 1" + + @property + def is_domino_enabled(self) -> bool: + return True if self.domino else False diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 585c8acf..c95453f1 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -246,7 +246,6 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - # async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -347,7 +346,6 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, - async_all_reduce: bool = False, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -431,7 +429,6 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - # async_all_reduce=async_all_reduce, ) self.attention = CoreAttention( @@ -707,7 +704,7 @@ def forward( return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} -class LlamaDecoderLayer(nn.Module): +class _BaseLlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, @@ -723,7 +720,6 @@ def __init__( parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, - async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -733,6 +729,31 @@ def __init__( self.parallel_config = parallel_config self.layer_idx = layer_idx + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + sequence_mask: torch.Tensor, + ) -> List[torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): + hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) + else: + hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) + + return { + "hidden_states": hidden_states, + "sequence_mask": sequence_mask, + } + + +class DominoLlamaDecoderLayer(_BaseLlamaDecoderLayer): def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -785,7 +806,7 @@ def _core_forward( hidden_states0 = WaitComm.apply( hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - ) # new + ) mlp_output0 = self.mlp( hidden_states=hidden_states0, @@ -821,28 +842,26 @@ def _core_forward( hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) return hidden_states, orig_sequence_mask - def _checkpointed_forward( - self, - hidden_states: torch.Tensor, - sequence_mask: torch.Tensor, - ) -> List[torch.Tensor]: - return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) - def forward( +class LlamaDecoderLayer(_BaseLlamaDecoderLayer): + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + ) -> List[Union[torch.Tensor, TensorPointer]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) - if self.recompute_layer and not isinstance(hidden_states, TensorPointer): - hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) - else: - hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + hidden_states = hidden_states + residual - return { - "hidden_states": hidden_states, - "sequence_mask": sequence_mask, - } + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + hidden_states = hidden_states + residual + + return hidden_states, output["sequence_mask"] class Embedding(nn.Module, AttachableStore): @@ -919,7 +938,7 @@ def __init__( [ PipelineBlock( p2p=self.p2p, - module_builder=LlamaDecoderLayer, + module_builder=DominoLlamaDecoderLayer if parallel_config.is_domino_enabled else LlamaDecoderLayer, module_kwargs={ "config": config, "parallel_config": parallel_config, diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index e1aa664c..8a14c388 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -40,27 +40,42 @@ class AsyncCommBucket: _copy_async_op: Dict[int, "dist.Work"] = {} @staticmethod - def add(tensor_id: int, work: "dist.Work"): - assert ( - tensor_id not in AsyncCommBucket._async_op - ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}" - AsyncCommBucket._async_op[tensor_id] = work - AsyncCommBucket._copy_async_op[tensor_id] = work + def add(op_name: int, work: "dist.Work"): + assert op_name not in AsyncCommBucket._async_op, f"Operation with name: {op_name} already exists" + AsyncCommBucket._async_op[op_name] = work + AsyncCommBucket._copy_async_op[op_name] = work @staticmethod - def get(tensor_id: int): - return AsyncCommBucket._async_op.get(tensor_id) + def get(op_name: int): + if op_name not in AsyncCommBucket._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return AsyncCommBucket._async_op.get(op_name) @staticmethod - def pop(tensor_id: int): - assert tensor_id in AsyncCommBucket._async_op, f"tensor_id: {tensor_id}" - return AsyncCommBucket._async_op.pop(tensor_id) + def pop(op_name: int): + if op_name not in AsyncCommBucket._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return AsyncCommBucket._async_op.pop(op_name) @staticmethod - def wait(tensor_id: int): - work = AsyncCommBucket._async_op.pop(tensor_id) + def wait(op_name: int): + """Wait and remove the operation from the bucket""" + work = AsyncCommBucket.pop(op_name) work.wait() + @staticmethod + def is_all_completed() -> bool: + if not len(AsyncCommBucket._async_op) == 0: + return False + + not_finished = [] + for k, v in AsyncCommBucket._copy_async_op.items(): + if v.is_completed() is not True: + not_finished.append((k, v)) + return len(not_finished) == 0 + @staticmethod def clear_all(): AsyncCommBucket._async_op.clear() diff --git a/tests/helpers/llama.py b/tests/helpers/llama.py index 3f94031f..8aae7669 100644 --- a/tests/helpers/llama.py +++ b/tests/helpers/llama.py @@ -1,5 +1,6 @@ import torch from nanotron.config import ( + AdamWOptimizerArgs, AllForwardAllBackwardPipelineEngine, CheckpointsArgs, Config, @@ -46,7 +47,19 @@ ) -def get_llama_training_config(model_config: ModelArgs): +def get_parallel_config(parallel_context: ParallelContext): + return ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + expert_parallel_size=parallel_context.expert_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + +def get_llama_training_config(model_config: ModelArgs, parallel_context): return Config( model=model_config, general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), @@ -54,25 +67,20 @@ def get_llama_training_config(model_config: ModelArgs): checkpoints_path="./checkpoints", checkpoint_interval=10, ), - parallelism=ParallelismArgs( - dp=1, - pp=1, - tp=2, - expert_parallel_size=2, - pp_engine="1f1b", - tp_mode="ALL_REDUCE", - tp_linear_async_communication=False, - ), + parallelism=get_parallel_config(parallel_context), tokenizer=TokenizerArgs("gpt2"), optimizer=OptimizerArgs( zero_stage=0, weight_decay=0.01, clip_grad=1.0, accumulate_grad_in_fp32=False, - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + name="adamW", + ), learning_rate_scheduler=LRSchedulerArgs( learning_rate=3e-4, lr_warmup_steps=100, @@ -103,7 +111,10 @@ def get_llama_training_config(model_config: ModelArgs): def create_llama_from_config( - model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext + model_config: LlamaConfig, + parallel_config: ParallelismArgs, + device: torch.device, + parallel_context: ParallelContext, ) -> LlamaForTraining: """ @@ -114,14 +125,6 @@ def create_llama_from_config( the model created will have random weights. """ - parallel_config = ParallelismArgs( - dp=parallel_context.data_parallel_size, - pp=parallel_context.pipeline_parallel_size, - tp=parallel_context.tensor_parallel_size, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) model = build_model( model_builder=lambda: LlamaForTraining( config=model_config, diff --git a/tests/test_base_model.py b/tests/test_base_model.py index b4759905..78f7e433 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -3,28 +3,28 @@ import torch.distributed as dist from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config from helpers.utils import init_distributed, rerun_if_address_is_in_use -from nanotron.config import Config, ModelArgs, RandomInit +from nanotron.config import ModelArgs, RandomInit from nanotron.parallel import ParallelContext from nanotron.parallel.pipeline_parallel.block import PipelineBlock from torch import nn @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)]) -@pytest.mark.skip @rerun_if_address_is_in_use() def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int): model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) - config = get_llama_training_config(model_args) - init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(config=config) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(model_args=model_args) def _test_get_named_modules_in_pp_rank( parallel_context: ParallelContext, - config: Config, + model_args: ModelArgs, ): + config = get_llama_training_config(model_args, parallel_context) model = create_llama_from_config( model_config=config.model.model_config, + parallel_config=config.parallelism, device=torch.device("cuda"), parallel_context=parallel_context, ) @@ -43,3 +43,37 @@ def _test_get_named_modules_in_pp_rank( # not PipelineBlock assert isinstance(module, nn.Module) assert name not in modules_that_not_in_current_pp_rank + + +@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_llama_model(tp: int, dp: int, pp: int): + BATCH_SIZE, SEQ_LEN = 10, 128 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_llama_model)( + model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN + ) + + +def _test_llama_model( + parallel_context: ParallelContext, + model_args: ModelArgs, + batch_size: int, + seq_len: int, +): + config = get_llama_training_config(model_args, parallel_context) + llama_model = create_llama_from_config( + model_config=config.model.model_config, + parallel_config=config.parallelism, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama_model.init_model_randomly(config=config) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert list(outputs.keys()) == ["loss"] + assert isinstance(outputs["loss"], torch.Tensor) diff --git a/tests/test_comm.py b/tests/test_comm.py new file mode 100644 index 00000000..4039c61d --- /dev/null +++ b/tests/test_comm.py @@ -0,0 +1,113 @@ +import pytest +import torch +import torch.distributed as dist +from helpers.utils import ( + init_distributed, + rerun_if_address_is_in_use, +) +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket, WaitComm + + +class MockWork: + def __init__(self): + self.completed = False + self.wait_called = False + + def wait(self): + self.wait_called = True + self.completed = True + + def is_completed(self): + return self.completed + + +@rerun_if_address_is_in_use() +def test_add_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_add_async_op_to_bucket)() + + +def _test_add_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + tensor = torch.randn(1, device="cuda") + work = dist.all_reduce(tensor, async_op=True) + + AsyncCommBucket.add(OP_NAME, work) + + assert AsyncCommBucket.get(OP_NAME) is work + + +@rerun_if_address_is_in_use() +def test_wait_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + + AsyncCommBucket.add(OP_NAME, work) + assert work.is_completed() is False + + AsyncCommBucket.wait(OP_NAME) + assert work.is_completed() + with pytest.raises(KeyError): + AsyncCommBucket.get(OP_NAME) + + +@rerun_if_address_is_in_use() +def test_is_all_completed_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + + AsyncCommBucket.add(OP_NAME, work) + assert AsyncCommBucket.is_all_completed() is False + + AsyncCommBucket.wait(OP_NAME) + assert AsyncCommBucket.is_all_completed() is True + + +@rerun_if_address_is_in_use() +def test_clear_ops_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_clear_ops_in_async_bucket)() + + +def _test_clear_ops_in_async_bucket(parallel_context: ParallelContext): + AsyncCommBucket.add("test1", MockWork()) + AsyncCommBucket.add("test2", MockWork()) + AsyncCommBucket.add("test3", MockWork()) + + assert AsyncCommBucket.is_all_completed() is False + + AsyncCommBucket.clear_all() + assert AsyncCommBucket.is_all_completed() is True + with pytest.raises(KeyError): + AsyncCommBucket.get("test1") + + +@rerun_if_address_is_in_use() +def test_wait_comm(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_comm)() + + +def _test_wait_comm(parallel_context: ParallelContext): + tensor = torch.randn(1, device="cuda", requires_grad=True) + OP_NAME = "test" + + comm_stream = torch.cuda.Stream() + + with torch.cuda.stream(comm_stream): + work = MockWork() + AsyncCommBucket.add(OP_NAME, work) + + output = WaitComm.apply(tensor, OP_NAME) + assert work.is_completed() is False + + # NOTE: we test that it waits for the async op to complete + # automatically in autograd + (output + 1).backward() + assert work.is_completed() diff --git a/tests/test_domino.py b/tests/test_domino.py new file mode 100644 index 00000000..44d9d98a --- /dev/null +++ b/tests/test_domino.py @@ -0,0 +1,71 @@ +from copy import deepcopy + +import pytest +import torch +from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.utils import init_distributed, rerun_if_address_is_in_use +from nanotron.config import ModelArgs, RandomInit +from nanotron.config.parallelism_config import DominoArgs +from nanotron.models.llama import DominoLlamaDecoderLayer +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.tensor_parallel.domino import is_async_comm + + +@pytest.mark.parametrize( + "op_name, expected", + [ + ("fwd.layer_attn_1_batch_0", True), + ("fwd.layer_attn_1_batch_1", True), + ("fwd.layer_mlp_1_batch_0", True), + ("fwd.layer_mlp_1_batch_1", False), + ("bwd.layer_mlp_1_batch_1", True), + ("bwd.layer_mlp_1_batch_0", True), + ("bwd.layer_attn_1_batch_1", True), + ("bwd.layer_attn_1_batch_0", False), + ], +) +def test_is_async_comm(op_name, expected): + assert is_async_comm(op_name) == expected + + +@pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_domino_model(tp: int, dp: int, pp: int): + BATCH_SIZE, SEQ_LEN = 10, 128 + + model_config = deepcopy(TINY_LLAMA_CONFIG) + model_config.num_hidden_layers = 28 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_domino_model)( + model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN + ) + + +def _test_domino_model( + parallel_context: ParallelContext, + model_args: ModelArgs, + batch_size: int, + seq_len: int, +): + config = get_llama_training_config(model_args, parallel_context) + config.parallelism.domino = DominoArgs(num_input_batches=2) + + llama_model = create_llama_from_config( + model_config=config.model.model_config, + parallel_config=config.parallelism, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama_model.init_model_randomly(config=config) + + for m in llama_model.model.decoder: + assert isinstance(m.pp_block, DominoLlamaDecoderLayer) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert isinstance(outputs["loss"], torch.Tensor) + assert AsyncCommBucket.is_all_completed() is True From 8a0f993181ae81aa2ae2c8bfe5546ee8ce1a2740 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 24 Feb 2025 10:35:54 +0000 Subject: [PATCH 12/40] add domino config, fix breaks in _RowLinearAsyncCommunication --- examples/domino/domino_config.yaml | 105 ++++++++++++++++++ .../parallel/tensor_parallel/functional.py | 30 ++--- 2 files changed, 121 insertions(+), 14 deletions(-) create mode 100644 examples/domino/domino_config.yaml diff --git a/examples/domino/domino_config.yaml b/examples/domino/domino_config.yaml new file mode 100644 index 00000000..07e27c37 --- /dev/null +++ b/examples/domino/domino_config.yaml @@ -0,0 +1,105 @@ +checkpoints: + checkpoint_interval: 10000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + load_lr_scheduler: false + load_optimizer: false + save_final_state: true + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: + - /fsx/elie_bakouch/data/fw-edu-dedup + num_loading_workers: 0 + seed: 8 + name: stable phase + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: nanotron_domino + run: domino_config + seed: 6 + step: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.041666666666666664 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 4096 + num_attention_heads: 32 + num_hidden_layers: 15 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 2 + rms_norm_eps: 1.0e-05 + rope_interleaved: false + rope_scaling: + factor: 32.0 + high_freq_factor: 4.0 + low_freq_factor: 1.0 + original_max_position_embeddings: 4096 + rope_type: llama3 + rope_theta: 500000.0 + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.00005 + lr_decay_starting_step: 50000 + lr_decay_steps: 10000 + lr_decay_style: linear + lr_warmup_steps: 1000 + lr_warmup_style: linear + min_decay_lr: 0 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 1 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + recompute_layer: false + tp: 8 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + tp_recompute_allgather: false + domino: + num_input_batches: 2 +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Llama-3.2-3B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 2 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 8 + sequence_length: 4096 + train_steps: 15000 + val_check_interval: -1 diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 654a14f3..5ec5aa00 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -597,21 +597,23 @@ def row_linear( handle_idx=None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: - return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) - - out = F.linear(input, weight, bias) - - if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) - if async_all_reduce: - work = AsyncCommBucket.pop(handle_idx) - else: - work = None - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." - out = differentiable_reduce_scatter_sum(out, group=group) work = None + out = _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") + out = F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + out = differentiable_all_reduce_sum( + out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx + ) + if async_all_reduce: + work = AsyncCommBucket.pop(handle_idx) + else: + work = None + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." + out = differentiable_reduce_scatter_sum(out, group=group) + work = None + else: + raise ValueError(f"Got unexpected mode: {tp_mode}.") return out, work From a61d2df784c050de48560785fc8eba6b8c0713b0 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 25 Feb 2025 10:55:43 +0000 Subject: [PATCH 13/40] add bwd.layer_mlp_x_batch_1 as async op --- src/nanotron/parallel/comm.py | 3 +++ .../distributed_differentiable_primitives.py | 12 ++++++++---- src/nanotron/parallel/tensor_parallel/domino.py | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 8a14c388..03e39980 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -91,6 +91,9 @@ def forward(ctx, input, wait_handle_idx): @staticmethod def backward(ctx, grad_output): if is_async_comm(ctx.wait_handle_idx): + if "bwd.layer_mlp_27_batch_1" == ctx.wait_handle_idx: + assert 1 == 1 + handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 11501f97..b77e25e3 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,7 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.comm import AsyncCommBucket, is_async_comm class DifferentiableIdentity(torch.autograd.Function): @@ -34,12 +34,16 @@ def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, @staticmethod def backward(ctx, grad_output): - group = ctx.group - - from nanotron.parallel.comm import is_async_comm + import pydevd + pydevd.settrace(suspend=False, trace_only_current_thread=True) + group = ctx.group handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") if ctx.handle_idx is not None else None async_all_reduce = is_async_comm(handle_idx) if handle_idx is not None else ctx.async_all_reduce + + if handle_idx is not None and "bwd.layer_mlp_" in handle_idx and "batch_1" in handle_idx: + assert 1 == 1 + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 0139388b..50dfc95e 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -8,7 +8,7 @@ def is_async_comm(x): NON_ASYNC_HANDLE_IDX = [ - "bwd.layer_mlp_{}_batch_1", + # "bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] From 06e17bcb403d1b3213a2be1d21182b69139d35fb Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 25 Feb 2025 13:41:45 +0000 Subject: [PATCH 14/40] =?UTF-8?q?-=20add=20cuda=20stream=20sync=20after=20?= =?UTF-8?q?attn=5Foutput0[work]=20-=20execute=20backward=20comm=20in=20a?= =?UTF-8?q?=20separate=20stream=20=20=20=20=20-=20make=20commm=20stream=20?= =?UTF-8?q?in=20the=20backward=20pass=20wait=20for=20compute=20stream=20be?= =?UTF-8?q?fore=20run=20backward=20comm=20-=20make=20WaitComm=E2=80=99s=20?= =?UTF-8?q?compute=20stream=20to=20wait=20for=20the=20comm=20stream?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/nanotron/models/llama.py | 24 +++++++-- src/nanotron/parallel/comm.py | 10 +++- .../distributed_differentiable_primitives.py | 49 +++++++++++++------ .../parallel/tensor_parallel/domino.py | 2 + .../parallel/tensor_parallel/functional.py | 8 ++- src/nanotron/parallel/tensor_parallel/nn.py | 6 ++- src/nanotron/trainer.py | 2 +- 7 files changed, 75 insertions(+), 26 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index c95453f1..8e4682a6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -249,10 +249,12 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx) + def forward(self, hidden_states, handle_idx=None, comm_stream=None): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj( + hidden_states, async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream + ) hidden_states, work = self.down_proj( - self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx + self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream ) return {"hidden_states": hidden_states, "work": work} @@ -446,6 +448,7 @@ def forward( hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] handle_idx=None, + comm_stream=None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -454,7 +457,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, async_all_reduce=True, handle_idx=handle_idx + hidden_states, async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -699,7 +702,9 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx) + output, work = self.o_proj( + attention_output, async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream + ) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -779,38 +784,46 @@ def _core_forward( hidden_states0 = WaitComm.apply( hidden_states0, BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + comm_stream, ) hidden_states1 = WaitComm.apply( hidden_states1, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + comm_stream, ) attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + comm_stream=comm_stream, ) attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + comm_stream=comm_stream, ) with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() + torch.cuda.current_stream().wait_stream(comm_stream) + hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) hidden_states0 = WaitComm.apply( hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + comm_stream, ) mlp_output0 = self.mlp( hidden_states=hidden_states0, handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + comm_stream=comm_stream, ) with torch.cuda.stream(comm_stream): attn_output1["work"].wait() @@ -825,6 +838,7 @@ def _core_forward( mlp_output1 = self.mlp( hidden_states=hidden_states1, handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + comm_stream=comm_stream, ) with torch.cuda.stream(comm_stream): diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 03e39980..33c8b5d2 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -84,12 +84,15 @@ def clear_all(): class WaitComm(torch.autograd.Function): @staticmethod - def forward(ctx, input, wait_handle_idx): + def forward(ctx, input, wait_handle_idx, comm_stream): ctx.wait_handle_idx = wait_handle_idx + ctx.comm_stream = comm_stream return input @staticmethod def backward(ctx, grad_output): + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) if is_async_comm(ctx.wait_handle_idx): if "bwd.layer_mlp_27_batch_1" == ctx.wait_handle_idx: assert 1 == 1 @@ -97,5 +100,8 @@ def backward(ctx, grad_output): handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() + assert 1 == 1 - return grad_output, None + torch.cuda.current_stream().wait_stream(ctx.comm_stream) + + return grad_output, None, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index b77e25e3..7482ae77 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -26,17 +26,18 @@ class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None): + def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None, comm_stream=None): ctx.async_all_reduce = async_all_reduce ctx.handle_idx = handle_idx ctx.group = group + ctx.comm_stream = comm_stream return tensor @staticmethod def backward(ctx, grad_output): - import pydevd + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) - pydevd.settrace(suspend=False, trace_only_current_thread=True) group = ctx.group handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") if ctx.handle_idx is not None else None async_all_reduce = is_async_comm(handle_idx) if handle_idx is not None else ctx.async_all_reduce @@ -44,7 +45,13 @@ def backward(ctx, grad_output): if handle_idx is not None and "bwd.layer_mlp_" in handle_idx and "batch_1" in handle_idx: assert 1 == 1 - return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None + return ( + DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx, ctx.comm_stream), + None, + None, + None, + None, + ) class DifferentiableAllReduceSum(torch.autograd.Function): @@ -52,24 +59,38 @@ class DifferentiableAllReduceSum(torch.autograd.Function): @staticmethod def forward( - ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None + ctx, + tensor, + group: Optional[ProcessGroup], + async_all_reduce: bool, + handle_idx: Optional[int] = None, + comm_stream=None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: + from contextlib import nullcontext + ctx.async_all_reduce = async_all_reduce if group.size() == 1: return tensor - if async_all_reduce is True: - handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) - AsyncCommBucket.add(handle_idx, handle) + if comm_stream is not None: + comm_stream.wait_stream(torch.cuda.current_stream()) + comm_context = torch.cuda.stream(comm_stream) else: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + comm_context = nullcontext() + + with comm_context: + if async_all_reduce is True: + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) + AsyncCommBucket.add(handle_idx, handle) + else: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) return tensor @staticmethod def backward(ctx, grad_output): - return grad_output, None, None, None + return grad_output, None, None, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -152,15 +173,15 @@ def backward(ctx, grad_output): def differentiable_identity( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None, comm_stream=None ): - return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx) + return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx, comm_stream) def differentiable_all_reduce_sum( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None, comm_stream=None ): - return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx) + return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx, comm_stream) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 50dfc95e..10e8306d 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -9,8 +9,10 @@ def is_async_comm(x): NON_ASYNC_HANDLE_IDX = [ # "bwd.layer_mlp_{}_batch_1", + "fwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] + assert "fwd." not in x patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 5ec5aa00..005f4624 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -440,12 +440,15 @@ def column_linear( tp_recompute_allgather: bool = True, async_all_reduce: bool = False, handle_idx: Optional[int] = None, + comm_stream: Optional[torch.cuda.Stream] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) + input = differentiable_identity( + input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx, comm_stream=comm_stream + ) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -595,6 +598,7 @@ def row_linear( async_communication: bool, async_all_reduce: bool, handle_idx=None, + comm_stream=None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: work = None @@ -603,7 +607,7 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum( - out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx + out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx, comm_stream=comm_stream ) if async_all_reduce: work = AsyncCommBucket.pop(handle_idx) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 69e742ca..c27b5b87 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -85,7 +85,7 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None, comm_stream=None) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -96,6 +96,7 @@ def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> to tp_recompute_allgather=self.tp_recompute_allgather, async_all_reduce=async_all_reduce, handle_idx=handle_idx, + comm_stream=comm_stream, ) def extra_repr(self) -> str: @@ -160,7 +161,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None, comm_stream=None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -170,6 +171,7 @@ def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.T async_communication=self.async_communication, async_all_reduce=async_all_reduce, handle_idx=handle_idx, + comm_stream=comm_stream, ) def extra_repr(self) -> str: diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 5c5a3d2e..ecfc9d18 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -592,7 +592,7 @@ def training_step( # assert 1 == 1 # assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" - # assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" + assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" AsyncCommBucket.clear_all() return outputs, loss_avg From 8d449429573a5c8cbb8136448ab5a4e011d67480 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 25 Feb 2025 13:59:59 +0000 Subject: [PATCH 15/40] wait default_stream instead of current_stream --- src/nanotron/models/llama.py | 9 ++++++--- src/nanotron/parallel/comm.py | 2 +- .../distributed_differentiable_primitives.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 8e4682a6..3f437c99 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -805,11 +805,12 @@ def _core_forward( comm_stream=comm_stream, ) + comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() - torch.cuda.current_stream().wait_stream(comm_stream) + torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 @@ -825,11 +826,12 @@ def _core_forward( handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), comm_stream=comm_stream, ) + comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output1["work"].wait() attn_output1["work"].is_completed() - torch.cuda.current_stream().wait_stream(comm_stream) + torch.cuda.default_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 residual1 = hidden_states1 @@ -841,6 +843,7 @@ def _core_forward( comm_stream=comm_stream, ) + comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() mlp_output1["work"].wait() @@ -848,7 +851,7 @@ def _core_forward( mlp_output0["work"].is_completed() mlp_output1["work"].is_completed() - torch.cuda.current_stream().wait_stream(comm_stream) + torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 33c8b5d2..33db29ed 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -102,6 +102,6 @@ def backward(ctx, grad_output): handle.wait() assert 1 == 1 - torch.cuda.current_stream().wait_stream(ctx.comm_stream) + torch.cuda.default_stream().wait_stream(ctx.comm_stream) return grad_output, None, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 7482ae77..e7f7f3fd 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -74,7 +74,7 @@ def forward( return tensor if comm_stream is not None: - comm_stream.wait_stream(torch.cuda.current_stream()) + comm_stream.wait_stream(torch.cuda.default_stream()) comm_context = torch.cuda.stream(comm_stream) else: comm_context = nullcontext() From aa77e6cd20c50d4df82a640f216ccf3487e3d989 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 25 Feb 2025 15:24:42 +0000 Subject: [PATCH 16/40] put torch.cuda.synchronize() everywhere --- src/nanotron/models/llama.py | 12 ++++++++++++ src/nanotron/parallel/comm.py | 2 ++ src/nanotron/parallel/tensor_parallel/domino.py | 4 ++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 3f437c99..3fb549e1 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -805,11 +805,15 @@ def _core_forward( comm_stream=comm_stream, ) + assert torch.cuda.current_stream() == torch.cuda.default_stream() + torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() + assert torch.cuda.current_stream() == torch.cuda.default_stream() + torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = attn_output0["hidden_states"] + residual0 @@ -826,11 +830,15 @@ def _core_forward( handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), comm_stream=comm_stream, ) + + assert torch.cuda.current_stream() == torch.cuda.default_stream() + torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output1["work"].wait() attn_output1["work"].is_completed() + torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 @@ -843,6 +851,8 @@ def _core_forward( comm_stream=comm_stream, ) + assert torch.cuda.current_stream() == torch.cuda.default_stream() + torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() @@ -851,6 +861,8 @@ def _core_forward( mlp_output0["work"].is_completed() mlp_output1["work"].is_completed() + assert torch.cuda.current_stream() == torch.cuda.default_stream() + torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = mlp_output0["hidden_states"] + residual0 diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 33db29ed..15c3f64f 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -102,6 +102,8 @@ def backward(ctx, grad_output): handle.wait() assert 1 == 1 + torch.cuda.synchronize() + assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(ctx.comm_stream) return grad_output, None, None diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 10e8306d..1eba8a4c 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -8,8 +8,8 @@ def is_async_comm(x): NON_ASYNC_HANDLE_IDX = [ - # "bwd.layer_mlp_{}_batch_1", - "fwd.layer_mlp_{}_batch_1", + # "fwd.layer_mlp_{}_batch_1", + "bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] assert "fwd." not in x From 76b5f9a78dde92549810edf3dbff80305f435d41 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 25 Feb 2025 15:37:49 +0000 Subject: [PATCH 17/40] only bwd.layer_attn_{}_batch_0 as non async --- src/nanotron/parallel/tensor_parallel/domino.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 1eba8a4c..1d34174e 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -9,7 +9,7 @@ def is_async_comm(x): NON_ASYNC_HANDLE_IDX = [ # "fwd.layer_mlp_{}_batch_1", - "bwd.layer_mlp_{}_batch_1", + # "bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] assert "fwd." not in x From fe7ee7ea071e9ccfc12d681f1635323709461042 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 25 Feb 2025 17:37:22 +0000 Subject: [PATCH 18/40] exp7a7_like_exp7a6_but_remove_fwd_pass_cuda_syncronization --- src/nanotron/models/llama.py | 12 ++++++------ src/nanotron/parallel/tensor_parallel/domino.py | 3 ++- src/nanotron/trainer.py | 5 ++++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 3fb549e1..f558c102 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -806,14 +806,14 @@ def _core_forward( ) assert torch.cuda.current_stream() == torch.cuda.default_stream() - torch.cuda.synchronize() + # torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() assert torch.cuda.current_stream() == torch.cuda.default_stream() - torch.cuda.synchronize() + # torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = attn_output0["hidden_states"] + residual0 @@ -832,13 +832,13 @@ def _core_forward( ) assert torch.cuda.current_stream() == torch.cuda.default_stream() - torch.cuda.synchronize() + # torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output1["work"].wait() attn_output1["work"].is_completed() - torch.cuda.synchronize() + # torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 @@ -852,7 +852,7 @@ def _core_forward( ) assert torch.cuda.current_stream() == torch.cuda.default_stream() - torch.cuda.synchronize() + # torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() @@ -862,7 +862,7 @@ def _core_forward( mlp_output1["work"].is_completed() assert torch.cuda.current_stream() == torch.cuda.default_stream() - torch.cuda.synchronize() + # torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = mlp_output0["hidden_states"] + residual0 diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 1d34174e..97dbe0dc 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -8,11 +8,12 @@ def is_async_comm(x): NON_ASYNC_HANDLE_IDX = [ + # NOTE: execute all fwd's comm in async # "fwd.layer_mlp_{}_batch_1", # "bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] - assert "fwd." not in x + # assert "fwd." not in x patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index ecfc9d18..a1c71a09 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -418,6 +418,7 @@ def train( ], **kwargs, ) -> None: + # torch.cuda.set_sync_debug_mode("warn") self.pre_training(**kwargs) if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: @@ -579,19 +580,21 @@ def training_step( self.post_train_step() from nanotron.parallel.comm import AsyncCommBucket + from nanotron.parallel.tensor_parallel.domino import is_async_comm # import torch.distributed as dist not_finished = [] for k, v in AsyncCommBucket._copy_async_op.items(): # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}" + assert is_async_comm(k) is True, f"k: {k}" if v.is_completed() is not True: not_finished.append((k, v)) # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: # assert 1 == 1 - # assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" + assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" AsyncCommBucket.clear_all() From e0a9bd0d89854bb61ebc9199f15c2f06920581bb Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 26 Feb 2025 12:27:01 +0000 Subject: [PATCH 19/40] remove torch.cuda.synchronize in WaitComm.backward --- src/nanotron/constants.py | 7 ------- src/nanotron/models/llama.py | 6 ------ src/nanotron/parallel/comm.py | 9 ++++++++- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 3fe440a8..580bd99d 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -10,10 +10,3 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" - - -CUDA_STREAMS = {} - -CLOCK = 0 -_AUTOGRAD_RUNS = [] -_NOT_BWD_ASYNC_OPS = [] diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index f558c102..87e4a883 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -806,14 +806,12 @@ def _core_forward( ) assert torch.cuda.current_stream() == torch.cuda.default_stream() - # torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() assert torch.cuda.current_stream() == torch.cuda.default_stream() - # torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = attn_output0["hidden_states"] + residual0 @@ -832,13 +830,11 @@ def _core_forward( ) assert torch.cuda.current_stream() == torch.cuda.default_stream() - # torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output1["work"].wait() attn_output1["work"].is_completed() - # torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 @@ -852,7 +848,6 @@ def _core_forward( ) assert torch.cuda.current_stream() == torch.cuda.default_stream() - # torch.cuda.synchronize() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() @@ -862,7 +857,6 @@ def _core_forward( mlp_output1["work"].is_completed() assert torch.cuda.current_stream() == torch.cuda.default_stream() - # torch.cuda.synchronize() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = mlp_output0["hidden_states"] + residual0 diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 15c3f64f..bb624b77 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -91,6 +91,13 @@ def forward(ctx, input, wait_handle_idx, comm_stream): @staticmethod def backward(ctx, grad_output): + """ + + NOTE: because the communication operation is already being executed + so the communication stream don't have to wait for the compute stream here + but the compute stream waits for the communication stream + before proceeding + """ # import pydevd # pydevd.settrace(suspend=False, trace_only_current_thread=True) if is_async_comm(ctx.wait_handle_idx): @@ -102,7 +109,7 @@ def backward(ctx, grad_output): handle.wait() assert 1 == 1 - torch.cuda.synchronize() + # torch.cuda.synchronize() assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(ctx.comm_stream) From a772ff06c7ae0bd0d231a3b9772aed8d599283b2 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 26 Feb 2025 15:42:51 +0000 Subject: [PATCH 20/40] add back torch.cuda.synchronize in WaitComm.backward and small refactors --- src/nanotron/models/llama.py | 2 -- src/nanotron/parallel/comm.py | 6 +----- src/nanotron/parallel/tensor_parallel/domino.py | 7 ++----- src/nanotron/parallel/tensor_parallel/functional.py | 1 - src/nanotron/trainer.py | 9 ++------- 5 files changed, 5 insertions(+), 20 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 87e4a883..9ec70b24 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -53,8 +53,6 @@ logger = logging.get_logger(__name__) -DOMINO_COMM_STREAM = "domino_comm_stream_{}" - class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index bb624b77..42c70b41 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -101,15 +101,11 @@ def backward(ctx, grad_output): # import pydevd # pydevd.settrace(suspend=False, trace_only_current_thread=True) if is_async_comm(ctx.wait_handle_idx): - if "bwd.layer_mlp_27_batch_1" == ctx.wait_handle_idx: - assert 1 == 1 - handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - assert 1 == 1 - # torch.cuda.synchronize() + torch.cuda.synchronize() assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(ctx.comm_stream) diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 97dbe0dc..59bcbb76 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -6,14 +6,11 @@ BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" -def is_async_comm(x): +def is_async_comm(x: str) -> bool: NON_ASYNC_HANDLE_IDX = [ - # NOTE: execute all fwd's comm in async - # "fwd.layer_mlp_{}_batch_1", - # "bwd.layer_mlp_{}_batch_1", + "fwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] - # assert "fwd." not in x patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 005f4624..a5256006 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -41,7 +41,6 @@ def forward( logits_max = torch.max(sharded_logits, dim=-1)[0] dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group) # Subtract the maximum value. - # sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1) sharded_logits.sub_(logits_max.unsqueeze(dim=-1)) # Get the shard's indices diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index a1c71a09..8f4401aa 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -582,8 +582,6 @@ def training_step( from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.tensor_parallel.domino import is_async_comm - # import torch.distributed as dist - not_finished = [] for k, v in AsyncCommBucket._copy_async_op.items(): # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}" @@ -591,11 +589,8 @@ def training_step( if v.is_completed() is not True: not_finished.append((k, v)) - # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: - # assert 1 == 1 - - assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" - assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" + # assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" + # assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" AsyncCommBucket.clear_all() return outputs, loss_avg From 543ef564a02ae19b4654cf2590f711c65e186d70 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 27 Feb 2025 13:40:02 +0000 Subject: [PATCH 21/40] add ctx.comm_stream.wait_stream(torch.cuda.default_stream()) to WaitComm, and remove torch.cuda.synchronize() in WaitComm --- src/nanotron/models/llama.py | 2 ++ src/nanotron/parallel/comm.py | 3 ++- .../tensor_parallel/distributed_differentiable_primitives.py | 1 + src/nanotron/parallel/tensor_parallel/domino.py | 2 +- src/nanotron/trainer.py | 2 +- 5 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 9ec70b24..185469d3 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -856,6 +856,8 @@ def _core_forward( assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(comm_stream) + # NOTE: before concat, we need to synchronize the streams + # torch.cuda.synchronize() hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 42c70b41..e9419e23 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -101,11 +101,12 @@ def backward(ctx, grad_output): # import pydevd # pydevd.settrace(suspend=False, trace_only_current_thread=True) if is_async_comm(ctx.wait_handle_idx): + ctx.comm_stream.wait_stream(torch.cuda.default_stream()) handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - torch.cuda.synchronize() + # torch.cuda.synchronize() assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(ctx.comm_stream) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index e7f7f3fd..89dd499b 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -81,6 +81,7 @@ def forward( with comm_context: if async_all_reduce is True: + assert comm_stream is not None handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) AsyncCommBucket.add(handle_idx, handle) else: diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 59bcbb76..f47e34fc 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -8,7 +8,7 @@ def is_async_comm(x: str) -> bool: NON_ASYNC_HANDLE_IDX = [ - "fwd.layer_mlp_{}_batch_1", + # "fwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 8f4401aa..6362fe43 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -418,7 +418,7 @@ def train( ], **kwargs, ) -> None: - # torch.cuda.set_sync_debug_mode("warn") + torch.cuda.set_sync_debug_mode("warn") self.pre_training(**kwargs) if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: From 36c998069ff173eba9b81781d7891b6dea56279c Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 27 Feb 2025 13:54:41 +0000 Subject: [PATCH 22/40] exp7a10_like_exp7a6_but_remove_fwd_pass_cuda_syncronization_and_remove_cuda_syncronize_in_wait_comm_bwd_and_add_comm_syncronize_in_waitcomm_and_commit_543ef56 --- src/nanotron/parallel/comm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index e9419e23..1f6c0b71 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -101,12 +101,13 @@ def backward(ctx, grad_output): # import pydevd # pydevd.settrace(suspend=False, trace_only_current_thread=True) if is_async_comm(ctx.wait_handle_idx): - ctx.comm_stream.wait_stream(torch.cuda.default_stream()) + # ctx.comm_stream.wait_stream(torch.cuda.default_stream()) handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() # torch.cuda.synchronize() + ctx.comm_stream.synchronize() assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(ctx.comm_stream) From 613eb16836b2883c9941f7b09ade4c094c085232 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 28 Feb 2025 15:00:09 +0000 Subject: [PATCH 23/40] remove comments and add typing --- src/nanotron/models/llama.py | 10 ------ src/nanotron/parallel/comm.py | 7 +--- .../distributed_differentiable_primitives.py | 36 +++++++++++-------- .../parallel/tensor_parallel/functional.py | 6 ++-- src/nanotron/parallel/tensor_parallel/nn.py | 16 +++++++-- src/nanotron/trainer.py | 14 +------- 6 files changed, 41 insertions(+), 48 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 185469d3..0151f8e9 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -803,13 +803,10 @@ def _core_forward( comm_stream=comm_stream, ) - assert torch.cuda.current_stream() == torch.cuda.default_stream() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() - - assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = attn_output0["hidden_states"] + residual0 @@ -827,12 +824,10 @@ def _core_forward( comm_stream=comm_stream, ) - assert torch.cuda.current_stream() == torch.cuda.default_stream() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output1["work"].wait() attn_output1["work"].is_completed() - torch.cuda.default_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 @@ -845,7 +840,6 @@ def _core_forward( comm_stream=comm_stream, ) - assert torch.cuda.current_stream() == torch.cuda.default_stream() comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() @@ -853,11 +847,7 @@ def _core_forward( mlp_output0["work"].is_completed() mlp_output1["work"].is_completed() - - assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(comm_stream) - # NOTE: before concat, we need to synchronize the streams - # torch.cuda.synchronize() hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 1f6c0b71..47deb265 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -72,6 +72,7 @@ def is_all_completed() -> bool: not_finished = [] for k, v in AsyncCommBucket._copy_async_op.items(): + assert is_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" if v.is_completed() is not True: not_finished.append((k, v)) return len(not_finished) == 0 @@ -92,23 +93,17 @@ def forward(ctx, input, wait_handle_idx, comm_stream): @staticmethod def backward(ctx, grad_output): """ - NOTE: because the communication operation is already being executed so the communication stream don't have to wait for the compute stream here but the compute stream waits for the communication stream before proceeding """ - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) if is_async_comm(ctx.wait_handle_idx): - # ctx.comm_stream.wait_stream(torch.cuda.default_stream()) handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - # torch.cuda.synchronize() ctx.comm_stream.synchronize() - assert torch.cuda.current_stream() == torch.cuda.default_stream() torch.cuda.default_stream().wait_stream(ctx.comm_stream) return grad_output, None, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 89dd499b..3d06413b 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext from typing import Optional, Tuple import torch @@ -26,7 +27,14 @@ class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None, comm_stream=None): + def forward( + ctx, + tensor: torch.Tensor, + group: Optional[ProcessGroup], + async_all_reduce: bool, + handle_idx: Optional[str] = None, + comm_stream: Optional[torch.cuda.Stream] = None, + ): ctx.async_all_reduce = async_all_reduce ctx.handle_idx = handle_idx ctx.group = group @@ -34,17 +42,11 @@ def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, return tensor @staticmethod - def backward(ctx, grad_output): - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - + def backward(ctx, grad_output: torch.Tensor): group = ctx.group handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") if ctx.handle_idx is not None else None async_all_reduce = is_async_comm(handle_idx) if handle_idx is not None else ctx.async_all_reduce - if handle_idx is not None and "bwd.layer_mlp_" in handle_idx and "batch_1" in handle_idx: - assert 1 == 1 - return ( DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx, ctx.comm_stream), None, @@ -60,14 +62,12 @@ class DifferentiableAllReduceSum(torch.autograd.Function): @staticmethod def forward( ctx, - tensor, + tensor: torch.Tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None, - comm_stream=None, + comm_stream: Optional[torch.cuda.Stream] = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: - from contextlib import nullcontext - ctx.async_all_reduce = async_all_reduce if group.size() == 1: @@ -174,13 +174,21 @@ def backward(ctx, grad_output): def differentiable_identity( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None, comm_stream=None + tensor, + group: Optional[ProcessGroup] = None, + async_all_reduce: bool = False, + handle_idx: Optional[str] = None, + comm_stream: Optional[torch.cuda.Stream] = None, ): return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx, comm_stream) def differentiable_all_reduce_sum( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None, comm_stream=None + tensor, + group: Optional[ProcessGroup] = None, + async_all_reduce: bool = False, + handle_idx: Optional[str] = None, + comm_stream: Optional[torch.cuda.Stream] = None, ): return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx, comm_stream) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index a5256006..291923a9 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -438,7 +438,7 @@ def column_linear( async_communication: bool, tp_recompute_allgather: bool = True, async_all_reduce: bool = False, - handle_idx: Optional[int] = None, + handle_idx: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): if async_communication: @@ -596,8 +596,8 @@ def row_linear( # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, async_all_reduce: bool, - handle_idx=None, - comm_stream=None, + handle_idx: Optional[str] = None, + comm_stream: Optional[torch.cuda.Stream] = None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: work = None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index c27b5b87..bba46523 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -85,7 +85,13 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None, comm_stream=None) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + async_all_reduce=None, + handle_idx: Optional[str] = None, + comm_stream: Optional[torch.cuda.Stream] = None, + ) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -161,7 +167,13 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None, comm_stream=None) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + async_all_reduce, + handle_idx: Optional[str] = None, + comm_stream: Optional[torch.cuda.Stream] = None, + ) -> torch.Tensor: return row_linear( input=x, weight=self.weight, diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 6362fe43..a11b91a2 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -61,6 +61,7 @@ from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.parallel.pipeline_parallel.engine import ( @@ -418,7 +419,6 @@ def train( ], **kwargs, ) -> None: - torch.cuda.set_sync_debug_mode("warn") self.pre_training(**kwargs) if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: @@ -579,18 +579,6 @@ def training_step( self.post_train_step() - from nanotron.parallel.comm import AsyncCommBucket - from nanotron.parallel.tensor_parallel.domino import is_async_comm - - not_finished = [] - for k, v in AsyncCommBucket._copy_async_op.items(): - # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}" - assert is_async_comm(k) is True, f"k: {k}" - if v.is_completed() is not True: - not_finished.append((k, v)) - - # assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" - # assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" AsyncCommBucket.clear_all() return outputs, loss_avg From 600f01ab735ca66c40ad7ee09da420b8cf1704bc Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 28 Feb 2025 16:01:58 +0000 Subject: [PATCH 24/40] remove explicite async_op arg --- src/nanotron/models/llama.py | 42 +++++++++---------- src/nanotron/parallel/comm.py | 14 +++---- .../distributed_differentiable_primitives.py | 25 ++++++----- .../parallel/tensor_parallel/domino.py | 8 ++-- .../parallel/tensor_parallel/functional.py | 32 +++++++------- src/nanotron/parallel/tensor_parallel/nn.py | 14 +++---- 6 files changed, 66 insertions(+), 69 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 0151f8e9..cb612c2c 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -35,10 +35,10 @@ from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.tensor_parallel.domino import ( - BWD_ATTN_HANDLE_IDX, - BWD_MLP_HANDLE_IDX, - FWD_ATTN_HANDLE_IDX, - FWD_MLP_HANDLE_IDX, + BWD_ATTN_OP_NAME, + BWD_MLP_OP_NAME, + FWD_ATTN_OP_NAME, + FWD_MLP_OP_NAME, ) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -247,12 +247,12 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states, handle_idx=None, comm_stream=None): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj( - hidden_states, async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream - ) + def forward( + self, hidden_states, op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None + ): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states, op_name=op_name, comm_stream=comm_stream) hidden_states, work = self.down_proj( - self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream + self.split_silu_mul(merged_states), op_name=op_name, comm_stream=comm_stream ) return {"hidden_states": hidden_states, "work": work} @@ -445,8 +445,8 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] - handle_idx=None, - comm_stream=None, + op_name: Optional[str] = None, + comm_stream: Optional[torch.cuda.Stream] = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -455,7 +455,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream + hidden_states, op_name=op_name, comm_stream=comm_stream ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -700,9 +700,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj( - attention_output, async_all_reduce=True, handle_idx=handle_idx, comm_stream=comm_stream - ) + output, work = self.o_proj(attention_output, op_name=op_name, comm_stream=comm_stream) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -781,25 +779,25 @@ def _core_forward( hidden_states1 = self.input_layernorm(hidden_states1) hidden_states0 = WaitComm.apply( hidden_states0, - BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + BWD_ATTN_OP_NAME.format(self.layer_idx, 1), comm_stream, ) hidden_states1 = WaitComm.apply( hidden_states1, - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + BWD_MLP_OP_NAME.format(self.layer_idx, 0), comm_stream, ) attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, - handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0), comm_stream=comm_stream, ) attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, - handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1), comm_stream=comm_stream, ) @@ -814,13 +812,13 @@ def _core_forward( hidden_states0 = self.post_attention_layernorm(hidden_states0) hidden_states0 = WaitComm.apply( hidden_states0, - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + BWD_MLP_OP_NAME.format(self.layer_idx, 1), comm_stream, ) mlp_output0 = self.mlp( hidden_states=hidden_states0, - handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0), comm_stream=comm_stream, ) @@ -836,7 +834,7 @@ def _core_forward( mlp_output1 = self.mlp( hidden_states=hidden_states1, - handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1), comm_stream=comm_stream, ) diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 47deb265..ef1f87f4 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -46,21 +46,21 @@ def add(op_name: int, work: "dist.Work"): AsyncCommBucket._copy_async_op[op_name] = work @staticmethod - def get(op_name: int): + def get(op_name: str) -> "dist.Work": if op_name not in AsyncCommBucket._async_op: raise KeyError(f"Operation with name: {op_name} doesn't exist") return AsyncCommBucket._async_op.get(op_name) @staticmethod - def pop(op_name: int): + def pop(op_name: str) -> "dist.Work": if op_name not in AsyncCommBucket._async_op: raise KeyError(f"Operation with name: {op_name} doesn't exist") return AsyncCommBucket._async_op.pop(op_name) @staticmethod - def wait(op_name: int): + def wait(op_name: str): """Wait and remove the operation from the bucket""" work = AsyncCommBucket.pop(op_name) work.wait() @@ -85,8 +85,8 @@ def clear_all(): class WaitComm(torch.autograd.Function): @staticmethod - def forward(ctx, input, wait_handle_idx, comm_stream): - ctx.wait_handle_idx = wait_handle_idx + def forward(ctx, input, op_name, comm_stream): + ctx.op_name = op_name ctx.comm_stream = comm_stream return input @@ -98,8 +98,8 @@ def backward(ctx, grad_output): but the compute stream waits for the communication stream before proceeding """ - if is_async_comm(ctx.wait_handle_idx): - handle = AsyncCommBucket.pop(ctx.wait_handle_idx) + if is_async_comm(ctx.op_name): + handle = AsyncCommBucket.pop(ctx.op_name) assert handle is not None handle.wait() diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 3d06413b..7cc1dc06 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -32,11 +32,11 @@ def forward( tensor: torch.Tensor, group: Optional[ProcessGroup], async_all_reduce: bool, - handle_idx: Optional[str] = None, + op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): ctx.async_all_reduce = async_all_reduce - ctx.handle_idx = handle_idx + ctx.op_name = op_name ctx.group = group ctx.comm_stream = comm_stream return tensor @@ -44,11 +44,11 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): group = ctx.group - handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") if ctx.handle_idx is not None else None - async_all_reduce = is_async_comm(handle_idx) if handle_idx is not None else ctx.async_all_reduce + op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else None + # async_all_reduce = is_async_comm(op_name) if op_name is not None else ctx.async_all_reduce return ( - DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx, ctx.comm_stream), + DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.comm_stream), None, None, None, @@ -64,10 +64,10 @@ def forward( ctx, tensor: torch.Tensor, group: Optional[ProcessGroup], - async_all_reduce: bool, - handle_idx: Optional[int] = None, + op_name: Optional[int] = None, comm_stream: Optional[torch.cuda.Stream] = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: + async_all_reduce = is_async_comm(op_name) if op_name is not None else False ctx.async_all_reduce = async_all_reduce if group.size() == 1: @@ -83,7 +83,7 @@ def forward( if async_all_reduce is True: assert comm_stream is not None handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) - AsyncCommBucket.add(handle_idx, handle) + AsyncCommBucket.add(op_name, handle) else: dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) @@ -177,20 +177,19 @@ def differentiable_identity( tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, - handle_idx: Optional[str] = None, + op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): - return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx, comm_stream) + return DifferentiableIdentity.apply(tensor, group, async_all_reduce, op_name, comm_stream) def differentiable_all_reduce_sum( tensor, group: Optional[ProcessGroup] = None, - async_all_reduce: bool = False, - handle_idx: Optional[str] = None, + op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): - return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx, comm_stream) + return DifferentiableAllReduceSum.apply(tensor, group, op_name, comm_stream) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index f47e34fc..7c4a10b8 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -1,9 +1,9 @@ import re -FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" -FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" -BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" -BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" +FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}" +FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}" +BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}" +BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}" def is_async_comm(x: str) -> bool: diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 291923a9..ac645dfb 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -25,6 +25,7 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) +from nanotron.parallel.tensor_parallel.domino import is_async_comm from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -437,17 +438,19 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, - async_all_reduce: bool = False, - handle_idx: Optional[str] = None, + op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): + is_async_all_reduce = is_async_comm(op_name) if op_name is not None else False + assert not ( + is_async_all_reduce and async_communication + ), "DoMiNo isn't support weight's async communication for column linear." + if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity( - input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx, comm_stream=comm_stream - ) + input = differentiable_identity(input, group=group, op_name=op_name, comm_stream=comm_stream) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -593,27 +596,28 @@ def row_linear( bias: Optional[torch.Tensor], group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, - # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, - async_all_reduce: bool, - handle_idx: Optional[str] = None, + op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: + is_async_all_reduce = is_async_comm(op_name) if op_name is not None else False + assert not ( + is_async_all_reduce and async_communication + ), "DoMiNo isn't support weight's async communication for row linear." + if async_communication: work = None out = _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) else: out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum( - out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx, comm_stream=comm_stream - ) - if async_all_reduce: - work = AsyncCommBucket.pop(handle_idx) + out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, comm_stream=comm_stream) + if is_async_all_reduce: + work = AsyncCommBucket.pop(op_name) else: work = None elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." + assert is_async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) work = None else: diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index bba46523..9ad2a51b 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -88,8 +88,7 @@ def __init__( def forward( self, x: torch.Tensor, - async_all_reduce=None, - handle_idx: Optional[str] = None, + op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ) -> torch.Tensor: return column_linear( @@ -100,8 +99,7 @@ def forward( tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, - async_all_reduce=async_all_reduce, - handle_idx=handle_idx, + op_name=op_name, comm_stream=comm_stream, ) @@ -170,8 +168,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): def forward( self, x: torch.Tensor, - async_all_reduce, - handle_idx: Optional[str] = None, + op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ) -> torch.Tensor: return row_linear( @@ -181,8 +178,7 @@ def forward( group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - async_all_reduce=async_all_reduce, - handle_idx=handle_idx, + op_name=op_name, comm_stream=comm_stream, ) @@ -308,7 +304,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: out = out * (~input_mask[..., None]) if self.mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) + out = differentiable_all_reduce_sum(out, group=self.pg) elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=self.pg) else: From 29a8914630679969c8914f9745a0ac71e37176d6 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 7 Mar 2025 12:16:06 +0000 Subject: [PATCH 25/40] pass stream amanger to llama's modules --- src/nanotron/models/llama.py | 49 ++++++++++++------- src/nanotron/parallel/comm.py | 42 +++++++--------- .../distributed_differentiable_primitives.py | 6 +-- src/nanotron/parallel/tensor_parallel/nn.py | 16 +++--- tests/test_comm.py | 10 +++- 5 files changed, 66 insertions(+), 57 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index cb612c2c..85938726 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -214,6 +214,7 @@ def __init__( config: LlamaConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() @@ -227,6 +228,7 @@ def __init__( config.intermediate_size, # shape of gate_linear config.intermediate_size, # shape of up_linear ) + comm_stream = stream_manager.get(f"comm_stream_{torch.cuda.current_device()}") self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, 2 * config.intermediate_size, @@ -236,6 +238,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, + comm_stream=comm_stream, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -244,16 +247,15 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + comm_stream=comm_stream, ) self.split_silu_mul = GLUActivation(config.hidden_act) def forward( - self, hidden_states, op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None + self, hidden_states: torch.Tensor, op_name: Optional[str] = None ): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states, op_name=op_name, comm_stream=comm_stream) - hidden_states, work = self.down_proj( - self.split_silu_mul(merged_states), op_name=op_name, comm_stream=comm_stream - ) + merged_states = self.gate_up_proj(hidden_states, op_name=op_name) + hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) return {"hidden_states": hidden_states, "work": work} @@ -346,6 +348,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, + stream_manager: Optional[CudaStreamManager] = None, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -392,6 +395,7 @@ def __init__( config.num_key_value_heads * self.d_qk, # shape of k config.num_key_value_heads * self.d_qk, # shape of v ) + comm_stream = stream_manager.get(f"comm_stream_{torch.cuda.current_device()}") self.qkv_proj = TensorParallelColumnLinear( self.d_model, config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, @@ -401,6 +405,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, + comm_stream=comm_stream, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -429,6 +434,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + comm_stream=comm_stream, ) self.attention = CoreAttention( @@ -445,8 +451,9 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] + # NOTE: because we dynamically determine which input split + # of domino at runtime, so we need to pass in the op_name op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -455,7 +462,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, op_name=op_name, comm_stream=comm_stream + hidden_states, op_name=op_name ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -700,7 +707,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output, op_name=op_name, comm_stream=comm_stream) + output, work = self.o_proj(attention_output, op_name=op_name) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -712,6 +719,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -721,14 +729,16 @@ def __init__( parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, + stream_manager=stream_manager, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, stream_manager=stream_manager) self.recompute_layer = parallel_config.recompute_layer self.parallel_config = parallel_config self.layer_idx = layer_idx + self.stream_manager = stream_manager def _checkpointed_forward( self, @@ -761,10 +771,9 @@ def _core_forward( sequence_mask: Union[torch.Tensor, TensorPointer], ) -> List[Union[torch.Tensor, TensorPointer]]: num_input_batches = self.parallel_config.domino.num_input_batches - assert num_input_batches == 2 orig_sequence_mask = sequence_mask - comm_stream = CudaStreamManager.get(f"comm_stream_{torch.cuda.current_device()}") + comm_stream = self.stream_manager.get(f"comm_stream_{torch.cuda.current_device()}") hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) @@ -792,19 +801,17 @@ def _core_forward( hidden_states=hidden_states0, sequence_mask=sequence_mask0, op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0), - comm_stream=comm_stream, ) attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1), - comm_stream=comm_stream, ) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output0["work"].wait() - attn_output0["work"].is_completed() + # assert attn_output0["work"].is_completed() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = attn_output0["hidden_states"] + residual0 @@ -819,13 +826,12 @@ def _core_forward( mlp_output0 = self.mlp( hidden_states=hidden_states0, op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0), - comm_stream=comm_stream, ) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): attn_output1["work"].wait() - attn_output1["work"].is_completed() + # assert attn_output1["work"].is_completed() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 @@ -835,7 +841,6 @@ def _core_forward( mlp_output1 = self.mlp( hidden_states=hidden_states1, op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1), - comm_stream=comm_stream, ) comm_stream.wait_stream(torch.cuda.default_stream()) @@ -843,8 +848,8 @@ def _core_forward( mlp_output0["work"].wait() mlp_output1["work"].wait() - mlp_output0["work"].is_completed() - mlp_output1["work"].is_completed() + # assert mlp_output0["work"].is_completed() + # assert mlp_output1["work"].is_completed() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = mlp_output0["hidden_states"] + residual0 @@ -926,6 +931,7 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) + self._init_cuda_stream_for_comm() self.token_position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=Embedding, @@ -955,6 +961,7 @@ def __init__( "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, + "stream_manager": self.stream_manager, }, module_input_keys={"hidden_states", "sequence_mask"}, module_output_keys={"hidden_states", "sequence_mask"}, @@ -997,6 +1004,10 @@ def __init__( module_output_keys={"output"}, ) + def _init_cuda_stream_for_comm(self): + self.stream_manager = CudaStreamManager() + self.stream_manager.create(f"comm_stream_{torch.cuda.current_device()}", device=torch.cuda.current_device()) + def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index ef1f87f4..265f1966 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -7,41 +7,33 @@ class CudaStreamManager: - _streams: Dict[str, "torch.cuda.Stream"] = {} + def __init__(self): + self._streams: Dict[str, "torch.cuda.Stream"] = {} - @staticmethod - def create(name: str, device: torch.device = None): - assert name not in CudaStreamManager._streams - CudaStreamManager._streams[name] = torch.cuda.Stream(device=device) + def create(self, name: str, device: torch.device): + assert name not in self._streams + self._streams[name] = torch.cuda.Stream(device=device) - @staticmethod - def get(name: str): - if name not in CudaStreamManager._streams: - CudaStreamManager.create(name) - return CudaStreamManager._streams.get(name) + def get(self, name: str): + if name not in self._streams: + self.create(name) + return self._streams.get(name) @contextmanager - def run_on_stream(name: str): - stream = CudaStreamManager.get(name) + def run_on_stream(self, name: str): + stream = self.get(name) with torch.cuda.stream(stream): yield stream class AsyncCommBucket: - """ - - Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - RuntimeError: expected Variable or None (got tuple) - Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass - RuntimeError: expected Variable or None (got tuple) - """ - _async_op: Dict[int, "dist.Work"] = {} _copy_async_op: Dict[int, "dist.Work"] = {} @staticmethod def add(op_name: int, work: "dist.Work"): assert op_name not in AsyncCommBucket._async_op, f"Operation with name: {op_name} already exists" + assert work is not None AsyncCommBucket._async_op[op_name] = work AsyncCommBucket._copy_async_op[op_name] = work @@ -84,14 +76,19 @@ def clear_all(): class WaitComm(torch.autograd.Function): + """ + Enforce a tensor to wait for the communication operation to finish + in torch's autograd graph + """ + @staticmethod - def forward(ctx, input, op_name, comm_stream): + def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream): ctx.op_name = op_name ctx.comm_stream = comm_stream return input @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output: torch.Tensor): """ NOTE: because the communication operation is already being executed so the communication stream don't have to wait for the compute stream here @@ -100,7 +97,6 @@ def backward(ctx, grad_output): """ if is_async_comm(ctx.op_name): handle = AsyncCommBucket.pop(ctx.op_name) - assert handle is not None handle.wait() ctx.comm_stream.synchronize() diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 7cc1dc06..43b43264 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -31,11 +31,9 @@ def forward( ctx, tensor: torch.Tensor, group: Optional[ProcessGroup], - async_all_reduce: bool, op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): - ctx.async_all_reduce = async_all_reduce ctx.op_name = op_name ctx.group = group ctx.comm_stream = comm_stream @@ -45,7 +43,6 @@ def forward( def backward(ctx, grad_output: torch.Tensor): group = ctx.group op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else None - # async_all_reduce = is_async_comm(op_name) if op_name is not None else ctx.async_all_reduce return ( DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.comm_stream), @@ -176,11 +173,10 @@ def backward(ctx, grad_output): def differentiable_identity( tensor, group: Optional[ProcessGroup] = None, - async_all_reduce: bool = False, op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): - return DifferentiableIdentity.apply(tensor, group, async_all_reduce, op_name, comm_stream) + return DifferentiableIdentity.apply(tensor, group, op_name, comm_stream) def differentiable_all_reduce_sum( diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 9ad2a51b..8eccdd1d 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -52,6 +52,7 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, + comm_stream: Optional[torch.cuda.Stream] = None, ): self.pg = pg self.world_size = pg.size() @@ -72,6 +73,7 @@ def __init__( self.mode = mode self.async_communication = async_communication + self.comm_stream = comm_stream if contiguous_chunks is not None: assert ( @@ -89,7 +91,6 @@ def forward( self, x: torch.Tensor, op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, ) -> torch.Tensor: return column_linear( input=x, @@ -100,7 +101,7 @@ def forward( async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, op_name=op_name, - comm_stream=comm_stream, + comm_stream=self.comm_stream, ) def extra_repr(self) -> str: @@ -119,6 +120,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + comm_stream: Optional[torch.cuda.Stream] = None, ): self.pg = pg self.world_size = pg.size() @@ -127,6 +129,7 @@ def __init__( self.in_features = in_features // self.world_size self.out_features = out_features + self.comm_stream = comm_stream # No need to shard the bias term, only rank 0 would have it bias = dist.get_rank(self.pg) == 0 and bias @@ -165,12 +168,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward( - self, - x: torch.Tensor, - op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, op_name: Optional[str] = None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -179,7 +177,7 @@ def forward( tp_mode=self.mode, async_communication=self.async_communication, op_name=op_name, - comm_stream=comm_stream, + comm_stream=self.comm_stream, ) def extra_repr(self) -> str: diff --git a/tests/test_comm.py b/tests/test_comm.py index 4039c61d..d49252a4 100644 --- a/tests/test_comm.py +++ b/tests/test_comm.py @@ -6,7 +6,7 @@ rerun_if_address_is_in_use, ) from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import AsyncCommBucket, WaitComm +from nanotron.parallel.comm import AsyncCommBucket, CudaStreamManager, WaitComm class MockWork: @@ -22,6 +22,14 @@ def is_completed(self): return self.completed +def test_cuda_stream_manager(): + manager = CudaStreamManager() + manager.create("test", torch.device("cuda")) + + stream = manager.get("test") + assert isinstance(stream, torch.cuda.Stream) + + @rerun_if_address_is_in_use() def test_add_async_op_to_bucket(): init_distributed(tp=2, dp=1, pp=1)(_test_add_async_op_to_bucket)() From 75abb32d42449baf011bf997081d63128afc4d73 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 7 Mar 2025 12:36:18 +0000 Subject: [PATCH 26/40] move domino's assert args to config --- src/nanotron/config/parallelism_config.py | 8 ++++++++ src/nanotron/parallel/comm.py | 8 ++++---- .../parallel/pipeline_parallel/engine.py | 16 ++++++++-------- .../distributed_differentiable_primitives.py | 4 ++-- src/nanotron/parallel/tensor_parallel/domino.py | 6 +++++- .../parallel/tensor_parallel/functional.py | 16 +++------------- tests/test_domino.py | 6 +++--- 7 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 07688959..1a655f0c 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -72,6 +72,14 @@ def __post_init__(self): if self.is_domino_enabled is True: assert self.tp > 1, "Domino requires TP > 1" + # NOTE: For DoMiNo since we overlapping the communication + # so it doesnt matter whether it's all_reduce or reduce_scatter + # so we just support and tested with all_reduce up to now + # but in principle, it should work with reduce_scatter as well + assert ( + self.tp_linear_async_communication is False + ), "Domino requires TP linear async communication to be False" + assert self.tp_mode == TensorParallelLinearMode.ALL_REDUCE, "Domino requires TP mode to be ALL_REDUCE" @property def is_domino_enabled(self) -> bool: diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 265f1966..4ffa63e4 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -3,7 +3,7 @@ import torch -from nanotron.parallel.tensor_parallel.domino import is_async_comm +from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm class CudaStreamManager: @@ -64,7 +64,7 @@ def is_all_completed() -> bool: not_finished = [] for k, v in AsyncCommBucket._copy_async_op.items(): - assert is_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" + assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" if v.is_completed() is not True: not_finished.append((k, v)) return len(not_finished) == 0 @@ -78,7 +78,7 @@ def clear_all(): class WaitComm(torch.autograd.Function): """ Enforce a tensor to wait for the communication operation to finish - in torch's autograd graph + in torch's autograd graph. """ @staticmethod @@ -95,7 +95,7 @@ def backward(ctx, grad_output: torch.Tensor): but the compute stream waits for the communication stream before proceeding """ - if is_async_comm(ctx.op_name): + if is_domino_async_comm(ctx.op_name): handle = AsyncCommBucket.pop(ctx.op_name) handle.wait() diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 076943c7..fc4b63ca 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,7 +2,7 @@ from typing import Dict, Iterable, Optional, Union import torch -from torch import nn as torch_nn +from torch import nn from torch.nn.parallel import DistributedDataParallel from nanotron import distributed as dist @@ -29,7 +29,7 @@ def forward( context: ContextManagers, state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], - model: torch_nn.Module, + model: nn.Module, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -59,7 +59,7 @@ def forward( return output @staticmethod - def _get_fwd_context(model: torch_nn.Module): + def _get_fwd_context(model: nn.Module): is_ddp = isinstance(model, DistributedDataParallel) # We never to trigger a DDP sync in the next backward pass context = ContextManagers([model.no_sync()] if is_ddp else []) @@ -97,7 +97,7 @@ def backward( def _get_bwd_context( self, - model: torch_nn.Module, + model: nn.Module, nb_backwards: int, grad_accumulator: Optional[GradientAccumulator], ): @@ -118,7 +118,7 @@ def _get_bwd_context( @abstractmethod def train_batch_iter( self, - model: torch_nn.Module, + model: nn.Module, pg: ProcessGroup, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, @@ -130,7 +130,7 @@ def train_batch_iter( @torch.inference_mode() def validate_batch_iter( self, - model: torch_nn.Module, + model: nn.Module, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: @@ -169,7 +169,7 @@ def __init__(self): def train_batch_iter( self, - model: torch_nn.Module, + model: nn.Module, pg: ProcessGroup, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, @@ -226,7 +226,7 @@ def __init__(self): def train_batch_iter( self, - model: torch_nn.Module, + model: nn.Module, pg: ProcessGroup, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 43b43264..8a5ca3ad 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -20,7 +20,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.comm import AsyncCommBucket, is_async_comm +from nanotron.parallel.comm import AsyncCommBucket, is_domino_async_comm class DifferentiableIdentity(torch.autograd.Function): @@ -64,7 +64,7 @@ def forward( op_name: Optional[int] = None, comm_stream: Optional[torch.cuda.Stream] = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: - async_all_reduce = is_async_comm(op_name) if op_name is not None else False + async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False ctx.async_all_reduce = async_all_reduce if group.size() == 1: diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 7c4a10b8..9bfe79d6 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -6,7 +6,11 @@ BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}" -def is_async_comm(x: str) -> bool: +def is_domino_async_comm(x: str) -> bool: + """ + Determine whether a module (e.g., mlp, attention) + performs all-reduce asynchronously in tensor parallelism + """ NON_ASYNC_HANDLE_IDX = [ # "fwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index ac645dfb..fec55a4a 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -25,7 +25,7 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) -from nanotron.parallel.tensor_parallel.domino import is_async_comm +from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -441,11 +441,6 @@ def column_linear( op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ): - is_async_all_reduce = is_async_comm(op_name) if op_name is not None else False - assert not ( - is_async_all_reduce and async_communication - ), "DoMiNo isn't support weight's async communication for column linear." - if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) @@ -600,11 +595,6 @@ def row_linear( op_name: Optional[str] = None, comm_stream: Optional[torch.cuda.Stream] = None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: - is_async_all_reduce = is_async_comm(op_name) if op_name is not None else False - assert not ( - is_async_all_reduce and async_communication - ), "DoMiNo isn't support weight's async communication for row linear." - if async_communication: work = None out = _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -612,12 +602,12 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, comm_stream=comm_stream) - if is_async_all_reduce: + is_domino_async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False + if is_domino_async_all_reduce: work = AsyncCommBucket.pop(op_name) else: work = None elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - assert is_async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) work = None else: diff --git a/tests/test_domino.py b/tests/test_domino.py index 44d9d98a..f6698a9b 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -9,7 +9,7 @@ from nanotron.models.llama import DominoLlamaDecoderLayer from nanotron.parallel import ParallelContext from nanotron.parallel.comm import AsyncCommBucket -from nanotron.parallel.tensor_parallel.domino import is_async_comm +from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm @pytest.mark.parametrize( @@ -25,8 +25,8 @@ ("bwd.layer_attn_1_batch_0", False), ], ) -def test_is_async_comm(op_name, expected): - assert is_async_comm(op_name) == expected +def test_is_domino_async_comm(op_name, expected): + assert is_domino_async_comm(op_name) == expected @pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)]) From da4220c25fd20eae4399053ee1dd4b18c454882a Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 7 Mar 2025 14:27:05 +0000 Subject: [PATCH 27/40] add retrieving async distributed handle from comm bucket instead of returning it directly in linear modules --- src/nanotron/constants.py | 4 + src/nanotron/models/base.py | 2 +- src/nanotron/models/llama.py | 73 ++++++----- src/nanotron/parallel/comm.py | 120 +++++++++++------- .../distributed_differentiable_primitives.py | 23 ++-- .../parallel/tensor_parallel/functional.py | 27 ++-- src/nanotron/parallel/tensor_parallel/nn.py | 13 +- src/nanotron/trainer.py | 12 +- tests/test_comm.py | 46 ++++--- 9 files changed, 185 insertions(+), 135 deletions(-) diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99d..61878181 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -10,3 +10,7 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" + + +### FOR COMMUNICATION ### +CUDA_STREAM_COMM_NAME = "comm_stream_{}" diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index 14ac6908..bb7e7f41 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -71,7 +71,7 @@ def get_embeddings_lm_head_tied_names(self) -> list[str]: Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] """ return [] - + def get_named_params_without_weight_decay(self) -> List[str]: """Return a list of named parameters that should not have weight decay applied to them.""" return [] diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 85938726..c514131c 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,7 +30,7 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import CudaStreamManager, WaitComm +from nanotron.parallel.comm import CudaStreamManager, insert_backward_sync_to_tensor from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P @@ -228,7 +228,7 @@ def __init__( config.intermediate_size, # shape of gate_linear config.intermediate_size, # shape of up_linear ) - comm_stream = stream_manager.get(f"comm_stream_{torch.cuda.current_device()}") + stream_manager.get_default_comm_stream() self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, 2 * config.intermediate_size, @@ -238,7 +238,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, - comm_stream=comm_stream, + stream_manager=stream_manager, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -247,7 +247,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - comm_stream=comm_stream, + stream_manager=stream_manager, ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -255,8 +255,8 @@ def forward( self, hidden_states: torch.Tensor, op_name: Optional[str] = None ): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states, op_name=op_name) - hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) - return {"hidden_states": hidden_states, "work": work} + hidden_states = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) + return {"hidden_states": hidden_states} class CoreAttention(nn.Module): @@ -395,7 +395,6 @@ def __init__( config.num_key_value_heads * self.d_qk, # shape of k config.num_key_value_heads * self.d_qk, # shape of v ) - comm_stream = stream_manager.get(f"comm_stream_{torch.cuda.current_device()}") self.qkv_proj = TensorParallelColumnLinear( self.d_model, config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, @@ -405,7 +404,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, - comm_stream=comm_stream, + stream_manager=stream_manager, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -434,7 +433,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - comm_stream=comm_stream, + stream_manager=stream_manager, ) self.attention = CoreAttention( @@ -707,9 +706,10 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output, op_name=op_name) + # output, work = self.o_proj(attention_output, op_name=op_name) + output = self.o_proj(attention_output, op_name=op_name) - return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} + return {"hidden_states": output, "sequence_mask": sequence_mask} class _BaseLlamaDecoderLayer(nn.Module): @@ -773,7 +773,8 @@ def _core_forward( num_input_batches = self.parallel_config.domino.num_input_batches orig_sequence_mask = sequence_mask - comm_stream = self.stream_manager.get(f"comm_stream_{torch.cuda.current_device()}") + comm_stream = self.stream_manager.get_default_comm_stream() + comm_bucket = self.stream_manager.comm_bucket hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) @@ -786,15 +787,15 @@ def _core_forward( hidden_states0 = self.input_layernorm(hidden_states0) hidden_states1 = self.input_layernorm(hidden_states1) - hidden_states0 = WaitComm.apply( + hidden_states0 = insert_backward_sync_to_tensor( hidden_states0, BWD_ATTN_OP_NAME.format(self.layer_idx, 1), - comm_stream, + self.stream_manager, ) - hidden_states1 = WaitComm.apply( + hidden_states1 = insert_backward_sync_to_tensor( hidden_states1, BWD_MLP_OP_NAME.format(self.layer_idx, 0), - comm_stream, + self.stream_manager, ) attn_output0 = self.attn( @@ -810,17 +811,18 @@ def _core_forward( comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): - attn_output0["work"].wait() + comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)) # assert attn_output0["work"].is_completed() + torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - hidden_states0 = WaitComm.apply( + hidden_states0 = insert_backward_sync_to_tensor( hidden_states0, BWD_MLP_OP_NAME.format(self.layer_idx, 1), - comm_stream, + self.stream_manager, ) mlp_output0 = self.mlp( @@ -830,8 +832,11 @@ def _core_forward( comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): - attn_output1["work"].wait() + # assert 1 == 1 + # attn_output1["work"].wait() + comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)) # assert attn_output1["work"].is_completed() + torch.cuda.default_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 @@ -845,11 +850,14 @@ def _core_forward( comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): - mlp_output0["work"].wait() - mlp_output1["work"].wait() + # mlp_output0["work"].wait() + # mlp_output1["work"].wait() + + # # assert mlp_output0["work"].is_completed() + # # assert mlp_output1["work"].is_completed() + comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 0)) + comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 1)) - # assert mlp_output0["work"].is_completed() - # assert mlp_output1["work"].is_completed() torch.cuda.default_stream().wait_stream(comm_stream) hidden_states0 = mlp_output0["hidden_states"] + residual0 @@ -918,6 +926,7 @@ def __init__( config: LlamaConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() @@ -931,7 +940,6 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - self._init_cuda_stream_for_comm() self.token_position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=Embedding, @@ -961,7 +969,7 @@ def __init__( "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, - "stream_manager": self.stream_manager, + "stream_manager": stream_manager, }, module_input_keys={"hidden_states", "sequence_mask"}, module_output_keys={"hidden_states", "sequence_mask"}, @@ -1004,10 +1012,6 @@ def __init__( module_output_keys={"output"}, ) - def _init_cuda_stream_for_comm(self): - self.stream_manager = CudaStreamManager() - self.stream_manager.create(f"comm_stream_{torch.cuda.current_device()}", device=torch.cuda.current_device()) - def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] @@ -1115,9 +1119,15 @@ def __init__( parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() - self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + self.model = LlamaModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + stream_manager=stream_manager, + ) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, @@ -1132,6 +1142,7 @@ def __init__( self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config + self.stream_manager = stream_manager def forward( self, diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 4ffa63e4..f7bed601 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -3,76 +3,83 @@ import torch +from nanotron.constants import CUDA_STREAM_COMM_NAME from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm -class CudaStreamManager: - def __init__(self): - self._streams: Dict[str, "torch.cuda.Stream"] = {} - - def create(self, name: str, device: torch.device): - assert name not in self._streams - self._streams[name] = torch.cuda.Stream(device=device) - - def get(self, name: str): - if name not in self._streams: - self.create(name) - return self._streams.get(name) - - @contextmanager - def run_on_stream(self, name: str): - stream = self.get(name) - with torch.cuda.stream(stream): - yield stream - - class AsyncCommBucket: - _async_op: Dict[int, "dist.Work"] = {} - _copy_async_op: Dict[int, "dist.Work"] = {} + """ + Store aynchronous communication operations. + """ - @staticmethod - def add(op_name: int, work: "dist.Work"): - assert op_name not in AsyncCommBucket._async_op, f"Operation with name: {op_name} already exists" + def __init__(self): + self._async_op: Dict[int, "dist.Work"] = {} + self._copy_async_op: Dict[int, "dist.Work"] = {} + + def add(self, op_name: int, work: "dist.Work"): + assert op_name not in self._async_op, f"Operation with name: {op_name} already exists" assert work is not None - AsyncCommBucket._async_op[op_name] = work - AsyncCommBucket._copy_async_op[op_name] = work + self._async_op[op_name] = work + self._copy_async_op[op_name] = work - @staticmethod - def get(op_name: str) -> "dist.Work": - if op_name not in AsyncCommBucket._async_op: + def get(self, op_name: str) -> "dist.Work": + if op_name not in self._async_op: raise KeyError(f"Operation with name: {op_name} doesn't exist") - return AsyncCommBucket._async_op.get(op_name) + return self._async_op.get(op_name) - @staticmethod - def pop(op_name: str) -> "dist.Work": - if op_name not in AsyncCommBucket._async_op: + def pop(self, op_name: str) -> "dist.Work": + if op_name not in self._async_op: raise KeyError(f"Operation with name: {op_name} doesn't exist") - return AsyncCommBucket._async_op.pop(op_name) + return self._async_op.pop(op_name) - @staticmethod - def wait(op_name: str): + def wait(self, op_name: str): """Wait and remove the operation from the bucket""" - work = AsyncCommBucket.pop(op_name) + work = self.pop(op_name) work.wait() - @staticmethod - def is_all_completed() -> bool: - if not len(AsyncCommBucket._async_op) == 0: + def is_all_completed(self) -> bool: + if not len(self._async_op) == 0: return False not_finished = [] - for k, v in AsyncCommBucket._copy_async_op.items(): + for k, v in self._copy_async_op.items(): assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" if v.is_completed() is not True: not_finished.append((k, v)) return len(not_finished) == 0 - @staticmethod - def clear_all(): - AsyncCommBucket._async_op.clear() - AsyncCommBucket._copy_async_op.clear() + def clear_all(self): + self._async_op.clear() + self._copy_async_op.clear() + + +class CudaStreamManager: + def __init__(self): + self._streams: Dict[str, "torch.cuda.Stream"] = {} + self.comm_bucket = AsyncCommBucket() + + def create(self, name: str, device: torch.device): + assert name not in self._streams + self._streams[name] = torch.cuda.Stream(device=device) + + def get(self, name: str): + if name not in self._streams: + self.create(name) + return self._streams.get(name) + + def get_default_comm_stream(self) -> torch.cuda.Stream: + """ + Return the default communication stream for the current cuda device. + """ + return self.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device())) + + @contextmanager + def run_on_stream(self, name: str): + stream = self.get(name) + with torch.cuda.stream(stream): + yield stream class WaitComm(torch.autograd.Function): @@ -82,9 +89,11 @@ class WaitComm(torch.autograd.Function): """ @staticmethod - def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream): + def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream, comm_bucket: AsyncCommBucket): + assert isinstance(comm_stream, torch.cuda.Stream) ctx.op_name = op_name ctx.comm_stream = comm_stream + ctx.comm_bucket = comm_bucket return input @staticmethod @@ -96,10 +105,23 @@ def backward(ctx, grad_output: torch.Tensor): before proceeding """ if is_domino_async_comm(ctx.op_name): - handle = AsyncCommBucket.pop(ctx.op_name) + handle = ctx.comm_bucket.pop(ctx.op_name) handle.wait() ctx.comm_stream.synchronize() torch.cuda.default_stream().wait_stream(ctx.comm_stream) - return grad_output, None, None + return grad_output, None, None, None + + +def insert_backward_sync_to_tensor( + tensor: torch.Tensor, op_name: str, stream_manager: CudaStreamManager +) -> torch.Tensor: + """ + Insert a wait communication operation of a given op_name to the autograd graph + of a tensor. + """ + + assert isinstance(stream_manager, CudaStreamManager) + comm_stream = stream_manager.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device())) + return WaitComm.apply(tensor, op_name, comm_stream, stream_manager.comm_bucket) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 8a5ca3ad..e50d34ba 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -20,7 +20,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.comm import AsyncCommBucket, is_domino_async_comm +from nanotron.parallel.comm import CudaStreamManager, is_domino_async_comm class DifferentiableIdentity(torch.autograd.Function): @@ -32,11 +32,11 @@ def forward( tensor: torch.Tensor, group: Optional[ProcessGroup], op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ): ctx.op_name = op_name ctx.group = group - ctx.comm_stream = comm_stream + ctx.stream_manager = stream_manager return tensor @staticmethod @@ -45,7 +45,7 @@ def backward(ctx, grad_output: torch.Tensor): op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else None return ( - DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.comm_stream), + DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.stream_manager), None, None, None, @@ -62,7 +62,7 @@ def forward( tensor: torch.Tensor, group: Optional[ProcessGroup], op_name: Optional[int] = None, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False ctx.async_all_reduce = async_all_reduce @@ -70,7 +70,8 @@ def forward( if group.size() == 1: return tensor - if comm_stream is not None: + if stream_manager is not None: + comm_stream = stream_manager.get_default_comm_stream() comm_stream.wait_stream(torch.cuda.default_stream()) comm_context = torch.cuda.stream(comm_stream) else: @@ -80,7 +81,7 @@ def forward( if async_all_reduce is True: assert comm_stream is not None handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) - AsyncCommBucket.add(op_name, handle) + stream_manager.comm_bucket.add(op_name, handle) else: dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) @@ -174,18 +175,18 @@ def differentiable_identity( tensor, group: Optional[ProcessGroup] = None, op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ): - return DifferentiableIdentity.apply(tensor, group, op_name, comm_stream) + return DifferentiableIdentity.apply(tensor, group, op_name, stream_manager) def differentiable_all_reduce_sum( tensor, group: Optional[ProcessGroup] = None, op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ): - return DifferentiableAllReduceSum.apply(tensor, group, op_name, comm_stream) + return DifferentiableAllReduceSum.apply(tensor, group, op_name, stream_manager) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fec55a4a..9ac09606 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,13 +19,12 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) -from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -439,13 +438,13 @@ def column_linear( async_communication: bool, tp_recompute_allgather: bool = True, op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group, op_name=op_name, comm_stream=comm_stream) + input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -593,24 +592,24 @@ def row_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, op_name: Optional[str] = None, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: - work = None out = _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) else: out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, comm_stream=comm_stream) - is_domino_async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False - if is_domino_async_all_reduce: - work = AsyncCommBucket.pop(op_name) - else: - work = None + out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) + # is_domino_async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False + # if is_domino_async_all_reduce: + # work = AsyncCommBucket.pop(op_name) + # else: + # work = None elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=group) - work = None + # work = None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") - return out, work + # return out, work + return out diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 8eccdd1d..53f1f930 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -19,6 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import get_global_rank +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.sharded_parameters import ( SplitConfig, @@ -52,7 +53,7 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -73,7 +74,7 @@ def __init__( self.mode = mode self.async_communication = async_communication - self.comm_stream = comm_stream + self.stream_manager = stream_manager if contiguous_chunks is not None: assert ( @@ -101,7 +102,7 @@ def forward( async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, op_name=op_name, - comm_stream=self.comm_stream, + stream_manager=self.stream_manager, ) def extra_repr(self) -> str: @@ -120,7 +121,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - comm_stream: Optional[torch.cuda.Stream] = None, + stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -129,7 +130,7 @@ def __init__( self.in_features = in_features // self.world_size self.out_features = out_features - self.comm_stream = comm_stream + self.stream_manager = stream_manager # No need to shard the bias term, only rank 0 would have it bias = dist.get_rank(self.pg) == 0 and bias @@ -177,7 +178,7 @@ def forward(self, x: torch.Tensor, op_name: Optional[str] = None) -> torch.Tenso tp_mode=self.mode, async_communication=self.async_communication, op_name=op_name, - comm_stream=self.comm_stream, + stream_manager=self.stream_manager, ) def extra_repr(self) -> str: diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index a11b91a2..5cccd47a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -35,7 +35,7 @@ SpectralMupInit, get_config_from_file, ) -from nanotron.constants import MODEL_CONFIG_FILE_NAME +from nanotron.constants import CUDA_STREAM_COMM_NAME, MODEL_CONFIG_FILE_NAME from nanotron.dataloader import sanity_check_dataloader from nanotron.helpers import ( _vocab_size_with_padding, @@ -61,7 +61,7 @@ from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.parallel.pipeline_parallel.engine import ( @@ -579,7 +579,7 @@ def training_step( self.post_train_step() - AsyncCommBucket.clear_all() + self.stream_manager.comm_bucket.clear_all() return outputs, loss_avg @@ -717,12 +717,18 @@ def _init_model_instance(self) -> NanotronModel: model_config_cls in CONFIG_TO_MODEL_CLASS ), f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported" + self.stream_manager = CudaStreamManager() + self.stream_manager.create( + CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device()), device=torch.cuda.current_device() + ) + model = self._init_model( model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls]( config=self.model_config, parallel_context=self.parallel_context, parallel_config=self.config.parallelism, random_states=self.random_states, + stream_manager=self.stream_manager, ), ) return model diff --git a/tests/test_comm.py b/tests/test_comm.py index d49252a4..39514a38 100644 --- a/tests/test_comm.py +++ b/tests/test_comm.py @@ -6,7 +6,7 @@ rerun_if_address_is_in_use, ) from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import AsyncCommBucket, CudaStreamManager, WaitComm +from nanotron.parallel.comm import AsyncCommBucket, CudaStreamManager, insert_backward_sync_to_tensor class MockWork: @@ -40,9 +40,10 @@ def _test_add_async_op_to_bucket(parallel_context: ParallelContext): tensor = torch.randn(1, device="cuda") work = dist.all_reduce(tensor, async_op=True) - AsyncCommBucket.add(OP_NAME, work) + comm_bucket = AsyncCommBucket() + comm_bucket.add(OP_NAME, work) - assert AsyncCommBucket.get(OP_NAME) is work + assert comm_bucket.get(OP_NAME) is work @rerun_if_address_is_in_use() @@ -53,14 +54,15 @@ def test_wait_async_op_to_bucket(): def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): OP_NAME = "test" work = MockWork() + comm_bucket = AsyncCommBucket() - AsyncCommBucket.add(OP_NAME, work) + comm_bucket.add(OP_NAME, work) assert work.is_completed() is False - AsyncCommBucket.wait(OP_NAME) + comm_bucket.wait(OP_NAME) assert work.is_completed() with pytest.raises(KeyError): - AsyncCommBucket.get(OP_NAME) + comm_bucket.get(OP_NAME) @rerun_if_address_is_in_use() @@ -71,12 +73,13 @@ def test_is_all_completed_in_async_bucket(): def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): OP_NAME = "test" work = MockWork() + comm_bucket = AsyncCommBucket() - AsyncCommBucket.add(OP_NAME, work) - assert AsyncCommBucket.is_all_completed() is False + comm_bucket.add(OP_NAME, work) + assert comm_bucket.is_all_completed() is False - AsyncCommBucket.wait(OP_NAME) - assert AsyncCommBucket.is_all_completed() is True + comm_bucket.wait(OP_NAME) + assert comm_bucket.is_all_completed() is True @rerun_if_address_is_in_use() @@ -85,16 +88,18 @@ def test_clear_ops_in_async_bucket(): def _test_clear_ops_in_async_bucket(parallel_context: ParallelContext): - AsyncCommBucket.add("test1", MockWork()) - AsyncCommBucket.add("test2", MockWork()) - AsyncCommBucket.add("test3", MockWork()) + comm_bucket = AsyncCommBucket() - assert AsyncCommBucket.is_all_completed() is False + comm_bucket.add("test1", MockWork()) + comm_bucket.add("test2", MockWork()) + comm_bucket.add("test3", MockWork()) - AsyncCommBucket.clear_all() - assert AsyncCommBucket.is_all_completed() is True + assert comm_bucket.is_all_completed() is False + + comm_bucket.clear_all() + assert comm_bucket.is_all_completed() is True with pytest.raises(KeyError): - AsyncCommBucket.get("test1") + comm_bucket.get("test1") @rerun_if_address_is_in_use() @@ -103,16 +108,17 @@ def test_wait_comm(): def _test_wait_comm(parallel_context: ParallelContext): - tensor = torch.randn(1, device="cuda", requires_grad=True) OP_NAME = "test" + tensor = torch.randn(1, device="cuda", requires_grad=True) + stream_manager = CudaStreamManager() comm_stream = torch.cuda.Stream() with torch.cuda.stream(comm_stream): work = MockWork() - AsyncCommBucket.add(OP_NAME, work) + stream_manager.comm_bucket.add(OP_NAME, work) - output = WaitComm.apply(tensor, OP_NAME) + output = insert_backward_sync_to_tensor(tensor, OP_NAME, stream_manager) assert work.is_completed() is False # NOTE: we test that it waits for the async op to complete From d7a636fb81d03fc1e4efeb2a8db4dbf2dc98b6dd Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 7 Mar 2025 15:43:11 +0000 Subject: [PATCH 28/40] small refactor --- src/nanotron/models/llama.py | 9 ------ .../parallel/tensor_parallel/functional.py | 28 ++++++++----------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index c514131c..2d623a36 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -812,7 +812,6 @@ def _core_forward( comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)) - # assert attn_output0["work"].is_completed() torch.cuda.default_stream().wait_stream(comm_stream) @@ -832,10 +831,7 @@ def _core_forward( comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): - # assert 1 == 1 - # attn_output1["work"].wait() comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)) - # assert attn_output1["work"].is_completed() torch.cuda.default_stream().wait_stream(comm_stream) @@ -850,11 +846,6 @@ def _core_forward( comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): - # mlp_output0["work"].wait() - # mlp_output1["work"].wait() - - # # assert mlp_output0["work"].is_completed() - # # assert mlp_output1["work"].is_completed() comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 0)) comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 1)) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 9ac09606..57ca7446 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional, Tuple +from typing import Optional import torch from torch.nn import functional as F @@ -593,23 +593,17 @@ def row_linear( async_communication: bool, op_name: Optional[str] = None, stream_manager: Optional[CudaStreamManager] = None, -) -> Tuple[torch.Tensor, Optional[torch.Future]]: +): if async_communication: - out = _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + + out = F.linear(input, weight, bias) + + if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + out = differentiable_reduce_scatter_sum(out, group=group) else: - out = F.linear(input, weight, bias) - if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) - # is_domino_async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False - # if is_domino_async_all_reduce: - # work = AsyncCommBucket.pop(op_name) - # else: - # work = None - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - out = differentiable_reduce_scatter_sum(out, group=group) - # work = None - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") + raise ValueError(f"Got unexpected mode: {tp_mode}.") - # return out, work return out From d3d8c10cffd39793333dba8a97fb69d513fb66e9 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 7 Mar 2025 16:05:16 +0000 Subject: [PATCH 29/40] add CudaStreamManager.init_default_comm_stream and fix domino test --- src/nanotron/models/llama.py | 5 ++++- src/nanotron/parallel/comm.py | 6 ++++++ src/nanotron/trainer.py | 6 ++---- tests/helpers/llama.py | 5 +++++ tests/test_domino.py | 7 +++++-- 5 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2d623a36..ef182c7b 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -228,7 +228,6 @@ def __init__( config.intermediate_size, # shape of gate_linear config.intermediate_size, # shape of up_linear ) - stream_manager.get_default_comm_stream() self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, 2 * config.intermediate_size, @@ -765,6 +764,10 @@ def forward( class DominoLlamaDecoderLayer(_BaseLlamaDecoderLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.stream_manager is not None, "DominoLlamaDecoderLayer requires a stream_manager" + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index f7bed601..95495213 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -60,6 +60,12 @@ def __init__(self): self._streams: Dict[str, "torch.cuda.Stream"] = {} self.comm_bucket = AsyncCommBucket() + def init_default_comm_stream(self): + """ + Initialize the default communication stream for the current cuda device. + """ + self.create(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device()), torch.cuda.current_device()) + def create(self, name: str, device: torch.device): assert name not in self._streams self._streams[name] = torch.cuda.Stream(device=device) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 5cccd47a..b7f241b6 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -35,7 +35,7 @@ SpectralMupInit, get_config_from_file, ) -from nanotron.constants import CUDA_STREAM_COMM_NAME, MODEL_CONFIG_FILE_NAME +from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.dataloader import sanity_check_dataloader from nanotron.helpers import ( _vocab_size_with_padding, @@ -718,9 +718,7 @@ def _init_model_instance(self) -> NanotronModel: ), f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported" self.stream_manager = CudaStreamManager() - self.stream_manager.create( - CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device()), device=torch.cuda.current_device() - ) + self.stream_manager.init_default_comm_stream() model = self._init_model( model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls]( diff --git a/tests/helpers/llama.py b/tests/helpers/llama.py index 8aae7669..d80b1b30 100644 --- a/tests/helpers/llama.py +++ b/tests/helpers/llama.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from nanotron.config import ( AdamWOptimizerArgs, @@ -20,6 +22,7 @@ from nanotron.config.config import PretrainDatasetsArgs from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.context import ParallelContext from nanotron.trainer import mark_tied_parameters @@ -115,6 +118,7 @@ def create_llama_from_config( parallel_config: ParallelismArgs, device: torch.device, parallel_context: ParallelContext, + stream_manager: Optional[CudaStreamManager] = None, ) -> LlamaForTraining: """ @@ -131,6 +135,7 @@ def create_llama_from_config( parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, + stream_manager=stream_manager, ), parallel_context=parallel_context, dtype=torch.bfloat16, diff --git a/tests/test_domino.py b/tests/test_domino.py index f6698a9b..b9f93a24 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -8,7 +8,7 @@ from nanotron.config.parallelism_config import DominoArgs from nanotron.models.llama import DominoLlamaDecoderLayer from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm @@ -51,12 +51,15 @@ def _test_domino_model( ): config = get_llama_training_config(model_args, parallel_context) config.parallelism.domino = DominoArgs(num_input_batches=2) + stream_manager = CudaStreamManager() + stream_manager.init_default_comm_stream() llama_model = create_llama_from_config( model_config=config.model.model_config, parallel_config=config.parallelism, device=torch.device("cuda"), parallel_context=parallel_context, + stream_manager=stream_manager, ) llama_model.init_model_randomly(config=config) @@ -68,4 +71,4 @@ def _test_domino_model( outputs = llama_model(input_ids, input_mask, input_mask, input_mask) assert isinstance(outputs["loss"], torch.Tensor) - assert AsyncCommBucket.is_all_completed() is True + assert stream_manager.comm_bucket.is_all_completed() is True From 74d415c1c02b9463214fb46db060c0efbfa5a0e4 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 7 Mar 2025 18:43:08 +0000 Subject: [PATCH 30/40] removing op_name in the forward pass by adding OpNameContext --- src/nanotron/models/llama.py | 49 ++++++++----------- .../parallel/tensor_parallel/domino.py | 31 ++++++++++++ .../parallel/tensor_parallel/functional.py | 5 +- src/nanotron/parallel/tensor_parallel/nn.py | 5 +- tests/test_domino.py | 42 +++++++++++++++- 5 files changed, 97 insertions(+), 35 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ef182c7b..0fb2af4d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -39,6 +39,7 @@ BWD_MLP_OP_NAME, FWD_ATTN_OP_NAME, FWD_MLP_OP_NAME, + OpNameContext, ) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -250,11 +251,9 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward( - self, hidden_states: torch.Tensor, op_name: Optional[str] = None - ): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states, op_name=op_name) - hidden_states = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) + def forward(self, hidden_states: torch.Tensor): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states) + hidden_states = self.down_proj(self.split_silu_mul(merged_states)) return {"hidden_states": hidden_states} @@ -449,9 +448,6 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] - # NOTE: because we dynamically determine which input split - # of domino at runtime, so we need to pass in the op_name - op_name: Optional[str] = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -460,7 +456,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, op_name=op_name + hidden_states ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -706,7 +702,7 @@ def forward( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) # output, work = self.o_proj(attention_output, op_name=op_name) - output = self.o_proj(attention_output, op_name=op_name) + output = self.o_proj(attention_output) return {"hidden_states": output, "sequence_mask": sequence_mask} @@ -801,16 +797,17 @@ def _core_forward( self.stream_manager, ) - attn_output0 = self.attn( - hidden_states=hidden_states0, - sequence_mask=sequence_mask0, - op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0), - ) - attn_output1 = self.attn( - hidden_states=hidden_states1, - sequence_mask=sequence_mask1, - op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1), - ) + with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)): + attn_output0 = self.attn( + hidden_states=hidden_states0, + sequence_mask=sequence_mask0, + ) + + with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)): + attn_output1 = self.attn( + hidden_states=hidden_states1, + sequence_mask=sequence_mask1, + ) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): @@ -827,10 +824,8 @@ def _core_forward( self.stream_manager, ) - mlp_output0 = self.mlp( - hidden_states=hidden_states0, - op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0), - ) + with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 0)): + mlp_output0 = self.mlp(hidden_states=hidden_states0) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): @@ -842,10 +837,8 @@ def _core_forward( residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - mlp_output1 = self.mlp( - hidden_states=hidden_states1, - op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1), - ) + with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 1)): + mlp_output1 = self.mlp(hidden_states=hidden_states1) comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 9bfe79d6..d7a98f04 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -1,10 +1,14 @@ import re +import threading +from typing import Optional FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}" FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}" BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}" +_operation_context = threading.local() + def is_domino_async_comm(x: str) -> bool: """ @@ -20,3 +24,30 @@ def is_domino_async_comm(x: str) -> bool: regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex not_async = bool(regex.match(x)) return not not_async + + +class OpNameContext: + """ + A context manager to set the name of a module operation + """ + + def __init__(self, op_name: str): + self.op_name = op_name + self.previous_op_name = None + + def __enter__(self): + if not hasattr(_operation_context, "current_op_name"): + _operation_context.current_op_name = None + self.previous_op_name = _operation_context.current_op_name + _operation_context.current_op_name = self.op_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _operation_context.current_op_name = self.previous_op_name + + +def get_op_name() -> Optional[str]: + """ + Get the name of the current operation. + """ + return getattr(_operation_context, "current_op_name", None) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 57ca7446..03357160 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -25,6 +25,7 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) +from nanotron.parallel.tensor_parallel.domino import get_op_name from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -437,13 +438,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, - op_name: Optional[str] = None, stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + op_name = get_op_name() input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: @@ -591,7 +592,6 @@ def row_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, - op_name: Optional[str] = None, stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: @@ -600,6 +600,7 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + op_name = get_op_name() out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 53f1f930..f92fe0ee 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -91,7 +91,6 @@ def __init__( def forward( self, x: torch.Tensor, - op_name: Optional[str] = None, ) -> torch.Tensor: return column_linear( input=x, @@ -101,7 +100,6 @@ def forward( tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, - op_name=op_name, stream_manager=self.stream_manager, ) @@ -169,7 +167,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, op_name: Optional[str] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -177,7 +175,6 @@ def forward(self, x: torch.Tensor, op_name: Optional[str] = None) -> torch.Tenso group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - op_name=op_name, stream_manager=self.stream_manager, ) diff --git a/tests/test_domino.py b/tests/test_domino.py index b9f93a24..e3182de4 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -9,7 +9,7 @@ from nanotron.models.llama import DominoLlamaDecoderLayer from nanotron.parallel import ParallelContext from nanotron.parallel.comm import CudaStreamManager -from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import OpNameContext, get_op_name, is_domino_async_comm @pytest.mark.parametrize( @@ -72,3 +72,43 @@ def _test_domino_model( assert isinstance(outputs["loss"], torch.Tensor) assert stream_manager.comm_bucket.is_all_completed() is True + + +### OpNameContext tests ### + + +def test_op_name_context_reentry(): + assert get_op_name() is None + context = OpNameContext("reusable_op") + + with context: + assert get_op_name() == "reusable_op" + + assert get_op_name() is None + + with context: + assert get_op_name() == "reusable_op" + + assert get_op_name() is None + + +def test_deeply_nested_contexts(): + with OpNameContext("level1"): + assert get_op_name() == "level1" + + with OpNameContext("level2"): + assert get_op_name() == "level2" + + assert get_op_name() == "level1" + + +def test_multiple_sequential_contexts(): + assert get_op_name() is None + + with OpNameContext("first_op"): + assert get_op_name() == "first_op" + + with OpNameContext("second_op"): + assert get_op_name() == "second_op" + + assert get_op_name() is None From 08a4472f5eec54ec6425e45b8404a100d72e2ecc Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 8 Mar 2025 10:46:21 +0000 Subject: [PATCH 31/40] add CudaStreamManager as context --- src/nanotron/models/llama.py | 19 ++-- src/nanotron/parallel/comm.py | 36 +------- .../distributed_differentiable_primitives.py | 18 ++-- .../parallel/tensor_parallel/domino.py | 87 ++++++++++++++++--- .../parallel/tensor_parallel/functional.py | 7 +- src/nanotron/parallel/tensor_parallel/nn.py | 7 -- tests/test_domino.py | 14 +-- 7 files changed, 104 insertions(+), 84 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 0fb2af4d..484f8321 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -39,7 +39,7 @@ BWD_MLP_OP_NAME, FWD_ATTN_OP_NAME, FWD_MLP_OP_NAME, - OpNameContext, + set_operation_context, ) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -215,7 +215,6 @@ def __init__( config: LlamaConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, - stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() @@ -238,7 +237,6 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, - stream_manager=stream_manager, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -247,7 +245,6 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - stream_manager=stream_manager, ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -346,7 +343,6 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, - stream_manager: Optional[CudaStreamManager] = None, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -402,7 +398,6 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, - stream_manager=stream_manager, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -431,7 +426,6 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - stream_manager=stream_manager, ) self.attention = CoreAttention( @@ -724,11 +718,10 @@ def __init__( parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, - stream_manager=stream_manager, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, stream_manager=stream_manager) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) self.recompute_layer = parallel_config.recompute_layer self.parallel_config = parallel_config @@ -797,13 +790,13 @@ def _core_forward( self.stream_manager, ) - with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)): + with set_operation_context(FWD_ATTN_OP_NAME.format(self.layer_idx, 0), self.stream_manager): attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, ) - with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)): + with set_operation_context(FWD_ATTN_OP_NAME.format(self.layer_idx, 1), self.stream_manager): attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, @@ -824,7 +817,7 @@ def _core_forward( self.stream_manager, ) - with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 0)): + with set_operation_context(FWD_MLP_OP_NAME.format(self.layer_idx, 0), self.stream_manager): mlp_output0 = self.mlp(hidden_states=hidden_states0) comm_stream.wait_stream(torch.cuda.default_stream()) @@ -837,7 +830,7 @@ def _core_forward( residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 1)): + with set_operation_context(FWD_MLP_OP_NAME.format(self.layer_idx, 1), self.stream_manager): mlp_output1 = self.mlp(hidden_states=hidden_states1) comm_stream.wait_stream(torch.cuda.default_stream()) diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 95495213..551a8459 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -4,7 +4,6 @@ import torch from nanotron.constants import CUDA_STREAM_COMM_NAME -from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm class AsyncCommBucket: @@ -45,7 +44,7 @@ def is_all_completed(self) -> bool: not_finished = [] for k, v in self._copy_async_op.items(): - assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" + # assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" if v.is_completed() is not True: not_finished.append((k, v)) return len(not_finished) == 0 @@ -88,38 +87,6 @@ def run_on_stream(self, name: str): yield stream -class WaitComm(torch.autograd.Function): - """ - Enforce a tensor to wait for the communication operation to finish - in torch's autograd graph. - """ - - @staticmethod - def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream, comm_bucket: AsyncCommBucket): - assert isinstance(comm_stream, torch.cuda.Stream) - ctx.op_name = op_name - ctx.comm_stream = comm_stream - ctx.comm_bucket = comm_bucket - return input - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - """ - NOTE: because the communication operation is already being executed - so the communication stream don't have to wait for the compute stream here - but the compute stream waits for the communication stream - before proceeding - """ - if is_domino_async_comm(ctx.op_name): - handle = ctx.comm_bucket.pop(ctx.op_name) - handle.wait() - - ctx.comm_stream.synchronize() - torch.cuda.default_stream().wait_stream(ctx.comm_stream) - - return grad_output, None, None, None - - def insert_backward_sync_to_tensor( tensor: torch.Tensor, op_name: str, stream_manager: CudaStreamManager ) -> torch.Tensor: @@ -127,6 +94,7 @@ def insert_backward_sync_to_tensor( Insert a wait communication operation of a given op_name to the autograd graph of a tensor. """ + from nanotron.parallel.tensor_parallel.domino import WaitComm assert isinstance(stream_manager, CudaStreamManager) comm_stream = stream_manager.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device())) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index e50d34ba..2e87b681 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -20,7 +20,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.comm import CudaStreamManager, is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm class DifferentiableIdentity(torch.autograd.Function): @@ -32,10 +32,10 @@ def forward( tensor: torch.Tensor, group: Optional[ProcessGroup], op_name: Optional[str] = None, - stream_manager: Optional[CudaStreamManager] = None, + stream_manager: Optional["CudaStreamManager"] = None, ): - ctx.op_name = op_name ctx.group = group + ctx.op_name = op_name ctx.stream_manager = stream_manager return tensor @@ -49,7 +49,6 @@ def backward(ctx, grad_output: torch.Tensor): None, None, None, - None, ) @@ -62,7 +61,7 @@ def forward( tensor: torch.Tensor, group: Optional[ProcessGroup], op_name: Optional[int] = None, - stream_manager: Optional[CudaStreamManager] = None, + stream_manager: Optional["CudaStreamManager"] = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False ctx.async_all_reduce = async_all_reduce @@ -70,7 +69,8 @@ def forward( if group.size() == 1: return tensor - if stream_manager is not None: + if async_all_reduce is True: + assert stream_manager is not None, f"op_name: {op_name}" comm_stream = stream_manager.get_default_comm_stream() comm_stream.wait_stream(torch.cuda.default_stream()) comm_context = torch.cuda.stream(comm_stream) @@ -89,7 +89,7 @@ def forward( @staticmethod def backward(ctx, grad_output): - return grad_output, None, None, None, None + return grad_output, None, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -175,7 +175,7 @@ def differentiable_identity( tensor, group: Optional[ProcessGroup] = None, op_name: Optional[str] = None, - stream_manager: Optional[CudaStreamManager] = None, + stream_manager: Optional["CudaStreamManager"] = None, ): return DifferentiableIdentity.apply(tensor, group, op_name, stream_manager) @@ -184,7 +184,7 @@ def differentiable_all_reduce_sum( tensor, group: Optional[ProcessGroup] = None, op_name: Optional[str] = None, - stream_manager: Optional[CudaStreamManager] = None, + stream_manager: Optional["CudaStreamManager"] = None, ): return DifferentiableAllReduceSum.apply(tensor, group, op_name, stream_manager) diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index d7a98f04..6c147214 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -1,7 +1,12 @@ import re import threading +from contextlib import contextmanager from typing import Optional +import torch + +from nanotron.parallel.comm import AsyncCommBucket, CudaStreamManager + FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}" FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}" @@ -26,28 +31,90 @@ def is_domino_async_comm(x: str) -> bool: return not not_async -class OpNameContext: +class OperationContext: """ - A context manager to set the name of a module operation + Context manager that sets both operation name and stream manager in thread-local context. + + Args: + op_name: Name of the current operation + stream_manager: Associated CUDA stream manager """ - def __init__(self, op_name: str): + def __init__(self, op_name: str, stream_manager: CudaStreamManager): self.op_name = op_name - self.previous_op_name = None + self.stream_manager = stream_manager + self._previous_op_name: Optional[str] = None + self._previous_stream_manager: Optional[CudaStreamManager] = None def __enter__(self): - if not hasattr(_operation_context, "current_op_name"): - _operation_context.current_op_name = None - self.previous_op_name = _operation_context.current_op_name - _operation_context.current_op_name = self.op_name + """Store current context and set new values""" + # Handle operation name + if not hasattr(_operation_context, "_current_op_name"): + _operation_context._current_op_name = None + self._previous_op_name = _operation_context._current_op_name + _operation_context._current_op_name = self.op_name + + # Handle stream manager + if not hasattr(_operation_context, "_current_stream_manager"): + _operation_context._current_stream_manager = None + self._previous_stream_manager = _operation_context._current_stream_manager + _operation_context._current_stream_manager = self.stream_manager + return self def __exit__(self, exc_type, exc_val, exc_tb): - _operation_context.current_op_name = self.previous_op_name + """Restore previous context values""" + _operation_context._current_op_name = self._previous_op_name + _operation_context._current_stream_manager = self._previous_stream_manager + + +class WaitComm(torch.autograd.Function): + """ + Enforce a tensor to wait for the communication operation to finish + in torch's autograd graph. + """ + + @staticmethod + def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream, comm_bucket: AsyncCommBucket): + assert isinstance(comm_stream, torch.cuda.Stream) + ctx.op_name = op_name + ctx.comm_stream = comm_stream + ctx.comm_bucket = comm_bucket + return input + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """ + NOTE: because the communication operation is already being executed + so the communication stream don't have to wait for the compute stream here + but the compute stream waits for the communication stream + before proceeding + """ + if is_domino_async_comm(ctx.op_name): + handle = ctx.comm_bucket.pop(ctx.op_name) + handle.wait() + + ctx.comm_stream.synchronize() + torch.cuda.default_stream().wait_stream(ctx.comm_stream) + + return grad_output, None, None, None + + +@contextmanager +def set_operation_context(name: str, stream_manager: CudaStreamManager): + with OperationContext(name, stream_manager): + yield def get_op_name() -> Optional[str]: """ Get the name of the current operation. """ - return getattr(_operation_context, "current_op_name", None) + return getattr(_operation_context, "_current_op_name", None) + + +def get_stream_manager() -> Optional[CudaStreamManager]: + """ + Get the stream manager for the current operation. + """ + return getattr(_operation_context, "_current_stream_manager", None) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 03357160..030d863e 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,13 +19,12 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) -from nanotron.parallel.tensor_parallel.domino import get_op_name +from nanotron.parallel.tensor_parallel.domino import get_op_name, get_stream_manager from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -438,13 +437,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, - stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: op_name = get_op_name() + stream_manager = get_stream_manager() input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: @@ -592,7 +591,6 @@ def row_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, - stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -601,6 +599,7 @@ def row_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: op_name = get_op_name() + stream_manager = get_stream_manager() out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index f92fe0ee..5598df2f 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -19,7 +19,6 @@ from nanotron import distributed as dist from nanotron.distributed import get_global_rank -from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.sharded_parameters import ( SplitConfig, @@ -53,7 +52,6 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, - stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -74,7 +72,6 @@ def __init__( self.mode = mode self.async_communication = async_communication - self.stream_manager = stream_manager if contiguous_chunks is not None: assert ( @@ -100,7 +97,6 @@ def forward( tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, - stream_manager=self.stream_manager, ) def extra_repr(self) -> str: @@ -119,7 +115,6 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -128,7 +123,6 @@ def __init__( self.in_features = in_features // self.world_size self.out_features = out_features - self.stream_manager = stream_manager # No need to shard the bias term, only rank 0 would have it bias = dist.get_rank(self.pg) == 0 and bias @@ -175,7 +169,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - stream_manager=self.stream_manager, ) def extra_repr(self) -> str: diff --git a/tests/test_domino.py b/tests/test_domino.py index e3182de4..fc28c192 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -9,7 +9,7 @@ from nanotron.models.llama import DominoLlamaDecoderLayer from nanotron.parallel import ParallelContext from nanotron.parallel.comm import CudaStreamManager -from nanotron.parallel.tensor_parallel.domino import OpNameContext, get_op_name, is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import OperationContext, get_op_name, is_domino_async_comm @pytest.mark.parametrize( @@ -74,12 +74,12 @@ def _test_domino_model( assert stream_manager.comm_bucket.is_all_completed() is True -### OpNameContext tests ### +### OperationContext tests ### def test_op_name_context_reentry(): assert get_op_name() is None - context = OpNameContext("reusable_op") + context = OperationContext("reusable_op") with context: assert get_op_name() == "reusable_op" @@ -93,10 +93,10 @@ def test_op_name_context_reentry(): def test_deeply_nested_contexts(): - with OpNameContext("level1"): + with OperationContext("level1"): assert get_op_name() == "level1" - with OpNameContext("level2"): + with OperationContext("level2"): assert get_op_name() == "level2" assert get_op_name() == "level1" @@ -105,10 +105,10 @@ def test_deeply_nested_contexts(): def test_multiple_sequential_contexts(): assert get_op_name() is None - with OpNameContext("first_op"): + with OperationContext("first_op"): assert get_op_name() == "first_op" - with OpNameContext("second_op"): + with OperationContext("second_op"): assert get_op_name() == "second_op" assert get_op_name() is None From 684b1b9001a7106cfa46adde3e8860096bd748fe Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 8 Mar 2025 10:51:57 +0000 Subject: [PATCH 32/40] small refactor --- src/nanotron/models/llama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 484f8321..82fd08d7 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -796,6 +796,8 @@ def _core_forward( sequence_mask=sequence_mask0, ) + # TODO: maybe try to bucket all the communication as in DPP, + # do it at at once with set_operation_context(FWD_ATTN_OP_NAME.format(self.layer_idx, 1), self.stream_manager): attn_output1 = self.attn( hidden_states=hidden_states1, @@ -843,6 +845,8 @@ def _core_forward( hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 + # TODO: make sure no memory overhead, + # and try a fixed memory buffer as in section 4.2 in the paper hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) return hidden_states, orig_sequence_mask From f8e8b1ff1a58bc066966be6a806b1c3eade051b7 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 10 Mar 2025 11:08:21 +0000 Subject: [PATCH 33/40] Reverting repository to commit 74d415c1c02b9463214fb46db060c0efbfa5a0e4 --- src/nanotron/models/llama.py | 23 ++--- src/nanotron/parallel/comm.py | 36 +++++++- .../distributed_differentiable_primitives.py | 18 ++-- .../parallel/tensor_parallel/domino.py | 87 +++---------------- .../parallel/tensor_parallel/functional.py | 7 +- src/nanotron/parallel/tensor_parallel/nn.py | 7 ++ tests/test_domino.py | 14 +-- 7 files changed, 84 insertions(+), 108 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 82fd08d7..0fb2af4d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -39,7 +39,7 @@ BWD_MLP_OP_NAME, FWD_ATTN_OP_NAME, FWD_MLP_OP_NAME, - set_operation_context, + OpNameContext, ) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -215,6 +215,7 @@ def __init__( config: LlamaConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() @@ -237,6 +238,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, + stream_manager=stream_manager, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -245,6 +247,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + stream_manager=stream_manager, ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -343,6 +346,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, + stream_manager: Optional[CudaStreamManager] = None, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -398,6 +402,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, + stream_manager=stream_manager, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -426,6 +431,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + stream_manager=stream_manager, ) self.attention = CoreAttention( @@ -718,10 +724,11 @@ def __init__( parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, + stream_manager=stream_manager, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, stream_manager=stream_manager) self.recompute_layer = parallel_config.recompute_layer self.parallel_config = parallel_config @@ -790,15 +797,13 @@ def _core_forward( self.stream_manager, ) - with set_operation_context(FWD_ATTN_OP_NAME.format(self.layer_idx, 0), self.stream_manager): + with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)): attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, ) - # TODO: maybe try to bucket all the communication as in DPP, - # do it at at once - with set_operation_context(FWD_ATTN_OP_NAME.format(self.layer_idx, 1), self.stream_manager): + with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)): attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, @@ -819,7 +824,7 @@ def _core_forward( self.stream_manager, ) - with set_operation_context(FWD_MLP_OP_NAME.format(self.layer_idx, 0), self.stream_manager): + with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 0)): mlp_output0 = self.mlp(hidden_states=hidden_states0) comm_stream.wait_stream(torch.cuda.default_stream()) @@ -832,7 +837,7 @@ def _core_forward( residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - with set_operation_context(FWD_MLP_OP_NAME.format(self.layer_idx, 1), self.stream_manager): + with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 1)): mlp_output1 = self.mlp(hidden_states=hidden_states1) comm_stream.wait_stream(torch.cuda.default_stream()) @@ -845,8 +850,6 @@ def _core_forward( hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 - # TODO: make sure no memory overhead, - # and try a fixed memory buffer as in section 4.2 in the paper hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) return hidden_states, orig_sequence_mask diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 551a8459..95495213 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -4,6 +4,7 @@ import torch from nanotron.constants import CUDA_STREAM_COMM_NAME +from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm class AsyncCommBucket: @@ -44,7 +45,7 @@ def is_all_completed(self) -> bool: not_finished = [] for k, v in self._copy_async_op.items(): - # assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" + assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" if v.is_completed() is not True: not_finished.append((k, v)) return len(not_finished) == 0 @@ -87,6 +88,38 @@ def run_on_stream(self, name: str): yield stream +class WaitComm(torch.autograd.Function): + """ + Enforce a tensor to wait for the communication operation to finish + in torch's autograd graph. + """ + + @staticmethod + def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream, comm_bucket: AsyncCommBucket): + assert isinstance(comm_stream, torch.cuda.Stream) + ctx.op_name = op_name + ctx.comm_stream = comm_stream + ctx.comm_bucket = comm_bucket + return input + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """ + NOTE: because the communication operation is already being executed + so the communication stream don't have to wait for the compute stream here + but the compute stream waits for the communication stream + before proceeding + """ + if is_domino_async_comm(ctx.op_name): + handle = ctx.comm_bucket.pop(ctx.op_name) + handle.wait() + + ctx.comm_stream.synchronize() + torch.cuda.default_stream().wait_stream(ctx.comm_stream) + + return grad_output, None, None, None + + def insert_backward_sync_to_tensor( tensor: torch.Tensor, op_name: str, stream_manager: CudaStreamManager ) -> torch.Tensor: @@ -94,7 +127,6 @@ def insert_backward_sync_to_tensor( Insert a wait communication operation of a given op_name to the autograd graph of a tensor. """ - from nanotron.parallel.tensor_parallel.domino import WaitComm assert isinstance(stream_manager, CudaStreamManager) comm_stream = stream_manager.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device())) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 2e87b681..e50d34ba 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -20,7 +20,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm +from nanotron.parallel.comm import CudaStreamManager, is_domino_async_comm class DifferentiableIdentity(torch.autograd.Function): @@ -32,10 +32,10 @@ def forward( tensor: torch.Tensor, group: Optional[ProcessGroup], op_name: Optional[str] = None, - stream_manager: Optional["CudaStreamManager"] = None, + stream_manager: Optional[CudaStreamManager] = None, ): - ctx.group = group ctx.op_name = op_name + ctx.group = group ctx.stream_manager = stream_manager return tensor @@ -49,6 +49,7 @@ def backward(ctx, grad_output: torch.Tensor): None, None, None, + None, ) @@ -61,7 +62,7 @@ def forward( tensor: torch.Tensor, group: Optional[ProcessGroup], op_name: Optional[int] = None, - stream_manager: Optional["CudaStreamManager"] = None, + stream_manager: Optional[CudaStreamManager] = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False ctx.async_all_reduce = async_all_reduce @@ -69,8 +70,7 @@ def forward( if group.size() == 1: return tensor - if async_all_reduce is True: - assert stream_manager is not None, f"op_name: {op_name}" + if stream_manager is not None: comm_stream = stream_manager.get_default_comm_stream() comm_stream.wait_stream(torch.cuda.default_stream()) comm_context = torch.cuda.stream(comm_stream) @@ -89,7 +89,7 @@ def forward( @staticmethod def backward(ctx, grad_output): - return grad_output, None, None, None + return grad_output, None, None, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -175,7 +175,7 @@ def differentiable_identity( tensor, group: Optional[ProcessGroup] = None, op_name: Optional[str] = None, - stream_manager: Optional["CudaStreamManager"] = None, + stream_manager: Optional[CudaStreamManager] = None, ): return DifferentiableIdentity.apply(tensor, group, op_name, stream_manager) @@ -184,7 +184,7 @@ def differentiable_all_reduce_sum( tensor, group: Optional[ProcessGroup] = None, op_name: Optional[str] = None, - stream_manager: Optional["CudaStreamManager"] = None, + stream_manager: Optional[CudaStreamManager] = None, ): return DifferentiableAllReduceSum.apply(tensor, group, op_name, stream_manager) diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 6c147214..d7a98f04 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -1,12 +1,7 @@ import re import threading -from contextlib import contextmanager from typing import Optional -import torch - -from nanotron.parallel.comm import AsyncCommBucket, CudaStreamManager - FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}" FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}" @@ -31,90 +26,28 @@ def is_domino_async_comm(x: str) -> bool: return not not_async -class OperationContext: +class OpNameContext: """ - Context manager that sets both operation name and stream manager in thread-local context. - - Args: - op_name: Name of the current operation - stream_manager: Associated CUDA stream manager + A context manager to set the name of a module operation """ - def __init__(self, op_name: str, stream_manager: CudaStreamManager): + def __init__(self, op_name: str): self.op_name = op_name - self.stream_manager = stream_manager - self._previous_op_name: Optional[str] = None - self._previous_stream_manager: Optional[CudaStreamManager] = None + self.previous_op_name = None def __enter__(self): - """Store current context and set new values""" - # Handle operation name - if not hasattr(_operation_context, "_current_op_name"): - _operation_context._current_op_name = None - self._previous_op_name = _operation_context._current_op_name - _operation_context._current_op_name = self.op_name - - # Handle stream manager - if not hasattr(_operation_context, "_current_stream_manager"): - _operation_context._current_stream_manager = None - self._previous_stream_manager = _operation_context._current_stream_manager - _operation_context._current_stream_manager = self.stream_manager - + if not hasattr(_operation_context, "current_op_name"): + _operation_context.current_op_name = None + self.previous_op_name = _operation_context.current_op_name + _operation_context.current_op_name = self.op_name return self def __exit__(self, exc_type, exc_val, exc_tb): - """Restore previous context values""" - _operation_context._current_op_name = self._previous_op_name - _operation_context._current_stream_manager = self._previous_stream_manager - - -class WaitComm(torch.autograd.Function): - """ - Enforce a tensor to wait for the communication operation to finish - in torch's autograd graph. - """ - - @staticmethod - def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream, comm_bucket: AsyncCommBucket): - assert isinstance(comm_stream, torch.cuda.Stream) - ctx.op_name = op_name - ctx.comm_stream = comm_stream - ctx.comm_bucket = comm_bucket - return input - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - """ - NOTE: because the communication operation is already being executed - so the communication stream don't have to wait for the compute stream here - but the compute stream waits for the communication stream - before proceeding - """ - if is_domino_async_comm(ctx.op_name): - handle = ctx.comm_bucket.pop(ctx.op_name) - handle.wait() - - ctx.comm_stream.synchronize() - torch.cuda.default_stream().wait_stream(ctx.comm_stream) - - return grad_output, None, None, None - - -@contextmanager -def set_operation_context(name: str, stream_manager: CudaStreamManager): - with OperationContext(name, stream_manager): - yield + _operation_context.current_op_name = self.previous_op_name def get_op_name() -> Optional[str]: """ Get the name of the current operation. """ - return getattr(_operation_context, "_current_op_name", None) - - -def get_stream_manager() -> Optional[CudaStreamManager]: - """ - Get the stream manager for the current operation. - """ - return getattr(_operation_context, "_current_stream_manager", None) + return getattr(_operation_context, "current_op_name", None) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 030d863e..03357160 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,12 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) -from nanotron.parallel.tensor_parallel.domino import get_op_name, get_stream_manager +from nanotron.parallel.tensor_parallel.domino import get_op_name from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -437,13 +438,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: op_name = get_op_name() - stream_manager = get_stream_manager() input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: @@ -591,6 +592,7 @@ def row_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -599,7 +601,6 @@ def row_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: op_name = get_op_name() - stream_manager = get_stream_manager() out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 5598df2f..f92fe0ee 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -19,6 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import get_global_rank +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.sharded_parameters import ( SplitConfig, @@ -52,6 +53,7 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, + stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -72,6 +74,7 @@ def __init__( self.mode = mode self.async_communication = async_communication + self.stream_manager = stream_manager if contiguous_chunks is not None: assert ( @@ -97,6 +100,7 @@ def forward( tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, + stream_manager=self.stream_manager, ) def extra_repr(self) -> str: @@ -115,6 +119,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -123,6 +128,7 @@ def __init__( self.in_features = in_features // self.world_size self.out_features = out_features + self.stream_manager = stream_manager # No need to shard the bias term, only rank 0 would have it bias = dist.get_rank(self.pg) == 0 and bias @@ -169,6 +175,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + stream_manager=self.stream_manager, ) def extra_repr(self) -> str: diff --git a/tests/test_domino.py b/tests/test_domino.py index fc28c192..e3182de4 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -9,7 +9,7 @@ from nanotron.models.llama import DominoLlamaDecoderLayer from nanotron.parallel import ParallelContext from nanotron.parallel.comm import CudaStreamManager -from nanotron.parallel.tensor_parallel.domino import OperationContext, get_op_name, is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import OpNameContext, get_op_name, is_domino_async_comm @pytest.mark.parametrize( @@ -74,12 +74,12 @@ def _test_domino_model( assert stream_manager.comm_bucket.is_all_completed() is True -### OperationContext tests ### +### OpNameContext tests ### def test_op_name_context_reentry(): assert get_op_name() is None - context = OperationContext("reusable_op") + context = OpNameContext("reusable_op") with context: assert get_op_name() == "reusable_op" @@ -93,10 +93,10 @@ def test_op_name_context_reentry(): def test_deeply_nested_contexts(): - with OperationContext("level1"): + with OpNameContext("level1"): assert get_op_name() == "level1" - with OperationContext("level2"): + with OpNameContext("level2"): assert get_op_name() == "level2" assert get_op_name() == "level1" @@ -105,10 +105,10 @@ def test_deeply_nested_contexts(): def test_multiple_sequential_contexts(): assert get_op_name() is None - with OperationContext("first_op"): + with OpNameContext("first_op"): assert get_op_name() == "first_op" - with OperationContext("second_op"): + with OpNameContext("second_op"): assert get_op_name() == "second_op" assert get_op_name() is None From 61ff007090de6221be540caaccadd33d245334aa Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 10 Mar 2025 11:11:30 +0000 Subject: [PATCH 34/40] add todos --- src/nanotron/models/llama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 0fb2af4d..93adbd2a 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -797,6 +797,8 @@ def _core_forward( self.stream_manager, ) + # TODO: maybe try to bucket all the communication as in DPP, + # do it at at once with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)): attn_output0 = self.attn( hidden_states=hidden_states0, @@ -850,6 +852,8 @@ def _core_forward( hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 + # TODO: make sure no memory overhead, + # and try a fixed memory buffer as in section 4.2 in the paper hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) return hidden_states, orig_sequence_mask From 9039ce24c6ee945b98f90bb5877b6fe8c719a2d5 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 10 Mar 2025 11:12:34 +0000 Subject: [PATCH 35/40] add todo --- src/nanotron/parallel/tensor_parallel/domino.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index d7a98f04..0a3f66e0 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -32,6 +32,7 @@ class OpNameContext: """ def __init__(self, op_name: str): + # TODO: support passing stream_manager as a part of an operation context self.op_name = op_name self.previous_op_name = None From 62fb3b2e67dd9e02acff58e3237d91977dad48cb Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 10 Mar 2025 11:29:41 +0000 Subject: [PATCH 36/40] add todos --- src/nanotron/config/parallelism_config.py | 1 + src/nanotron/models/llama.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 1a655f0c..6d880364 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -79,6 +79,7 @@ def __post_init__(self): assert ( self.tp_linear_async_communication is False ), "Domino requires TP linear async communication to be False" + # TODO: support REDUCE_SCATTER mode for Domino assert self.tp_mode == TensorParallelLinearMode.ALL_REDUCE, "Domino requires TP mode to be ALL_REDUCE" @property diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 93adbd2a..d00fb58e 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -811,6 +811,8 @@ def _core_forward( sequence_mask=sequence_mask1, ) + # TODO(xrsrke): double check if we need this explicit synchronization + # otherwise, remove it comm_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(comm_stream): comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)) From 7c7b6f7af831598fc4de5109f5c20d7cb60d1ef7 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 13 Mar 2025 14:51:56 +0000 Subject: [PATCH 37/40] add todo and undo torch_nn --- src/nanotron/models/llama.py | 3 ++- .../parallel/pipeline_parallel/engine.py | 16 +++++++------- .../parallel/tensor_parallel/domino.py | 22 +++++++++++++++++-- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index d00fb58e..72fd2440 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -701,7 +701,6 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - # output, work = self.o_proj(attention_output, op_name=op_name) output = self.o_proj(attention_output) return {"hidden_states": output, "sequence_mask": sequence_mask} @@ -784,6 +783,8 @@ def _core_forward( residual0 = hidden_states0 residual1 = hidden_states1 + # TODO: overlap the 'layernorm > attn' of the second batch + # with the comm of the first batch in both forward and backward hidden_states0 = self.input_layernorm(hidden_states0) hidden_states1 = self.input_layernorm(hidden_states1) hidden_states0 = insert_backward_sync_to_tensor( diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index fc4b63ca..076943c7 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,7 +2,7 @@ from typing import Dict, Iterable, Optional, Union import torch -from torch import nn +from torch import nn as torch_nn from torch.nn.parallel import DistributedDataParallel from nanotron import distributed as dist @@ -29,7 +29,7 @@ def forward( context: ContextManagers, state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], - model: nn.Module, + model: torch_nn.Module, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -59,7 +59,7 @@ def forward( return output @staticmethod - def _get_fwd_context(model: nn.Module): + def _get_fwd_context(model: torch_nn.Module): is_ddp = isinstance(model, DistributedDataParallel) # We never to trigger a DDP sync in the next backward pass context = ContextManagers([model.no_sync()] if is_ddp else []) @@ -97,7 +97,7 @@ def backward( def _get_bwd_context( self, - model: nn.Module, + model: torch_nn.Module, nb_backwards: int, grad_accumulator: Optional[GradientAccumulator], ): @@ -118,7 +118,7 @@ def _get_bwd_context( @abstractmethod def train_batch_iter( self, - model: nn.Module, + model: torch_nn.Module, pg: ProcessGroup, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, @@ -130,7 +130,7 @@ def train_batch_iter( @torch.inference_mode() def validate_batch_iter( self, - model: nn.Module, + model: torch_nn.Module, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: @@ -169,7 +169,7 @@ def __init__(self): def train_batch_iter( self, - model: nn.Module, + model: torch_nn.Module, pg: ProcessGroup, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, @@ -226,7 +226,7 @@ def __init__(self): def train_batch_iter( self, - model: nn.Module, + model: torch_nn.Module, pg: ProcessGroup, batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], nb_microbatches: int, diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 0a3f66e0..c5152112 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -1,3 +1,10 @@ +""" +Implementation of communication overlapping +in the paper "Domino: Eliminating Communication in LLM Training via +Generic Tensor Slicing and Overlapping" +https://arxiv.org/abs/2409.15241 +""" + import re import threading from typing import Optional @@ -13,10 +20,21 @@ def is_domino_async_comm(x: str) -> bool: """ Determine whether a module (e.g., mlp, attention) - performs all-reduce asynchronously in tensor parallelism + runs all-reduce asynchronously in tensor parallelism + based on its module name. + + Currently support intra-layer communication overlapping + as described in domino's input splitting approach. + + How do we determine it? + + In the forward pass: We run all the forward pass's communication asynchronously + diagram: https://imgur.com/a/g5Ou2iZ + + + In the backward pass: We run all backward pass's communication asynchronously + except for the first batch's attention module. + https://imgur.com/a/MrZb57a """ NON_ASYNC_HANDLE_IDX = [ - # "fwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] From dda052e31944c4f942506b7cb84b7c86c2ee64d1 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 1 Apr 2025 15:05:22 +0000 Subject: [PATCH 38/40] remove hardcoded operation names for backward passes in tp primitives and fix domino tests --- examples/config_domino.yaml | 108 ++++++++++++++++++ src/nanotron/models/llama.py | 20 +++- .../distributed_differentiable_primitives.py | 6 +- .../parallel/tensor_parallel/domino.py | 35 +++--- .../parallel/tensor_parallel/functional.py | 6 +- tests/helpers/llama_helper.py | 65 ++++++++--- tests/test_domino.py | 52 +++++---- 7 files changed, 227 insertions(+), 65 deletions(-) create mode 100644 examples/config_domino.yaml diff --git a/examples/config_domino.yaml b/examples/config_domino.yaml new file mode 100644 index 00000000..c60d9bf3 --- /dev/null +++ b/examples/config_domino.yaml @@ -0,0 +1,108 @@ +checkpoints: + checkpoint_interval: 10000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + load_lr_scheduler: false + load_optimizer: false + save_final_state: true + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: + - /fsx/loubna/datasets/llama_tokenized/other_sources/wiki + token_size_in_bytes: 4 + tokenizer_name: meta-llama/Llama-3.2-1B + vocab_size: 128256 + num_loading_workers: 8 + seed: 42 + name: Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: nanotron_domino + run: domino_config + seed: 6 + step: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.041666666666666664 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 4096 + num_attention_heads: 32 + num_hidden_layers: 15 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 2 + rms_norm_eps: 1.0e-05 + rope_interleaved: false + rope_scaling: + factor: 32.0 + high_freq_factor: 4.0 + low_freq_factor: 1.0 + original_max_position_embeddings: 4096 + rope_type: llama3 + rope_theta: 500000.0 + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.00005 + lr_decay_starting_step: 50000 + lr_decay_steps: 10000 + lr_decay_style: linear + lr_warmup_steps: 1000 + lr_warmup_style: linear + min_decay_lr: 0 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 1 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + recompute_layer: false + tp: 8 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + tp_recompute_allgather: false + domino: + num_input_batches: 2 +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Llama-3.2-1B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 2 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 8 + sequence_length: 4096 + train_steps: 15000 + val_check_interval: -1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 9bd1eced..4d7de8a6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -825,13 +825,19 @@ def _core_forward( # TODO: maybe try to bucket all the communication as in DPP, # do it at at once - with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)): + with OpNameContext( + fwd_op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0), + bwd_op_name=BWD_ATTN_OP_NAME.format(self.layer_idx, 0), + ): attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, ) - with OpNameContext(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)): + with OpNameContext( + fwd_op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1), + bwd_op_name=BWD_ATTN_OP_NAME.format(self.layer_idx, 1), + ): attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, @@ -854,7 +860,10 @@ def _core_forward( self.stream_manager, ) - with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 0)): + with OpNameContext( + fwd_op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0), + bwd_op_name=BWD_MLP_OP_NAME.format(self.layer_idx, 0), + ): mlp_output0 = self.mlp(hidden_states=hidden_states0) comm_stream.wait_stream(torch.cuda.default_stream()) @@ -867,7 +876,10 @@ def _core_forward( residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - with OpNameContext(FWD_MLP_OP_NAME.format(self.layer_idx, 1)): + with OpNameContext( + fwd_op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1), + bwd_op_name=BWD_MLP_OP_NAME.format(self.layer_idx, 1), + ): mlp_output1 = self.mlp(hidden_states=hidden_states1) comm_stream.wait_stream(torch.cuda.default_stream()) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index e50d34ba..a5e65258 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -21,6 +21,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup from nanotron.parallel.comm import CudaStreamManager, is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import get_current_bwd_op_name class DifferentiableIdentity(torch.autograd.Function): @@ -34,7 +35,7 @@ def forward( op_name: Optional[str] = None, stream_manager: Optional[CudaStreamManager] = None, ): - ctx.op_name = op_name + ctx.bwd_op_name = get_current_bwd_op_name() ctx.group = group ctx.stream_manager = stream_manager return tensor @@ -42,10 +43,9 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): group = ctx.group - op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else None return ( - DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.stream_manager), + DifferentiableAllReduceSum.apply(grad_output, group, ctx.bwd_op_name, ctx.stream_manager), None, None, None, diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index c5152112..3d0b4415 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -45,28 +45,29 @@ def is_domino_async_comm(x: str) -> bool: class OpNameContext: - """ - A context manager to set the name of a module operation - """ + """Track both forward and backward operation names""" - def __init__(self, op_name: str): - # TODO: support passing stream_manager as a part of an operation context - self.op_name = op_name - self.previous_op_name = None + def __init__(self, fwd_op_name: str, bwd_op_name: str): + self.fwd_op_name = fwd_op_name + self.bwd_op_name = bwd_op_name + self.prev_fwd = None + self.prev_bwd = None def __enter__(self): - if not hasattr(_operation_context, "current_op_name"): - _operation_context.current_op_name = None - self.previous_op_name = _operation_context.current_op_name - _operation_context.current_op_name = self.op_name + self.prev_fwd = getattr(_operation_context, "current_fwd_op_name", None) + self.prev_bwd = getattr(_operation_context, "current_bwd_op_name", None) + _operation_context.current_fwd_op_name = self.fwd_op_name + _operation_context.current_bwd_op_name = self.bwd_op_name return self def __exit__(self, exc_type, exc_val, exc_tb): - _operation_context.current_op_name = self.previous_op_name + _operation_context.current_fwd_op_name = self.prev_fwd + _operation_context.current_bwd_op_name = self.prev_bwd -def get_op_name() -> Optional[str]: - """ - Get the name of the current operation. - """ - return getattr(_operation_context, "current_op_name", None) +def get_current_fwd_op_name() -> Optional[str]: + return getattr(_operation_context, "current_fwd_op_name", None) + + +def get_current_bwd_op_name() -> Optional[str]: + return getattr(_operation_context, "current_bwd_op_name", None) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 7d430e38..9ed815a6 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -25,7 +25,7 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) -from nanotron.parallel.tensor_parallel.domino import get_op_name +from nanotron.parallel.tensor_parallel.domino import get_current_fwd_op_name from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer @@ -560,7 +560,7 @@ def column_linear( return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - op_name = get_op_name() + op_name = get_current_fwd_op_name() input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: @@ -714,7 +714,7 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - op_name = get_op_name() + op_name = get_current_fwd_op_name() out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/tests/helpers/llama_helper.py b/tests/helpers/llama_helper.py index 7334f857..13c66658 100644 --- a/tests/helpers/llama_helper.py +++ b/tests/helpers/llama_helper.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from nanotron.config import ( AdamWOptimizerArgs, @@ -20,6 +22,7 @@ from nanotron.config.config import PretrainDatasetsArgs from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.context import ParallelContext from nanotron.trainer import mark_tied_parameters @@ -47,15 +50,21 @@ ) -def get_llama_training_config(model_config: ModelArgs): - return Config( - model=model_config, - general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), - checkpoints=CheckpointsArgs( - checkpoints_path="./checkpoints", - checkpoint_interval=10, - ), - parallelism=ParallelismArgs( +def get_parallel_config(parallel_context: ParallelContext): + return ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + expert_parallel_size=parallel_context.expert_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + +def get_llama_training_config(model_config: ModelArgs, parallel_context: ParallelContext): + if parallel_context is None: + parallel_config = ParallelismArgs( dp=1, pp=1, tp=2, @@ -63,7 +72,18 @@ def get_llama_training_config(model_config: ModelArgs): pp_engine="1f1b", tp_mode="ALL_REDUCE", tp_linear_async_communication=False, + ) + else: + parallel_config = get_parallel_config(parallel_context) + + return Config( + model=model_config, + general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), + checkpoints=CheckpointsArgs( + checkpoints_path="./checkpoints", + checkpoint_interval=10, ), + parallelism=parallel_config, tokenizer=TokenizerArgs("gpt2"), optimizer=OptimizerArgs( optimizer_factory=AdamWOptimizerArgs( @@ -106,7 +126,11 @@ def get_llama_training_config(model_config: ModelArgs): def create_llama_from_config( - model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext + model_config: LlamaConfig, + device: torch.device, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, + stream_manager: Optional[CudaStreamManager] = None, ) -> LlamaForTraining: """ @@ -117,20 +141,25 @@ def create_llama_from_config( the model created will have random weights. """ - parallel_config = ParallelismArgs( - dp=parallel_context.data_parallel_size, - pp=parallel_context.pipeline_parallel_size, - tp=parallel_context.tensor_parallel_size, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) + if parallel_config is None: + parallel_config = ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + else: + parallel_config = parallel_config + model = build_model( model_builder=lambda: LlamaForTraining( config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, + stream_manager=stream_manager, ), parallel_context=parallel_context, dtype=torch.bfloat16, diff --git a/tests/test_domino.py b/tests/test_domino.py index e3182de4..843437a0 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -1,15 +1,21 @@ +import os from copy import deepcopy import pytest import torch -from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.llama_helper import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config from helpers.utils import init_distributed, rerun_if_address_is_in_use from nanotron.config import ModelArgs, RandomInit from nanotron.config.parallelism_config import DominoArgs from nanotron.models.llama import DominoLlamaDecoderLayer from nanotron.parallel import ParallelContext from nanotron.parallel.comm import CudaStreamManager -from nanotron.parallel.tensor_parallel.domino import OpNameContext, get_op_name, is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import ( + OpNameContext, + get_current_bwd_op_name, + get_current_fwd_op_name, + is_domino_async_comm, +) @pytest.mark.parametrize( @@ -18,7 +24,7 @@ ("fwd.layer_attn_1_batch_0", True), ("fwd.layer_attn_1_batch_1", True), ("fwd.layer_mlp_1_batch_0", True), - ("fwd.layer_mlp_1_batch_1", False), + ("fwd.layer_mlp_1_batch_1", True), ("bwd.layer_mlp_1_batch_1", True), ("bwd.layer_mlp_1_batch_0", True), ("bwd.layer_attn_1_batch_1", True), @@ -32,6 +38,7 @@ def test_is_domino_async_comm(op_name, expected): @pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)]) @rerun_if_address_is_in_use() def test_domino_model(tp: int, dp: int, pp: int): + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" BATCH_SIZE, SEQ_LEN = 10, 128 model_config = deepcopy(TINY_LLAMA_CONFIG) @@ -78,37 +85,42 @@ def _test_domino_model( def test_op_name_context_reentry(): - assert get_op_name() is None - context = OpNameContext("reusable_op") + assert get_current_fwd_op_name() is None + assert get_current_bwd_op_name() is None + context = OpNameContext(fwd_op_name="fwd.reusable_op", bwd_op_name="bwd.reusable_op") with context: - assert get_op_name() == "reusable_op" + assert get_current_fwd_op_name() == "fwd.reusable_op" + assert get_current_bwd_op_name() == "bwd.reusable_op" - assert get_op_name() is None + assert get_current_fwd_op_name() is None + assert get_current_bwd_op_name() is None with context: - assert get_op_name() == "reusable_op" + assert get_current_fwd_op_name() == "fwd.reusable_op" + assert get_current_bwd_op_name() == "bwd.reusable_op" - assert get_op_name() is None + assert get_current_fwd_op_name() is None + assert get_current_bwd_op_name() is None def test_deeply_nested_contexts(): - with OpNameContext("level1"): - assert get_op_name() == "level1" + with OpNameContext(fwd_op_name="fwd.level1", bwd_op_name="fwd.level1"): + assert get_current_fwd_op_name() == "fwd.level1" - with OpNameContext("level2"): - assert get_op_name() == "level2" + with OpNameContext(fwd_op_name="fwd.level2", bwd_op_name="fwd.level2"): + assert get_current_fwd_op_name() == "fwd.level2" - assert get_op_name() == "level1" + assert get_current_fwd_op_name() == "fwd.level1" def test_multiple_sequential_contexts(): - assert get_op_name() is None + assert get_current_fwd_op_name() is None - with OpNameContext("first_op"): - assert get_op_name() == "first_op" + with OpNameContext(fwd_op_name="fwd.first_op", bwd_op_name="bwd.first_op"): + assert get_current_fwd_op_name() == "fwd.first_op" - with OpNameContext("second_op"): - assert get_op_name() == "second_op" + with OpNameContext(fwd_op_name="fwd.second_op", bwd_op_name="bwd.second_op"): + assert get_current_fwd_op_name() == "fwd.second_op" - assert get_op_name() is None + assert get_current_fwd_op_name() is None From 9a428dcf76240a5fe7f924dd499f5f80c3b62914 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 1 Apr 2025 15:54:13 +0000 Subject: [PATCH 39/40] add config --- examples/{ => domino}/config_domino.yaml | 0 examples/domino/domino_config.yaml | 105 ----------------------- 2 files changed, 105 deletions(-) rename examples/{ => domino}/config_domino.yaml (100%) delete mode 100644 examples/domino/domino_config.yaml diff --git a/examples/config_domino.yaml b/examples/domino/config_domino.yaml similarity index 100% rename from examples/config_domino.yaml rename to examples/domino/config_domino.yaml diff --git a/examples/domino/domino_config.yaml b/examples/domino/domino_config.yaml deleted file mode 100644 index 07e27c37..00000000 --- a/examples/domino/domino_config.yaml +++ /dev/null @@ -1,105 +0,0 @@ -checkpoints: - checkpoint_interval: 10000 - checkpoints_path: checkpoints - checkpoints_path_is_shared_file_system: false - resume_checkpoint_path: null - load_lr_scheduler: false - load_optimizer: false - save_final_state: true - save_initial_state: false -data_stages: -- data: - dataset: - dataset_folder: - - /fsx/elie_bakouch/data/fw-edu-dedup - num_loading_workers: 0 - seed: 8 - name: stable phase - start_training_step: 1 -general: - benchmark_csv_path: null - consumed_train_samples: null - ignore_sanity_checks: true - project: nanotron_domino - run: domino_config - seed: 6 - step: null -logging: - iteration_step_info_interval: 1 - log_level: info - log_level_replica: info -model: - ddp_bucket_cap_mb: 25 - dtype: bfloat16 - init_method: - std: 0.041666666666666664 - make_vocab_size_divisible_by: 1 - model_config: - bos_token_id: 128000 - eos_token_id: 128001 - hidden_act: silu - hidden_size: 4096 - initializer_range: 0.02 - intermediate_size: 14336 - is_llama_config: true - max_position_embeddings: 4096 - num_attention_heads: 32 - num_hidden_layers: 15 - num_key_value_heads: 8 - pad_token_id: null - pretraining_tp: 2 - rms_norm_eps: 1.0e-05 - rope_interleaved: false - rope_scaling: - factor: 32.0 - high_freq_factor: 4.0 - low_freq_factor: 1.0 - original_max_position_embeddings: 4096 - rope_type: llama3 - rope_theta: 500000.0 - tie_word_embeddings: true - use_cache: true - vocab_size: 128256 -optimizer: - accumulate_grad_in_fp32: true - clip_grad: 1.0 - learning_rate_scheduler: - learning_rate: 0.00005 - lr_decay_starting_step: 50000 - lr_decay_steps: 10000 - lr_decay_style: linear - lr_warmup_steps: 1000 - lr_warmup_style: linear - min_decay_lr: 0 - optimizer_factory: - adam_beta1: 0.9 - adam_beta2: 0.95 - adam_eps: 1.0e-08 - name: adamW - torch_adam_is_fused: true - weight_decay: 0.01 - zero_stage: 1 -parallelism: - dp: 1 - expert_parallel_size: 1 - pp: 1 - pp_engine: 1f1b - recompute_layer: false - tp: 8 - tp_linear_async_communication: false - tp_mode: ALL_REDUCE - tp_recompute_allgather: false - domino: - num_input_batches: 2 -tokenizer: - tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Llama-3.2-3B - tokenizer_revision: null -tokens: - batch_accumulation_per_replica: 2 - limit_test_batches: 0 - limit_val_batches: 0 - micro_batch_size: 8 - sequence_length: 4096 - train_steps: 15000 - val_check_interval: -1 From 7521671e0e9fb6bce93bb6a404dbd7258709991f Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 1 Apr 2025 16:23:44 +0000 Subject: [PATCH 40/40] last changes --- examples/domino/config_domino.yaml | 2 +- src/nanotron/parallel/comm.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/domino/config_domino.yaml b/examples/domino/config_domino.yaml index c60d9bf3..9887ad26 100644 --- a/examples/domino/config_domino.yaml +++ b/examples/domino/config_domino.yaml @@ -47,7 +47,7 @@ model: is_llama_config: true max_position_embeddings: 4096 num_attention_heads: 32 - num_hidden_layers: 15 + num_hidden_layers: 20 num_key_value_heads: 8 pad_token_id: null pretraining_tp: 2 diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 95495213..d112076a 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -45,7 +45,6 @@ def is_all_completed(self) -> bool: not_finished = [] for k, v in self._copy_async_op.items(): - assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" if v.is_completed() is not True: not_finished.append((k, v)) return len(not_finished) == 0