From d7bf8be86a6c9ae1b20ece6f90fcccac57e4f438 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Wed, 29 Jan 2025 12:47:04 +0000 Subject: [PATCH 01/17] first draft of domino forward pass --- examples/config_tiny_llama_domino.yaml | 113 ++++++++++++++++++ src/nanotron/config/parallelism_config.py | 17 +++ src/nanotron/models/llama.py | 90 +++++++++++--- src/nanotron/optim/gradient_accumulator.py | 4 + .../distributed_differentiable_primitives.py | 21 ++-- .../parallel/tensor_parallel/functional.py | 12 +- src/nanotron/parallel/tensor_parallel/nn.py | 5 +- 7 files changed, 235 insertions(+), 27 deletions(-) create mode 100644 examples/config_tiny_llama_domino.yaml diff --git a/examples/config_tiny_llama_domino.yaml b/examples/config_tiny_llama_domino.yaml new file mode 100644 index 00000000..66e22dbd --- /dev/null +++ b/examples/config_tiny_llama_domino.yaml @@ -0,0 +1,113 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + # dp: 2 + # pp: 2 + dp: 1 + pp: 1 + tp: 2 + expert_parallel_size: 1 + pp_engine: 1f1b + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + domino: + num_input_batches: 2 +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 15 + val_check_interval: -1 diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 7f20ad99..2701bf9c 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -11,6 +11,22 @@ from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +@dataclass +class DominoArgs: + """ + Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping + https://arxiv.org/abs/2409.15241 + """ + + # NOTE: if the number of input batches is 1, + # it's equivalent to non-domino mode + # so if you want to enable domino mode, set this to > 1 + num_input_batches: int + + def __post_init__(self): + assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1" + + @dataclass class ParallelismArgs: """Arguments related to TP/PP/DP @@ -37,6 +53,7 @@ class ParallelismArgs: tp_recompute_allgather: bool = True expert_parallel_size: int = 1 + domino: Optional[DominoArgs] = None def __post_init__(self): # Conservative defaults diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..bc495624 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -237,13 +237,14 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) - hidden_states = self.down_proj(self.split_silu_mul(merged_states)) - return {"hidden_states": hidden_states} + hidden_states, work = self.down_proj(self.split_silu_mul(merged_states)) + return {"hidden_states": hidden_states, "work": work} class CoreAttention(nn.Module): @@ -335,6 +336,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, + async_all_reduce: bool = False, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -418,6 +420,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + async_all_reduce=async_all_reduce, ) self.attention = CoreAttention( @@ -687,9 +690,9 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output = self.o_proj(attention_output) + output, work = self.o_proj(attention_output) - return {"hidden_states": output, "sequence_mask": sequence_mask} + return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} class LlamaDecoderLayer(nn.Module): @@ -702,36 +705,80 @@ def __init__( ): super().__init__() self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, + async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) self.recompute_layer = parallel_config.recompute_layer + self.parallel_config = parallel_config def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> List[Union[torch.Tensor, TensorPointer]]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) - hidden_states = output["hidden_states"] - hidden_states = hidden_states + residual + num_input_batches = self.parallel_config.domino.num_input_batches + assert num_input_batches == 2 + hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) + orig_sequence_mask = sequence_mask + sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) + + hidden_states0, hidden_states1 = hidden_states + sequence_mask0, sequence_mask1 = sequence_mask + + # # Combine the chunks into a list of dictionaries + # hidden_encoder_states_list = [ + # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} + # for i in range(num_input_batches) + # ] + + residual0 = hidden_states0 + residual1 = hidden_states1 + + hidden_states0 = self.input_layernorm(hidden_states0) + hidden_states1 = self.input_layernorm(hidden_states1) + + attn_output0 = self.attn(hidden_states=hidden_states0, sequence_mask=sequence_mask0) + attn_output0_work = attn_output0["work"] + + attn_output1 = self.attn(hidden_states=hidden_states1, sequence_mask=sequence_mask1) + attn_output1_work = attn_output1["work"] + + attn_output0_work.wait() + hidden_states0 = attn_output0["hidden_states"] + hidden_states0 = hidden_states0 + residual0 + residual0 = hidden_states0 + hidden_states0 = self.post_attention_layernorm(hidden_states0) - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] - hidden_states = hidden_states + residual + mlp_output0 = self.mlp(hidden_states=hidden_states0) - return hidden_states, output["sequence_mask"] + attn_output1_work.wait() + hidden_states1 = attn_output1["hidden_states"] + hidden_states1 = hidden_states1 + residual1 + residual1 = hidden_states1 + hidden_states1 = self.post_attention_layernorm(hidden_states1) + + mlp_output1 = self.mlp(hidden_states=hidden_states1) + mlp_output0["work"].wait() + mlp_output1["work"].wait() + + hidden_states0 = mlp_output0["hidden_states"] + hidden_states1 = mlp_output1["hidden_states"] + + hidden_states0 = hidden_states0 + residual0 + hidden_states1 = hidden_states1 + residual1 + + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + return hidden_states, orig_sequence_mask def _checkpointed_forward( self, @@ -899,9 +946,24 @@ def forward_with_hidden_states( "hidden_states": output["input_embeds"], "sequence_mask": input_mask, } + + # assert 1 == 1 + # num_input_batches = self.parallel_config.domino.num_input_batches + # hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1) + # hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0) + + # # Combine the chunks into a list of dictionaries + # hidden_encoder_states_list = [ + # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} + # for i in range(num_input_batches) + # ] + for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + # for hidden_encoder_states in hidden_encoder_states_list: + # hidden_encoder_states = encoder_block(**hidden_encoder_states) + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 2e940744..ba0d5dd0 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -202,6 +202,10 @@ def build_grad_buffers( return fp32_grad_buffers, contiguous_buffer_f32_gradients def backward(self, loss: torch.Tensor): + if isinstance(loss, tuple): + assert 1 == 1 + raise NotImplementedError("Not implemented yet") + result = loss.backward() for name, elt in self.fp32_grad_buffers.items(): diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index bd41347a..0ae9a4de 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from torch import distributed as torch_dist @@ -32,23 +32,28 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllReduceSum.apply(grad_output, group), None + return DifferentiableAllReduceSum.apply(grad_output, group, False), None class DifferentiableAllReduceSum(torch.autograd.Function): """All-reduce in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward( + ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool + ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: if group.size() == 1: return tensor - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) - return tensor + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce) + if async_all_reduce: + return tensor, handle + else: + return tensor, None @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -134,8 +139,8 @@ def differentiable_identity(tensor, group: Optional[ProcessGroup] = None): return DifferentiableIdentity.apply(tensor, group) -def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableAllReduceSum.apply(tensor, group) +def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False): + return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index e2ee3a29..ffb3e3f9 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import Optional, Tuple import torch from torch.nn import functional as F @@ -587,18 +587,22 @@ def row_linear( bias: Optional[torch.Tensor], group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, + # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, -): + async_all_reduce: bool, +) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group) + out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) + work = None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") - return out + return out, work diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 4c7325cd..19cbdf88 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -111,6 +111,7 @@ def __init__( device=None, dtype=None, async_communication: bool = False, + async_all_reduce: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, ): self.pg = pg @@ -133,6 +134,7 @@ def __init__( ) self.mode = mode self.async_communication = async_communication + self.async_all_reduce = async_all_reduce if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") @@ -166,6 +168,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + async_all_reduce=self.async_all_reduce, ) def extra_repr(self) -> str: @@ -290,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: out = out * (~input_mask[..., None]) if self.mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=self.pg) + out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=self.pg) else: From 3803b1927e4063c2b329857dfbd0497778a235b1 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Thu, 30 Jan 2025 14:03:36 +0000 Subject: [PATCH 02/17] support the backward pass --- src/nanotron/optim/gradient_accumulator.py | 2 +- src/nanotron/parallel/comm.py | 26 +++++++++++++++++++ .../parallel/pipeline_parallel/engine.py | 8 ++++-- .../distributed_differentiable_primitives.py | 22 +++++++++++++--- .../parallel/tensor_parallel/functional.py | 10 ++++++- src/nanotron/parallel/tensor_parallel/nn.py | 2 +- 6 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 src/nanotron/parallel/comm.py diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index ba0d5dd0..b5ef7d89 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -202,7 +202,7 @@ def build_grad_buffers( return fp32_grad_buffers, contiguous_buffer_f32_gradients def backward(self, loss: torch.Tensor): - if isinstance(loss, tuple): + if not isinstance(loss, torch.Tensor): assert 1 == 1 raise NotImplementedError("Not implemented yet") diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py new file mode 100644 index 00000000..76e33e21 --- /dev/null +++ b/src/nanotron/parallel/comm.py @@ -0,0 +1,26 @@ +from typing import Dict + + +class AsyncCommBucket: + """ + + Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + RuntimeError: expected Variable or None (got tuple) + Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + RuntimeError: expected Variable or None (got tuple) + """ + + _async_op: Dict[int, "dist.Work"] = {} + + @staticmethod + def add(tensor_id: int, work: "dist.Work"): + AsyncCommBucket._async_op[tensor_id] = work + + @staticmethod + def get(tensor_id: int): + return AsyncCommBucket._async_op.get(tensor_id) + + @staticmethod + def wait(tensor_id: int): + work = AsyncCommBucket._async_op.pop(tensor_id) + work.wait() diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..8160f302 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -12,8 +15,6 @@ from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -83,6 +84,9 @@ def backward( if grad_accumulator is None: sum(activations).backward() else: + # if not isinstance(activations, torch.Tensor): + # raise NotImplementedError("Only support sum of tensors for now") + grad_accumulator.backward(sum(activations)) # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 0ae9a4de..05ade53d 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,6 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup +from nanotron.parallel.comm import AsyncCommBucket class DifferentiableIdentity(torch.autograd.Function): @@ -42,14 +43,29 @@ class DifferentiableAllReduceSum(torch.autograd.Function): def forward( ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: + # ctx.mark_non_differentiable(async_all_reduce) + ctx.async_all_reduce = async_all_reduce + if group.size() == 1: return tensor + orig_id = id(tensor) handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce) + # if async_all_reduce: + # handle.wait() + new_id = id(tensor) + assert 1 == 1 + assert orig_id == new_id + # if async_all_reduce: + # return tensor, handle + # else: + # return tensor, None if async_all_reduce: - return tensor, handle - else: - return tensor, None + # AsyncCommBucket.add(tensor, handle) + # AsyncCommBucket.add(id(tensor), handle) + AsyncCommBucket.add(orig_id, handle) + + return tensor @staticmethod def backward(ctx, grad_output): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index ffb3e3f9..b05a50bf 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -597,7 +597,15 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + orig_out_id = id(out) + # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? + out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + if async_all_reduce: + from nanotron.parallel.comm import AsyncCommBucket + + work = AsyncCommBucket.get(orig_out_id) + assert 1 == 1 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 19cbdf88..14a7486a 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -293,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: out = out * (~input_mask[..., None]) if self.mode is TensorParallelLinearMode.ALL_REDUCE: - out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) + out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=self.pg) else: From d765fd57e29a30b0a083012bed6e443a2df72b7f Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Fri, 31 Jan 2025 14:20:32 +0000 Subject: [PATCH 03/17] the first draft for bwd overlapping --- src/nanotron/constants.py | 3 + src/nanotron/helpers.py | 2 + src/nanotron/models/llama.py | 57 ++++++++++++++++--- src/nanotron/parallel/comm.py | 46 +++++++++++++++ .../distributed_differentiable_primitives.py | 27 ++++++--- .../parallel/tensor_parallel/functional.py | 6 +- src/nanotron/parallel/tensor_parallel/nn.py | 5 +- 7 files changed, 127 insertions(+), 19 deletions(-) diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99d..78fd0bb9 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -10,3 +10,6 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" + + +CUDA_STREAMS = {} diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 73ca3484..7f31d812 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -482,7 +482,9 @@ def get_profiler(config: Config): on_trace_ready=on_trace_ready, # record_shapes=True, # profile_memory=True, + with_flops=True, with_stack=True, + with_modules=True, ) else: prof = contextlib.nullcontext() diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index bc495624..fb112f8e 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,6 +30,7 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import WaitComm from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P @@ -46,6 +47,8 @@ logger = logging.get_logger(__name__) +DOMINO_COMM_STREAM = "domino_comm_stream_{}" + class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -241,8 +244,8 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states, handle_idx) hidden_states, work = self.down_proj(self.split_silu_mul(merged_states)) return {"hidden_states": hidden_states, "work": work} @@ -437,6 +440,7 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] + handle_idx=None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -445,7 +449,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states + hidden_states, handle_idx=handle_idx ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -720,6 +724,18 @@ def __init__( self.recompute_layer = parallel_config.recompute_layer self.parallel_config = parallel_config + # if parallel_config.domino is not None and parallel_config.domino.num_input_batches > 1: + # from nanotron.parallel.comm import CudaStreamManager + # # NOTE: we use different cuda streams for different gpus, so it can overlaps the communication + # CudaStreamManager.create(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) + num_gpus = torch.cuda.device_count() + for i in range(num_gpus): + from nanotron import constants + + constants.CUDA_STREAMS[i] = torch.cuda.Stream(device=torch.device(f"cuda:{i}")) + + self.layer_idx = layer_idx + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -747,29 +763,52 @@ def _core_forward( hidden_states0 = self.input_layernorm(hidden_states0) hidden_states1 = self.input_layernorm(hidden_states1) - attn_output0 = self.attn(hidden_states=hidden_states0, sequence_mask=sequence_mask0) + attn_output0 = self.attn( + hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=f"layer_{self.layer_idx}_batch_0" + ) attn_output0_work = attn_output0["work"] - attn_output1 = self.attn(hidden_states=hidden_states1, sequence_mask=sequence_mask1) + attn_output1 = self.attn( + hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=f"layer_{self.layer_idx}_batch_1" + ) attn_output1_work = attn_output1["work"] - attn_output0_work.wait() + from nanotron import constants + + comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] + # comm_stream = CudaStreamManager.get(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) + with torch.cuda.stream(comm_stream): + attn_output0_work.wait() + # attn_output0_work.wait() + hidden_states0 = attn_output0["hidden_states"] hidden_states0 = hidden_states0 + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) + hidden_states0 = WaitComm.apply(hidden_states0, f"layer_{self.layer_idx}_batch_0") + # mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_{self.layer_idx}_batch_0") mlp_output0 = self.mlp(hidden_states=hidden_states0) - attn_output1_work.wait() + with torch.cuda.stream(comm_stream): + attn_output1_work.wait() + # attn_output1_work.wait() + hidden_states1 = attn_output1["hidden_states"] hidden_states1 = hidden_states1 + residual1 residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) + hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") + # mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_{self.layer_idx}_batch_1") mlp_output1 = self.mlp(hidden_states=hidden_states1) - mlp_output0["work"].wait() - mlp_output1["work"].wait() + + with torch.cuda.stream(comm_stream): + mlp_output0["work"].wait() + mlp_output1["work"].wait() + + # mlp_output0["work"].wait() + # mlp_output1["work"].wait() hidden_states0 = mlp_output0["hidden_states"] hidden_states1 = mlp_output1["hidden_states"] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 76e33e21..e26e2814 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -1,5 +1,27 @@ +from contextlib import contextmanager from typing import Dict +import torch + + +class CudaStreamManager: + _streams: Dict[str, "torch.cuda.Stream"] = {} + + @staticmethod + def create(name: str): + assert name not in CudaStreamManager._streams + CudaStreamManager._streams[name] = torch.cuda.Stream() + + @staticmethod + def get(name: str): + return CudaStreamManager._streams.get(name) + + @contextmanager + def run_on_stream(name: str): + stream = CudaStreamManager.get(name) + with torch.cuda.stream(stream): + yield stream + class AsyncCommBucket: """ @@ -14,13 +36,37 @@ class AsyncCommBucket: @staticmethod def add(tensor_id: int, work: "dist.Work"): + assert ( + tensor_id not in AsyncCommBucket._async_op + ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}" AsyncCommBucket._async_op[tensor_id] = work @staticmethod def get(tensor_id: int): return AsyncCommBucket._async_op.get(tensor_id) + @staticmethod + def pop(tensor_id: int): + return AsyncCommBucket._async_op.pop(tensor_id) + @staticmethod def wait(tensor_id: int): work = AsyncCommBucket._async_op.pop(tensor_id) work.wait() + + +class WaitComm(torch.autograd.Function): + @staticmethod + def forward(ctx, input, wait_handle_idx): + ctx.wait_handle_idx = wait_handle_idx + return input + + @staticmethod + def backward(ctx, grad_output): + import pydevd + + pydevd.settrace(suspend=False, trace_only_current_thread=True) + if ctx.wait_handle_idx != "layer_1_batch_1": + handle = AsyncCommBucket.pop(ctx.wait_handle_idx) + handle.wait() + return grad_output, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 05ade53d..38c6bafd 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -26,14 +26,23 @@ class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward(ctx, tensor, group: Optional[ProcessGroup], handle_idx=None): + # assert handle_idx is not None + ctx.handle_idx = handle_idx ctx.group = group return tensor @staticmethod def backward(ctx, grad_output): + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async + # assert ctx.handle_idx is not None group = ctx.group - return DifferentiableAllReduceSum.apply(grad_output, group, False), None + if ctx.handle_idx is not None: + assert 1 == 1 + + return DifferentiableAllReduceSum.apply(grad_output, group, True, ctx.handle_idx), None, None class DifferentiableAllReduceSum(torch.autograd.Function): @@ -41,7 +50,7 @@ class DifferentiableAllReduceSum(torch.autograd.Function): @staticmethod def forward( - ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool + ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: # ctx.mark_non_differentiable(async_all_reduce) ctx.async_all_reduce = async_all_reduce @@ -63,13 +72,17 @@ def forward( if async_all_reduce: # AsyncCommBucket.add(tensor, handle) # AsyncCommBucket.add(id(tensor), handle) - AsyncCommBucket.add(orig_id, handle) + # try: + # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) + # except Exception as e: + # assert 1 == 1 + AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) return tensor @staticmethod def backward(ctx, grad_output): - return grad_output, None, None + return grad_output, None, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -151,8 +164,8 @@ def backward(ctx, grad_output): # ----------------- -def differentiable_identity(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableIdentity.apply(tensor, group) +def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, handle_idx=None): + return DifferentiableIdentity.apply(tensor, group, handle_idx) def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index b05a50bf..ff43c98b 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -436,12 +436,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + handle_idx: Optional[int] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group) + input = differentiable_identity(input, group=group, handle_idx=handle_idx) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -604,7 +605,8 @@ def row_linear( if async_all_reduce: from nanotron.parallel.comm import AsyncCommBucket - work = AsyncCommBucket.get(orig_out_id) + # work = AsyncCommBucket.get(orig_out_id) + work = AsyncCommBucket.pop(orig_out_id) assert 1 == 1 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 14a7486a..f4ceff63 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -52,6 +52,7 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, + # handle_idx: Optional[int] = None, ): self.pg = pg self.world_size = pg.size() @@ -72,6 +73,7 @@ def __init__( self.mode = mode self.async_communication = async_communication + # self.handle_idx = handle_idx if contiguous_chunks is not None: assert ( @@ -85,7 +87,7 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -94,6 +96,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, + handle_idx=handle_idx, ) def extra_repr(self) -> str: From 9924608fccfddcf6bb87548610653407bc581ef5 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Mon, 3 Feb 2025 09:29:03 +0000 Subject: [PATCH 04/17] add backward pass overlapping --- src/nanotron/models/llama.py | 21 +++++++++++++-------- src/nanotron/parallel/comm.py | 12 +++++++++--- src/nanotron/trainer.py | 4 ++++ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index fb112f8e..8dbcc661 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -764,12 +764,16 @@ def _core_forward( hidden_states1 = self.input_layernorm(hidden_states1) attn_output0 = self.attn( - hidden_states=hidden_states0, sequence_mask=sequence_mask0, handle_idx=f"layer_{self.layer_idx}_batch_0" + hidden_states=hidden_states0, + sequence_mask=sequence_mask0, + handle_idx=f"layer_attn_{self.layer_idx}_batch_0", ) attn_output0_work = attn_output0["work"] attn_output1 = self.attn( - hidden_states=hidden_states1, sequence_mask=sequence_mask1, handle_idx=f"layer_{self.layer_idx}_batch_1" + hidden_states=hidden_states1, + sequence_mask=sequence_mask1, + handle_idx=f"layer_attn_{self.layer_idx}_batch_1", ) attn_output1_work = attn_output1["work"] @@ -785,10 +789,11 @@ def _core_forward( hidden_states0 = hidden_states0 + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - hidden_states0 = WaitComm.apply(hidden_states0, f"layer_{self.layer_idx}_batch_0") + hidden_states0 = WaitComm.apply(hidden_states0, f"layer_mlp_{self.layer_idx}_batch_1") - # mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_{self.layer_idx}_batch_0") - mlp_output0 = self.mlp(hidden_states=hidden_states0) + mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_mlp_{self.layer_idx}_batch_0") + mlp_output0 = WaitComm.apply(mlp_output0, f"layer_mlp_{self.layer_idx}_batch_1") + # mlp_output0 = self.mlp(hidden_states=hidden_states0) with torch.cuda.stream(comm_stream): attn_output1_work.wait() @@ -798,10 +803,10 @@ def _core_forward( hidden_states1 = hidden_states1 + residual1 residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") + # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") - # mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_{self.layer_idx}_batch_1") - mlp_output1 = self.mlp(hidden_states=hidden_states1) + mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_mlp_{self.layer_idx}_batch_1") + # mlp_output1 = self.mlp(hidden_states=hidden_states1) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index e26e2814..dd99f20a 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -54,6 +54,10 @@ def wait(tensor_id: int): work = AsyncCommBucket._async_op.pop(tensor_id) work.wait() + @staticmethod + def clear_all(): + AsyncCommBucket._async_op.clear() + class WaitComm(torch.autograd.Function): @staticmethod @@ -63,10 +67,12 @@ def forward(ctx, input, wait_handle_idx): @staticmethod def backward(ctx, grad_output): - import pydevd + # import pydevd - pydevd.settrace(suspend=False, trace_only_current_thread=True) - if ctx.wait_handle_idx != "layer_1_batch_1": + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + # if ctx.wait_handle_idx != "layer_1_batch_1": + if ctx.wait_handle_idx != "layer_30_batch_1": handle = AsyncCommBucket.pop(ctx.wait_handle_idx) handle.wait() + # assert 1 == 1 return grad_output, None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 94b03c6e..52d5df48 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -578,6 +578,10 @@ def training_step( self.post_train_step() + from nanotron.parallel.comm import AsyncCommBucket + + AsyncCommBucket.clear_all() + return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: From d6bc8da4a5df9f35b3345148639a4775b90de568 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Tue, 4 Feb 2025 14:32:38 +0000 Subject: [PATCH 05/17] fix some ops dont execute in the bwd pass --- examples/config_tiny_llama_domino.yaml | 113 ------------------ src/nanotron/models/llama.py | 49 ++++++-- src/nanotron/parallel/comm.py | 23 +++- .../distributed_differentiable_primitives.py | 84 ++++++++----- .../parallel/tensor_parallel/functional.py | 14 ++- src/nanotron/parallel/tensor_parallel/nn.py | 6 +- 6 files changed, 129 insertions(+), 160 deletions(-) delete mode 100644 examples/config_tiny_llama_domino.yaml diff --git a/examples/config_tiny_llama_domino.yaml b/examples/config_tiny_llama_domino.yaml deleted file mode 100644 index 66e22dbd..00000000 --- a/examples/config_tiny_llama_domino.yaml +++ /dev/null @@ -1,113 +0,0 @@ -checkpoints: - checkpoint_interval: 10 - checkpoints_path: checkpoints - checkpoints_path_is_shared_file_system: false - resume_checkpoint_path: null - save_initial_state: false -data_stages: -- data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 - name: Stable Training Stage - start_training_step: 1 -- data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 - name: Annealing Phase - start_training_step: 10 -general: - benchmark_csv_path: null - consumed_train_samples: null - ignore_sanity_checks: true - project: debug - run: tiny_llama_%date_%jobid - seed: 42 - step: null -lighteval: null -logging: - iteration_step_info_interval: 1 - log_level: info - log_level_replica: info -model: - ddp_bucket_cap_mb: 25 - dtype: bfloat16 - init_method: - std: 0.025 - make_vocab_size_divisible_by: 1 - model_config: - bos_token_id: 1 - eos_token_id: 2 - hidden_act: silu - hidden_size: 16 - initializer_range: 0.02 - intermediate_size: 64 - is_llama_config: true - max_position_embeddings: 256 - num_attention_heads: 4 - num_hidden_layers: 2 - num_key_value_heads: 4 - pad_token_id: null - pretraining_tp: 1 - rms_norm_eps: 1.0e-05 - rope_scaling: null - tie_word_embeddings: true - use_cache: true - vocab_size: 256 -optimizer: - accumulate_grad_in_fp32: true - clip_grad: 1.0 - learning_rate_scheduler: - learning_rate: 0.0003 - lr_decay_starting_step: null - lr_decay_steps: 13 - lr_decay_style: cosine - lr_warmup_steps: 2 - lr_warmup_style: linear - min_decay_lr: 1.0e-05 - optimizer_factory: - adam_beta1: 0.9 - adam_beta2: 0.95 - adam_eps: 1.0e-08 - name: adamW - torch_adam_is_fused: true - weight_decay: 0.01 - zero_stage: 0 -parallelism: - # dp: 2 - # pp: 2 - dp: 1 - pp: 1 - tp: 2 - expert_parallel_size: 1 - pp_engine: 1f1b - tp_linear_async_communication: false - tp_mode: ALL_REDUCE - domino: - num_input_batches: 2 -profiler: null -tokenizer: - tokenizer_max_length: null - tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel - tokenizer_revision: null -tokens: - batch_accumulation_per_replica: 1 - limit_test_batches: 0 - limit_val_batches: 0 - micro_batch_size: 2 - sequence_length: 256 - train_steps: 15 - val_check_interval: -1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 8dbcc661..47394240 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -49,6 +49,11 @@ DOMINO_COMM_STREAM = "domino_comm_stream_{}" +FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" +BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" +FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" +BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" + class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -245,8 +250,8 @@ def __init__( self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states, handle_idx) - hidden_states, work = self.down_proj(self.split_silu_mul(merged_states)) + merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx) + hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx) return {"hidden_states": hidden_states, "work": work} @@ -449,7 +454,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, handle_idx=handle_idx + hidden_states, async_all_reduce=True, handle_idx=handle_idx ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -694,7 +699,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output) + output, work = self.o_proj(attention_output, handle_idx=handle_idx) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -766,14 +771,26 @@ def _core_forward( attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, - handle_idx=f"layer_attn_{self.layer_idx}_batch_0", + # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_0", + handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + ) + attn_output0["hidden_states"] = WaitComm.apply( + attn_output0["hidden_states"], + # f"bwd.layer_attn_{self.layer_idx}_batch_1" + BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), ) attn_output0_work = attn_output0["work"] attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, - handle_idx=f"layer_attn_{self.layer_idx}_batch_1", + # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_1", + handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + ) + attn_output1["hidden_states"] = WaitComm.apply( + attn_output1["hidden_states"], + # f"bwd.layer_mlp_{self.layer_idx}_batch_0" + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) attn_output1_work = attn_output1["work"] @@ -789,10 +806,18 @@ def _core_forward( hidden_states0 = hidden_states0 + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - hidden_states0 = WaitComm.apply(hidden_states0, f"layer_mlp_{self.layer_idx}_batch_1") + # hidden_states0 = WaitComm.apply(hidden_states0, f"bwd.layer_mlp_{self.layer_idx}_batch_0") - mlp_output0 = self.mlp(hidden_states=hidden_states0, handle_idx=f"layer_mlp_{self.layer_idx}_batch_0") - mlp_output0 = WaitComm.apply(mlp_output0, f"layer_mlp_{self.layer_idx}_batch_1") + mlp_output0 = self.mlp( + hidden_states=hidden_states0, + # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_0" + handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + ) + mlp_output0["hidden_states"] = WaitComm.apply( + mlp_output0["hidden_states"], + # f"bwd.layer_mlp_{self.layer_idx}_batch_1" + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + ) # mlp_output0 = self.mlp(hidden_states=hidden_states0) with torch.cuda.stream(comm_stream): @@ -805,7 +830,11 @@ def _core_forward( hidden_states1 = self.post_attention_layernorm(hidden_states1) # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") - mlp_output1 = self.mlp(hidden_states=hidden_states1, handle_idx=f"layer_mlp_{self.layer_idx}_batch_1") + mlp_output1 = self.mlp( + hidden_states=hidden_states1, + # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_1" + handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + ) # mlp_output1 = self.mlp(hidden_states=hidden_states1) with torch.cuda.stream(comm_stream): diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index dd99f20a..c2b18e05 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -47,6 +47,7 @@ def get(tensor_id: int): @staticmethod def pop(tensor_id: int): + assert tensor_id in AsyncCommBucket._async_op, f"tensor_id: {tensor_id}" return AsyncCommBucket._async_op.pop(tensor_id) @staticmethod @@ -59,6 +60,17 @@ def clear_all(): AsyncCommBucket._async_op.clear() +def is_async_comm(x): + import re + + NON_ASYNC_HANDLE_IDX = ["bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0"] + + patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers + regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex + not_async = bool(regex.match(x)) + return not not_async + + class WaitComm(torch.autograd.Function): @staticmethod def forward(ctx, input, wait_handle_idx): @@ -68,11 +80,16 @@ def forward(ctx, input, wait_handle_idx): @staticmethod def backward(ctx, grad_output): # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - # if ctx.wait_handle_idx != "layer_1_batch_1": - if ctx.wait_handle_idx != "layer_30_batch_1": + + if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: + assert 1 == 1 + + # if ctx.wait_handle_idx != "bwd.layer_mlp_1_batch_1": + # if ctx.wait_handle_idx != "layer_30_batch_1": + if is_async_comm(ctx.wait_handle_idx): handle = AsyncCommBucket.pop(ctx.wait_handle_idx) + assert handle is not None handle.wait() # assert 1 == 1 return grad_output, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 38c6bafd..3badfb34 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -26,8 +26,9 @@ class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup], handle_idx=None): + def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None): # assert handle_idx is not None + ctx.async_all_reduce = async_all_reduce ctx.handle_idx = handle_idx ctx.group = group return tensor @@ -39,10 +40,31 @@ def backward(ctx, grad_output): # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async # assert ctx.handle_idx is not None group = ctx.group - if ctx.handle_idx is not None: - assert 1 == 1 - return DifferentiableAllReduceSum.apply(grad_output, group, True, ctx.handle_idx), None, None + if ctx.handle_idx is not None and "fwd." in ctx.handle_idx: + handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") + # if "bwd.layer_mlp_1_batch_1" == handle_idx: + # from nanotron.parallel.comm import is_async_comm + # async_all_reduce = is_async_comm(handle_idx) + # else: + # async_all_reduce = ctx.async_all_reduce + from nanotron.parallel.comm import is_async_comm + + async_all_reduce = is_async_comm(handle_idx) + else: + handle_idx = ctx.handle_idx + async_all_reduce = ctx.async_all_reduce + + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None + + +def is_last_batch_of_attn(x): + import re + + pattern = r"layer_attn_\d+_batch_0" + if re.match(pattern, x): + return True + return False class DifferentiableAllReduceSum(torch.autograd.Function): @@ -52,31 +74,33 @@ class DifferentiableAllReduceSum(torch.autograd.Function): def forward( ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: - # ctx.mark_non_differentiable(async_all_reduce) ctx.async_all_reduce = async_all_reduce if group.size() == 1: return tensor - orig_id = id(tensor) - handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce) - # if async_all_reduce: - # handle.wait() - new_id = id(tensor) - assert 1 == 1 - assert orig_id == new_id - # if async_all_reduce: - # return tensor, handle - # else: - # return tensor, None - if async_all_reduce: - # AsyncCommBucket.add(tensor, handle) - # AsyncCommBucket.add(id(tensor), handle) - # try: - # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) - # except Exception as e: - # assert 1 == 1 - AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) + if handle_idx == "bwd.layer_mlp_1_batch_0": + assert 1 == 1 + + id(tensor) + if async_all_reduce is True: + if isinstance(handle_idx, str): + do_async = is_last_batch_of_attn(handle_idx) is False + else: + do_async = async_all_reduce + + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) + if do_async: + # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx + # if handle_idx is not None and "bwd." in handle_idx: + # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) + # else: + # AsyncCommBucket.add(orig_id, handle) + # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx + assert handle_idx is not None + AsyncCommBucket.add(handle_idx, handle) + else: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) return tensor @@ -164,12 +188,16 @@ def backward(ctx, grad_output): # ----------------- -def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, handle_idx=None): - return DifferentiableIdentity.apply(tensor, group, handle_idx) +def differentiable_identity( + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None +): + return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx) -def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False): - return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce) +def differentiable_all_reduce_sum( + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None +): + return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index ff43c98b..a3f1248e 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -436,13 +436,14 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + async_all_reduce: bool = False, handle_idx: Optional[int] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group, handle_idx=handle_idx) + input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -591,6 +592,7 @@ def row_linear( # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, async_all_reduce: bool, + handle_idx=None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -599,14 +601,18 @@ def row_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) - orig_out_id = id(out) + id(out) # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? - out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) + out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) if async_all_reduce: from nanotron.parallel.comm import AsyncCommBucket # work = AsyncCommBucket.get(orig_out_id) - work = AsyncCommBucket.pop(orig_out_id) + # work = AsyncCommBucket.pop(orig_out_id) + if handle_idx == "fwd.layer_mlp_1_batch_0": + assert 1 == 1 + + work = AsyncCommBucket.pop(handle_idx) assert 1 == 1 elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index f4ceff63..847454fd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -87,7 +87,7 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -96,6 +96,7 @@ def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, + async_all_reduce=async_all_reduce, handle_idx=handle_idx, ) @@ -163,7 +164,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -172,6 +173,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, async_all_reduce=self.async_all_reduce, + handle_idx=handle_idx, ) def extra_repr(self) -> str: From 93b2f106bb91e4f9328546d77dfc169348eb12b9 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Wed, 5 Feb 2025 10:49:03 +0000 Subject: [PATCH 06/17] fix can't find an ops in fwd --- src/nanotron/parallel/comm.py | 8 +++++++- .../distributed_differentiable_primitives.py | 11 +++++++---- src/nanotron/parallel/tensor_parallel/functional.py | 6 +++++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index c2b18e05..459ea23d 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -63,7 +63,13 @@ def clear_all(): def is_async_comm(x): import re - NON_ASYNC_HANDLE_IDX = ["bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0"] + NON_ASYNC_HANDLE_IDX = [ + # "fwd.layer_attn_{}_batch_0", + # "fwd.layer_mlp_{}_batch_0", + # "fwd.layer_mlp_{}_batch_1", + "bwd.layer_mlp_{}_batch_1", + "bwd.layer_attn_{}_batch_0", + ] patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 3badfb34..9d65878b 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -84,10 +84,13 @@ def forward( id(tensor) if async_all_reduce is True: - if isinstance(handle_idx, str): - do_async = is_last_batch_of_attn(handle_idx) is False - else: - do_async = async_all_reduce + # if isinstance(handle_idx, str): + # do_async = is_last_batch_of_attn(handle_idx) is False + # else: + # do_async = async_all_reduce + from nanotron.parallel.comm import is_async_comm + + do_async = is_async_comm(handle_idx) handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) if do_async: diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index a3f1248e..f0ca3a0d 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -603,13 +603,17 @@ def row_linear( # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) id(out) # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? + if handle_idx == "fwd.layer_attn_0_batch_0": + assert 1 == 1 + out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) if async_all_reduce: from nanotron.parallel.comm import AsyncCommBucket # work = AsyncCommBucket.get(orig_out_id) # work = AsyncCommBucket.pop(orig_out_id) - if handle_idx == "fwd.layer_mlp_1_batch_0": + # if handle_idx == "fwd.layer_mlp_1_batch_0": + if handle_idx == "fwd.layer_attn_0_batch_0": assert 1 == 1 work = AsyncCommBucket.pop(handle_idx) From 31db05dafaf3190c1e3c33b4eb0b87cb9e5f0f04 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Wed, 5 Feb 2025 16:49:44 +0000 Subject: [PATCH 07/17] partially overlapping bwd pass --- src/nanotron/models/llama.py | 166 ++++++++++++------ src/nanotron/parallel/comm.py | 5 +- .../distributed_differentiable_primitives.py | 6 + src/nanotron/trainer.py | 3 +- 4 files changed, 123 insertions(+), 57 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 47394240..acbece96 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -50,9 +50,9 @@ DOMINO_COMM_STREAM = "domino_comm_stream_{}" FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" -BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" +BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" class RotaryEmbedding(nn.Module): @@ -741,116 +741,176 @@ def __init__( self.layer_idx = layer_idx + # def _core_forward( + # self, + # hidden_states: Union[torch.Tensor, TensorPointer], + # sequence_mask: Union[torch.Tensor, TensorPointer], + # ) -> List[Union[torch.Tensor, TensorPointer]]: + # from nanotron import constants + + # num_input_batches = self.parallel_config.domino.num_input_batches + # orig_sequence_mask = sequence_mask + + # assert num_input_batches == 2 + # hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) + # sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) + + # hidden_states0, hidden_states1 = hidden_states + # sequence_mask0, sequence_mask1 = sequence_mask + + # residual0 = hidden_states0 + # residual1 = hidden_states1 + + # hidden_states0 = self.input_layernorm(hidden_states0) + # hidden_states1 = self.input_layernorm(hidden_states1) + + # attn_output0 = self.attn( + # hidden_states=hidden_states0, + # sequence_mask=sequence_mask0, + # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + # ) + # # attn_output0["hidden_states"] = WaitComm.apply( + # # attn_output0["hidden_states"], + # # BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + # # ) + + # attn_output1 = self.attn( + # hidden_states=hidden_states1, + # sequence_mask=sequence_mask1, + # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + # ) + # # attn_output1["hidden_states"] = WaitComm.apply( + # # attn_output1["hidden_states"], + # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + # # ) + + # comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] + # with torch.cuda.stream(comm_stream): + # attn_output0["work"].wait() + + # hidden_states0 = attn_output0["hidden_states"] + residual0 + # residual0 = hidden_states0 + # hidden_states0 = self.post_attention_layernorm(hidden_states0) + # hidden_states0 = WaitComm.apply( + # hidden_states0, + # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + # ) # new + + # mlp_output0 = self.mlp( + # hidden_states=hidden_states0, + # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + # ) + # # mlp_output0["hidden_states"] = WaitComm.apply( + # # mlp_output0["hidden_states"], + # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + # # ) + + # with torch.cuda.stream(comm_stream): + # attn_output1["work"].wait() + + # hidden_states1 = attn_output1["hidden_states"] + residual1 + # residual1 = hidden_states1 + # hidden_states1 = self.post_attention_layernorm(hidden_states1) + # hidden_states1 = WaitComm.apply( + # hidden_states1, + # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + # ) + + # mlp_output1 = self.mlp( + # hidden_states=hidden_states1, + # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + # ) + + # with torch.cuda.stream(comm_stream): + # mlp_output0["work"].wait() + # mlp_output1["work"].wait() + + # hidden_states0 = mlp_output0["hidden_states"] + residual0 + # hidden_states1 = mlp_output1["hidden_states"] + residual1 + + # hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + # return hidden_states, orig_sequence_mask + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> List[Union[torch.Tensor, TensorPointer]]: + from nanotron import constants num_input_batches = self.parallel_config.domino.num_input_batches + orig_sequence_mask = sequence_mask + assert num_input_batches == 2 hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) - orig_sequence_mask = sequence_mask sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) hidden_states0, hidden_states1 = hidden_states sequence_mask0, sequence_mask1 = sequence_mask - # # Combine the chunks into a list of dictionaries - # hidden_encoder_states_list = [ - # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} - # for i in range(num_input_batches) - # ] - residual0 = hidden_states0 residual1 = hidden_states1 hidden_states0 = self.input_layernorm(hidden_states0) hidden_states1 = self.input_layernorm(hidden_states1) + hidden_states0 = WaitComm.apply( + hidden_states0, + BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + ) + hidden_states1 = WaitComm.apply( + hidden_states1, + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + ) attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, - # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_0", handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), ) - attn_output0["hidden_states"] = WaitComm.apply( - attn_output0["hidden_states"], - # f"bwd.layer_attn_{self.layer_idx}_batch_1" - BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), - ) - attn_output0_work = attn_output0["work"] - attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, - # handle_idx=f"fwd.layer_attn_{self.layer_idx}_batch_1", handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), ) - attn_output1["hidden_states"] = WaitComm.apply( - attn_output1["hidden_states"], - # f"bwd.layer_mlp_{self.layer_idx}_batch_0" - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - ) - attn_output1_work = attn_output1["work"] - - from nanotron import constants comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] - # comm_stream = CudaStreamManager.get(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) with torch.cuda.stream(comm_stream): - attn_output0_work.wait() - # attn_output0_work.wait() + attn_output0["work"].wait() - hidden_states0 = attn_output0["hidden_states"] - hidden_states0 = hidden_states0 + residual0 + hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - # hidden_states0 = WaitComm.apply(hidden_states0, f"bwd.layer_mlp_{self.layer_idx}_batch_0") + hidden_states0 = WaitComm.apply( + hidden_states0, + BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + ) # new mlp_output0 = self.mlp( hidden_states=hidden_states0, - # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_0" handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) - mlp_output0["hidden_states"] = WaitComm.apply( - mlp_output0["hidden_states"], - # f"bwd.layer_mlp_{self.layer_idx}_batch_1" - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - ) - # mlp_output0 = self.mlp(hidden_states=hidden_states0) with torch.cuda.stream(comm_stream): - attn_output1_work.wait() - # attn_output1_work.wait() + attn_output1["work"].wait() - hidden_states1 = attn_output1["hidden_states"] - hidden_states1 = hidden_states1 + residual1 + hidden_states1 = attn_output1["hidden_states"] + residual1 residual1 = hidden_states1 hidden_states1 = self.post_attention_layernorm(hidden_states1) - # hidden_states1 = WaitComm.apply(hidden_states1, f"layer_{self.layer_idx}_batch_1") mlp_output1 = self.mlp( hidden_states=hidden_states1, - # handle_idx=f"fwd.layer_mlp_{self.layer_idx}_batch_1" handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), ) - # mlp_output1 = self.mlp(hidden_states=hidden_states1) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() mlp_output1["work"].wait() - # mlp_output0["work"].wait() - # mlp_output1["work"].wait() - - hidden_states0 = mlp_output0["hidden_states"] - hidden_states1 = mlp_output1["hidden_states"] - - hidden_states0 = hidden_states0 + residual0 - hidden_states1 = hidden_states1 + residual1 + hidden_states0 = mlp_output0["hidden_states"] + residual0 + hidden_states1 = mlp_output1["hidden_states"] + residual1 hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + assert 1 == 1 return hidden_states, orig_sequence_mask def _checkpointed_forward( diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 459ea23d..b00f6e9e 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -91,11 +91,10 @@ def backward(ctx, grad_output): if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: assert 1 == 1 - # if ctx.wait_handle_idx != "bwd.layer_mlp_1_batch_1": - # if ctx.wait_handle_idx != "layer_30_batch_1": if is_async_comm(ctx.wait_handle_idx): handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - # assert 1 == 1 + # assert handle.is_completed() is True + return grad_output, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 9d65878b..c4f69c05 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -55,6 +55,9 @@ def backward(ctx, grad_output): handle_idx = ctx.handle_idx async_all_reduce = ctx.async_all_reduce + if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: + assert 1 == 1 + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None @@ -94,6 +97,9 @@ def forward( handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) if do_async: + if "bwd" in handle_idx: + assert 1 == 1 + # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx # if handle_idx is not None and "bwd." in handle_idx: # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 52d5df48..96af6b12 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -580,7 +580,8 @@ def training_step( from nanotron.parallel.comm import AsyncCommBucket - AsyncCommBucket.clear_all() + assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" + # AsyncCommBucket.clear_all() return outputs, loss_avg From 23f210815e06147cf95fff3fede4641cc1c43101 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Mon, 10 Feb 2025 16:33:23 +0000 Subject: [PATCH 08/17] fix stream not sync --- src/nanotron/constants.py | 4 + src/nanotron/models/llama.py | 46 ++++---- src/nanotron/parallel/comm.py | 15 ++- src/nanotron/parallel/dependency.py | 102 ++++++++++++++++++ .../distributed_differentiable_primitives.py | 4 + src/nanotron/parallel/tensor_parallel/nn.py | 8 +- src/nanotron/trainer.py | 17 ++- 7 files changed, 172 insertions(+), 24 deletions(-) create mode 100644 src/nanotron/parallel/dependency.py diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 78fd0bb9..3fe440a8 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -13,3 +13,7 @@ CUDA_STREAMS = {} + +CLOCK = 0 +_AUTOGRAD_RUNS = [] +_NOT_BWD_ASYNC_OPS = [] diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index acbece96..72ebf478 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -245,13 +245,15 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - async_all_reduce=parallel_config.domino.num_input_batches > 1, + # async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx) - hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx) + hidden_states, work = self.down_proj( + self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx + ) return {"hidden_states": hidden_states, "work": work} @@ -428,7 +430,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - async_all_reduce=async_all_reduce, + # async_all_reduce=async_all_reduce, ) self.attention = CoreAttention( @@ -699,7 +701,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output, handle_idx=handle_idx) + output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -876,6 +878,7 @@ def _core_forward( comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] with torch.cuda.stream(comm_stream): attn_output0["work"].wait() + attn_output0["work"].is_completed() hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 @@ -890,8 +893,16 @@ def _core_forward( handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) + # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend( + # run_after=attn_output1["hidden_states"], + # run_before=mlp_output0["hidden_states"] + # ) + with torch.cuda.stream(comm_stream): attn_output1["work"].wait() + attn_output1["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 residual1 = hidden_states1 @@ -906,11 +917,24 @@ def _core_forward( mlp_output0["work"].wait() mlp_output1["work"].wait() + mlp_output0["work"].is_completed() + mlp_output1["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) + hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 + # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1) + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) assert 1 == 1 + + # assert attn_output0["work"].is_completed() + # assert attn_output1["work"].is_completed() + # assert mlp_output0["work"].is_completed() + # assert mlp_output1["work"].is_completed() + return hidden_states, orig_sequence_mask def _checkpointed_forward( @@ -1080,23 +1104,9 @@ def forward_with_hidden_states( "sequence_mask": input_mask, } - # assert 1 == 1 - # num_input_batches = self.parallel_config.domino.num_input_batches - # hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1) - # hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0) - - # # Combine the chunks into a list of dictionaries - # hidden_encoder_states_list = [ - # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} - # for i in range(num_input_batches) - # ] - for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) - # for hidden_encoder_states in hidden_encoder_states_list: - # hidden_encoder_states = encoder_block(**hidden_encoder_states) - hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index b00f6e9e..789416c3 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -33,6 +33,7 @@ class AsyncCommBucket: """ _async_op: Dict[int, "dist.Work"] = {} + _copy_async_op: Dict[int, "dist.Work"] = {} @staticmethod def add(tensor_id: int, work: "dist.Work"): @@ -40,6 +41,7 @@ def add(tensor_id: int, work: "dist.Work"): tensor_id not in AsyncCommBucket._async_op ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}" AsyncCommBucket._async_op[tensor_id] = work + AsyncCommBucket._copy_async_op[tensor_id] = work @staticmethod def get(tensor_id: int): @@ -58,6 +60,7 @@ def wait(tensor_id: int): @staticmethod def clear_all(): AsyncCommBucket._async_op.clear() + AsyncCommBucket._copy_async_op.clear() def is_async_comm(x): @@ -92,9 +95,19 @@ def backward(ctx, grad_output): assert 1 == 1 if is_async_comm(ctx.wait_handle_idx): + from nanotron.constants import _AUTOGRAD_RUNS + + _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}") handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - # assert handle.is_completed() is True + # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" + else: + + from nanotron import constants + + # if dist.get_rank() == 0: + # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) + constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) return grad_output, None diff --git a/src/nanotron/parallel/dependency.py b/src/nanotron/parallel/dependency.py new file mode 100644 index 00000000..6a633d8a --- /dev/null +++ b/src/nanotron/parallel/dependency.py @@ -0,0 +1,102 @@ +from typing import Dict, Tuple + +import torch +from torch import Tensor + +_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} + + +def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: + """Gets a phony. Phony is tensor without space. It is useful to make + arbitrary dependency in a autograd graph because it doesn't require any + gradient accumulation. + + .. note:: + + Phonies for each device are cached. If an autograd function gets a phony + internally, the phony must be detached to be returned. Otherwise, the + autograd engine will mutate the cached phony in-place:: + + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() # detach() is necessary. + + """ + key = (device, requires_grad) + + try: + phony = _phonies[key] + except KeyError: + with torch.cuda.stream(torch.cuda.default_stream(device)): + phony = torch.empty(0, device=device, requires_grad=requires_grad) + + _phonies[key] = phony + + return phony + + +def fork(input: Tensor) -> Tuple[Tensor, Tensor]: + """Branches out from an autograd lane of the given tensor.""" + if torch.is_grad_enabled() and input.requires_grad: + input, phony = Fork.apply(input) + else: + phony = get_phony(input.device, requires_grad=False) + + return input, phony + + +class Fork(torch.autograd.Function): + @staticmethod + def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + phony = get_phony(input.device, requires_grad=False) + return input, phony.detach() + + @staticmethod + def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + return grad_input + + +def join(input: Tensor, phony: Tensor) -> Tensor: + """Merges two autograd lanes.""" + if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): + input = Join.apply(input, phony) + + return input + + +class Join(torch.autograd.Function): + @staticmethod + def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore + return input + + @staticmethod + def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + return grad_input, None + + +# def depend(fork_from, join_to) -> None: +# # Ensure that batches[i-1] is executed after batches[i] in +# # # backpropagation by an explicit dependency. +# # if i != 0: +# # depend(batches[i-1], batches[i]) +# # depend(run_after, run_before) +# fork_from, phony = fork(fork_from) +# join_to = join(join_to, phony) +# return fork_from, join_to + + +def depend(run_after, run_before) -> None: + # Ensure that batches[i-1] is executed after batches[i] in + # # backpropagation by an explicit dependency. + # if i != 0: + # depend(batches[i-1], batches[i]) + # depend(run_after, run_before) + run_after, phony = fork(run_after) + run_before = join(run_before, phony) + return run_after, run_before diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index c4f69c05..58275368 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -58,6 +58,10 @@ def backward(ctx, grad_output): if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: assert 1 == 1 + from nanotron.constants import _AUTOGRAD_RUNS + + _AUTOGRAD_RUNS.append(handle_idx) + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 847454fd..4fea1838 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -115,7 +115,7 @@ def __init__( device=None, dtype=None, async_communication: bool = False, - async_all_reduce: bool = False, + # async_all_reduce: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, ): self.pg = pg @@ -138,7 +138,7 @@ def __init__( ) self.mode = mode self.async_communication = async_communication - self.async_all_reduce = async_all_reduce + # self.async_all_reduce = async_all_reduce if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") @@ -164,7 +164,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -172,7 +172,7 @@ def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - async_all_reduce=self.async_all_reduce, + async_all_reduce=async_all_reduce, handle_idx=handle_idx, ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 96af6b12..e58af9f5 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -564,6 +564,9 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) + if dist.get_rank() == 0: + assert 1 == 1 + # Apply gradient self.optimizer.step() self.optimizer.zero_grad() @@ -580,8 +583,20 @@ def training_step( from nanotron.parallel.comm import AsyncCommBucket + # import torch.distributed as dist + + not_finished = [] + for k, v in AsyncCommBucket._copy_async_op.items(): + # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}" + if v.is_completed() is not True: + not_finished.append((k, v)) + + # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: + # assert 1 == 1 + + assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" - # AsyncCommBucket.clear_all() + AsyncCommBucket.clear_all() return outputs, loss_avg From eac4ac59360b57015c5c5cf43b7851024c5bc7f5 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Tue, 11 Feb 2025 14:47:21 +0000 Subject: [PATCH 09/17] add cuda stream syncronization for the bwd pass --- src/nanotron/models/llama.py | 26 +++++++------------------- src/nanotron/parallel/comm.py | 6 ++++-- src/nanotron/trainer.py | 2 +- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 72ebf478..bb10021d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -842,6 +842,7 @@ def _core_forward( num_input_batches = self.parallel_config.domino.num_input_batches orig_sequence_mask = sequence_mask + comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] assert num_input_batches == 2 hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) @@ -855,14 +856,8 @@ def _core_forward( hidden_states0 = self.input_layernorm(hidden_states0) hidden_states1 = self.input_layernorm(hidden_states1) - hidden_states0 = WaitComm.apply( - hidden_states0, - BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), - ) - hidden_states1 = WaitComm.apply( - hidden_states1, - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - ) + hidden_states0 = WaitComm.apply(hidden_states0, BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), comm_stream) + hidden_states1 = WaitComm.apply(hidden_states1, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), comm_stream) attn_output0 = self.attn( hidden_states=hidden_states0, @@ -875,7 +870,6 @@ def _core_forward( handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), ) - comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] with torch.cuda.stream(comm_stream): attn_output0["work"].wait() attn_output0["work"].is_completed() @@ -884,8 +878,7 @@ def _core_forward( residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) hidden_states0 = WaitComm.apply( - hidden_states0, - BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), comm_stream ) # new mlp_output0 = self.mlp( @@ -928,12 +921,6 @@ def _core_forward( # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1) hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) - assert 1 == 1 - - # assert attn_output0["work"].is_completed() - # assert attn_output1["work"].is_completed() - # assert mlp_output0["work"].is_completed() - # assert mlp_output1["work"].is_completed() return hidden_states, orig_sequence_mask @@ -1104,8 +1091,9 @@ def forward_with_hidden_states( "sequence_mask": input_mask, } - for encoder_block in self.decoder: - hidden_encoder_states = encoder_block(**hidden_encoder_states) + for layer_idx, encoder_block in enumerate(self.decoder): + with torch.profiler.record_function(f"layer_{layer_idx}"): + hidden_encoder_states = encoder_block(**hidden_encoder_states) hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 789416c3..38a3c134 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -82,8 +82,9 @@ def is_async_comm(x): class WaitComm(torch.autograd.Function): @staticmethod - def forward(ctx, input, wait_handle_idx): + def forward(ctx, input, wait_handle_idx, comm_stream): ctx.wait_handle_idx = wait_handle_idx + ctx.comm_stream = comm_stream return input @staticmethod @@ -101,6 +102,7 @@ def backward(ctx, grad_output): handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() + torch.cuda.default_stream().wait_stream(ctx.comm_stream) # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" else: @@ -110,4 +112,4 @@ def backward(ctx, grad_output): # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) - return grad_output, None + return grad_output, None, None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index e58af9f5..dfec12c0 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -594,7 +594,7 @@ def training_step( # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: # assert 1 == 1 - assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" + assert len(not_finished) == 0, f"len={len(not_finished)}, AsyncCommBucket._copy_async_op: {not_finished}" assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" AsyncCommBucket.clear_all() From 3a438ff238dd8da16e25a3bf83bd1856ee0d727b Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Tue, 11 Feb 2025 16:47:23 +0000 Subject: [PATCH 10/17] domino but non_async_last_batch_mlp_and_non_async_first_batch_attn --- src/nanotron/models/llama.py | 30 ++--- src/nanotron/parallel/comm.py | 52 -------- .../distributed_differentiable_primitives.py | 111 ++++++++++-------- .../parallel/tensor_parallel/functional.py | 24 ++-- src/nanotron/parallel/tensor_parallel/nn.py | 18 ++- 5 files changed, 101 insertions(+), 134 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index bb10021d..a6554bd8 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,10 +30,12 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import WaitComm + +# from nanotron.parallel.comm import WaitComm from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.domino import WaitComm from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -249,11 +251,9 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx) - hidden_states, work = self.down_proj( - self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx - ) + def forward(self, hidden_states, op_name): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states, op_name=op_name) + hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) return {"hidden_states": hidden_states, "work": work} @@ -447,7 +447,7 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] - handle_idx=None, + op_name, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -456,7 +456,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states, async_all_reduce=True, handle_idx=handle_idx + hidden_states, op_name=op_name ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -701,7 +701,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx) + output, work = self.o_proj(attention_output, op_name=op_name) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -862,12 +862,12 @@ def _core_forward( attn_output0 = self.attn( hidden_states=hidden_states0, sequence_mask=sequence_mask0, - handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + op_name=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), ) attn_output1 = self.attn( hidden_states=hidden_states1, sequence_mask=sequence_mask1, - handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + op_name=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), ) with torch.cuda.stream(comm_stream): @@ -883,7 +883,7 @@ def _core_forward( mlp_output0 = self.mlp( hidden_states=hidden_states0, - handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend( @@ -903,15 +903,15 @@ def _core_forward( mlp_output1 = self.mlp( hidden_states=hidden_states1, - handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), ) with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() - mlp_output1["work"].wait() + # mlp_output1["work"].wait() mlp_output0["work"].is_completed() - mlp_output1["work"].is_completed() + # mlp_output1["work"].is_completed() torch.cuda.current_stream().wait_stream(comm_stream) diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 38a3c134..6dbb041f 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -61,55 +61,3 @@ def wait(tensor_id: int): def clear_all(): AsyncCommBucket._async_op.clear() AsyncCommBucket._copy_async_op.clear() - - -def is_async_comm(x): - import re - - NON_ASYNC_HANDLE_IDX = [ - # "fwd.layer_attn_{}_batch_0", - # "fwd.layer_mlp_{}_batch_0", - # "fwd.layer_mlp_{}_batch_1", - "bwd.layer_mlp_{}_batch_1", - "bwd.layer_attn_{}_batch_0", - ] - - patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers - regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex - not_async = bool(regex.match(x)) - return not not_async - - -class WaitComm(torch.autograd.Function): - @staticmethod - def forward(ctx, input, wait_handle_idx, comm_stream): - ctx.wait_handle_idx = wait_handle_idx - ctx.comm_stream = comm_stream - return input - - @staticmethod - def backward(ctx, grad_output): - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - - if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: - assert 1 == 1 - - if is_async_comm(ctx.wait_handle_idx): - from nanotron.constants import _AUTOGRAD_RUNS - - _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}") - handle = AsyncCommBucket.pop(ctx.wait_handle_idx) - assert handle is not None - handle.wait() - torch.cuda.default_stream().wait_stream(ctx.comm_stream) - # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" - else: - - from nanotron import constants - - # if dist.get_rank() == 0: - # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) - constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) - - return grad_output, None, None diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 58275368..5ac3bedf 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -20,16 +20,16 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.tensor_parallel.domino import is_async_comm class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx=None): - # assert handle_idx is not None + def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None): ctx.async_all_reduce = async_all_reduce - ctx.handle_idx = handle_idx + ctx.op_name = op_name ctx.group = group return tensor @@ -41,28 +41,35 @@ def backward(ctx, grad_output): # assert ctx.handle_idx is not None group = ctx.group - if ctx.handle_idx is not None and "fwd." in ctx.handle_idx: - handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") - # if "bwd.layer_mlp_1_batch_1" == handle_idx: - # from nanotron.parallel.comm import is_async_comm - # async_all_reduce = is_async_comm(handle_idx) - # else: - # async_all_reduce = ctx.async_all_reduce - from nanotron.parallel.comm import is_async_comm + # if ctx.handle_idx is not None and "fwd." in ctx.handle_idx: + # handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") + # # if "bwd.layer_mlp_1_batch_1" == handle_idx: + # # from nanotron.parallel.comm import is_async_comm + # # async_all_reduce = is_async_comm(handle_idx) + # # else: + # # async_all_reduce = ctx.async_all_reduce + # # from nanotron.parallel.comm import is_async_comm + # from nanotron.parallel.tensor_parallel.domino import is_async_comm - async_all_reduce = is_async_comm(handle_idx) - else: - handle_idx = ctx.handle_idx - async_all_reduce = ctx.async_all_reduce + # async_all_reduce = is_async_comm(handle_idx) + # else: + # handle_idx = ctx.handle_idx + # async_all_reduce = ctx.async_all_reduce + + # if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: + # assert 1 == 1 + + op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name + async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce - if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: + if op_name is not None and "layer_mlp_27_batch_1" in op_name: assert 1 == 1 from nanotron.constants import _AUTOGRAD_RUNS - _AUTOGRAD_RUNS.append(handle_idx) + _AUTOGRAD_RUNS.append(ctx.op_name) - return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name), None, None, None def is_last_batch_of_attn(x): @@ -79,39 +86,45 @@ class DifferentiableAllReduceSum(torch.autograd.Function): @staticmethod def forward( - ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, handle_idx: Optional[int] = None + ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: ctx.async_all_reduce = async_all_reduce if group.size() == 1: return tensor - if handle_idx == "bwd.layer_mlp_1_batch_0": - assert 1 == 1 - - id(tensor) - if async_all_reduce is True: - # if isinstance(handle_idx, str): - # do_async = is_last_batch_of_attn(handle_idx) is False - # else: - # do_async = async_all_reduce - from nanotron.parallel.comm import is_async_comm - - do_async = is_async_comm(handle_idx) - - handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) - if do_async: - if "bwd" in handle_idx: - assert 1 == 1 - - # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx - # if handle_idx is not None and "bwd." in handle_idx: - # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) - # else: - # AsyncCommBucket.add(orig_id, handle) - # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx - assert handle_idx is not None - AsyncCommBucket.add(handle_idx, handle) + # if handle_idx == "bwd.layer_mlp_1_batch_0": + # assert 1 == 1 + + # id(tensor) + # if async_all_reduce is True: + # # if isinstance(handle_idx, str): + # # do_async = is_last_batch_of_attn(handle_idx) is False + # # else: + # # do_async = async_all_reduce + # # from nanotron.parallel.comm import is_async_comm + # from nanotron.parallel.tensor_parallel.domino import is_async_comm + + # do_async = is_async_comm(handle_idx) + + # handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) + # if do_async: + # if "bwd" in handle_idx: + # assert 1 == 1 + + # # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx + # # if handle_idx is not None and "bwd." in handle_idx: + # # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) + # # else: + # # AsyncCommBucket.add(orig_id, handle) + # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx + # assert handle_idx is not None + # AsyncCommBucket.add(handle_idx, handle) + # else: + # dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + if async_all_reduce: + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) + AsyncCommBucket.add(op_name, handle) else: dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) @@ -202,15 +215,15 @@ def backward(ctx, grad_output): def differentiable_identity( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None ): - return DifferentiableIdentity.apply(tensor, group, async_all_reduce, handle_idx) + return DifferentiableIdentity.apply(tensor, group, async_all_reduce, op_name) def differentiable_all_reduce_sum( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, handle_idx=None + tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None ): - return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, handle_idx) + return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, op_name) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index f0ca3a0d..6d69408c 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -437,13 +437,13 @@ def column_linear( async_communication: bool, tp_recompute_allgather: bool = True, async_all_reduce: bool = False, - handle_idx: Optional[int] = None, + op_name: Optional[str] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) + input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, op_name=op_name) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -592,7 +592,7 @@ def row_linear( # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, async_all_reduce: bool, - handle_idx=None, + op_name: Optional[str] = None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -601,23 +601,31 @@ def row_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) - id(out) + # id(out) # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? - if handle_idx == "fwd.layer_attn_0_batch_0": + if op_name == "fwd.layer_attn_0_batch_0": assert 1 == 1 - out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx) + if op_name == "fwd.layer_mlp_0_batch_1": + assert 1 == 1 + + if op_name == "fwd.layer_attn_0_batch_0": + assert 1 == 1 + + out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, op_name=op_name) if async_all_reduce: from nanotron.parallel.comm import AsyncCommBucket # work = AsyncCommBucket.get(orig_out_id) # work = AsyncCommBucket.pop(orig_out_id) # if handle_idx == "fwd.layer_mlp_1_batch_0": - if handle_idx == "fwd.layer_attn_0_batch_0": + if op_name == "fwd.layer_attn_0_batch_0": assert 1 == 1 - work = AsyncCommBucket.pop(handle_idx) + work = AsyncCommBucket.pop(op_name) assert 1 == 1 + else: + work = None elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 4fea1838..2e6fd5a4 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -31,6 +31,7 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) +from nanotron.parallel.tensor_parallel.domino import is_async_comm from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.functional import ( column_linear, @@ -52,7 +53,6 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, - # handle_idx: Optional[int] = None, ): self.pg = pg self.world_size = pg.size() @@ -73,7 +73,6 @@ def __init__( self.mode = mode self.async_communication = async_communication - # self.handle_idx = handle_idx if contiguous_chunks is not None: assert ( @@ -87,7 +86,7 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -96,8 +95,8 @@ def forward(self, x: torch.Tensor, async_all_reduce=None, handle_idx=None) -> to tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, - async_all_reduce=async_all_reduce, - handle_idx=handle_idx, + async_all_reduce=False if op_name is None else is_async_comm(op_name), + op_name=op_name, ) def extra_repr(self) -> str: @@ -115,7 +114,6 @@ def __init__( device=None, dtype=None, async_communication: bool = False, - # async_all_reduce: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, ): self.pg = pg @@ -138,7 +136,7 @@ def __init__( ) self.mode = mode self.async_communication = async_communication - # self.async_all_reduce = async_all_reduce + if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") @@ -164,7 +162,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -172,8 +170,8 @@ def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.T group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - async_all_reduce=async_all_reduce, - handle_idx=handle_idx, + async_all_reduce=False if op_name is None else is_async_comm(op_name), + op_name=op_name, ) def extra_repr(self) -> str: From da948dfab4690c13435e1c0ab7ca0ae957bb0451 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Tue, 11 Feb 2025 16:48:10 +0000 Subject: [PATCH 11/17] non_async_last_batch_mlp_and_non_async_first_batch_attn --- .../parallel/tensor_parallel/domino.py | 60 +++++++++++++++++++ tests/test_domino.py | 19 ++++++ 2 files changed, 79 insertions(+) create mode 100644 src/nanotron/parallel/tensor_parallel/domino.py create mode 100644 tests/test_domino.py diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py new file mode 100644 index 00000000..7fb68abc --- /dev/null +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -0,0 +1,60 @@ +import re + +import torch + +from nanotron.parallel.comm import AsyncCommBucket + + +def is_async_comm(op_name: str): + """ + There are two operations that we can't overlap + for the forward pass: the last micro-batch of the mlp layer + for the backward pass: the first micro-batch of the attention layer + """ + NON_ASYNC_HANDLE_IDX = [ + # "fwd.layer_attn_{}_batch_0", + # "fwd.layer_mlp_{}_batch_0", + "fwd.layer_mlp_{}_batch_1", + # "bwd.layer_mlp_{}_batch_1", + "bwd.layer_attn_{}_batch_0", + ] + + patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers + regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex + not_async = bool(regex.match(op_name)) + return not not_async + + +class WaitComm(torch.autograd.Function): + @staticmethod + def forward(ctx, input, wait_handle_idx, comm_stream): + ctx.wait_handle_idx = wait_handle_idx + ctx.comm_stream = comm_stream + return input + + @staticmethod + def backward(ctx, grad_output): + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + + if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: + assert 1 == 1 + + if is_async_comm(ctx.wait_handle_idx): + from nanotron.constants import _AUTOGRAD_RUNS + + _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}") + handle = AsyncCommBucket.pop(ctx.wait_handle_idx) + assert handle is not None + handle.wait() + torch.cuda.default_stream().wait_stream(ctx.comm_stream) + # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" + else: + + from nanotron import constants + + # if dist.get_rank() == 0: + # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) + constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) + + return grad_output, None, None diff --git a/tests/test_domino.py b/tests/test_domino.py new file mode 100644 index 00000000..8f474ff8 --- /dev/null +++ b/tests/test_domino.py @@ -0,0 +1,19 @@ +import pytest +from nanotron.parallel.tensor_parallel.domino import is_async_comm + + +@pytest.mark.parametrize( + "op_name, expected", + [ + ("fwd.layer_attn_1_batch_0", True), + ("fwd.layer_attn_1_batch_1", True), + ("fwd.layer_mlp_1_batch_0", True), + ("fwd.layer_mlp_1_batch_1", False), + ("bwd.layer_mlp_1_batch_1", True), + ("bwd.layer_mlp_1_batch_0", True), + ("bwd.layer_attn_1_batch_1", True), + ("bwd.layer_attn_1_batch_0", False), + ], +) +def test_is_async_comm(op_name, expected): + assert is_async_comm(op_name) == expected From aa3e97393058c0b9bcc7bf65dccccc2c8a3f5c91 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Wed, 12 Feb 2025 13:21:51 +0000 Subject: [PATCH 12/17] backup before refactoring --- examples/config_llama_domino.yaml | 98 ++++++++++++++++ src/nanotron/models/llama.py | 107 ++---------------- .../distributed_differentiable_primitives.py | 105 ++++++----------- .../parallel/tensor_parallel/domino.py | 15 ++- src/nanotron/trainer.py | 6 - 5 files changed, 149 insertions(+), 182 deletions(-) create mode 100644 examples/config_llama_domino.yaml diff --git a/examples/config_llama_domino.yaml b/examples/config_llama_domino.yaml new file mode 100644 index 00000000..b9811fdd --- /dev/null +++ b/examples/config_llama_domino.yaml @@ -0,0 +1,98 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: nanotron_domino + run: config_llama_domino + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 16384 + is_llama_config: true + max_position_embeddings: 4096 + num_attention_heads: 32 + num_hidden_layers: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 1000 + lr_decay_style: cosine + lr_warmup_steps: 500 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + pp: 1 + tp: 8 + expert_parallel_size: 1 + pp_engine: 1f1b + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + domino: + num_input_batches: 2 +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 4096 + train_steps: 1500 + val_check_interval: -1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index a6554bd8..b6785604 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -35,7 +35,13 @@ from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P -from nanotron.parallel.tensor_parallel.domino import WaitComm +from nanotron.parallel.tensor_parallel.domino import ( + BWD_ATTN_HANDLE_IDX, + BWD_MLP_HANDLE_IDX, + FWD_ATTN_HANDLE_IDX, + FWD_MLP_HANDLE_IDX, + WaitComm, +) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -51,11 +57,6 @@ DOMINO_COMM_STREAM = "domino_comm_stream_{}" -FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" -FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" -BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" -BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" - class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -743,96 +744,6 @@ def __init__( self.layer_idx = layer_idx - # def _core_forward( - # self, - # hidden_states: Union[torch.Tensor, TensorPointer], - # sequence_mask: Union[torch.Tensor, TensorPointer], - # ) -> List[Union[torch.Tensor, TensorPointer]]: - # from nanotron import constants - - # num_input_batches = self.parallel_config.domino.num_input_batches - # orig_sequence_mask = sequence_mask - - # assert num_input_batches == 2 - # hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) - # sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) - - # hidden_states0, hidden_states1 = hidden_states - # sequence_mask0, sequence_mask1 = sequence_mask - - # residual0 = hidden_states0 - # residual1 = hidden_states1 - - # hidden_states0 = self.input_layernorm(hidden_states0) - # hidden_states1 = self.input_layernorm(hidden_states1) - - # attn_output0 = self.attn( - # hidden_states=hidden_states0, - # sequence_mask=sequence_mask0, - # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), - # ) - # # attn_output0["hidden_states"] = WaitComm.apply( - # # attn_output0["hidden_states"], - # # BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), - # # ) - - # attn_output1 = self.attn( - # hidden_states=hidden_states1, - # sequence_mask=sequence_mask1, - # handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), - # ) - # # attn_output1["hidden_states"] = WaitComm.apply( - # # attn_output1["hidden_states"], - # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - # # ) - - # comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] - # with torch.cuda.stream(comm_stream): - # attn_output0["work"].wait() - - # hidden_states0 = attn_output0["hidden_states"] + residual0 - # residual0 = hidden_states0 - # hidden_states0 = self.post_attention_layernorm(hidden_states0) - # hidden_states0 = WaitComm.apply( - # hidden_states0, - # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - # ) # new - - # mlp_output0 = self.mlp( - # hidden_states=hidden_states0, - # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - # ) - # # mlp_output0["hidden_states"] = WaitComm.apply( - # # mlp_output0["hidden_states"], - # # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - # # ) - - # with torch.cuda.stream(comm_stream): - # attn_output1["work"].wait() - - # hidden_states1 = attn_output1["hidden_states"] + residual1 - # residual1 = hidden_states1 - # hidden_states1 = self.post_attention_layernorm(hidden_states1) - # hidden_states1 = WaitComm.apply( - # hidden_states1, - # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), - # ) - - # mlp_output1 = self.mlp( - # hidden_states=hidden_states1, - # handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), - # ) - - # with torch.cuda.stream(comm_stream): - # mlp_output0["work"].wait() - # mlp_output1["work"].wait() - - # hidden_states0 = mlp_output0["hidden_states"] + residual0 - # hidden_states1 = mlp_output1["hidden_states"] + residual1 - - # hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) - # return hidden_states, orig_sequence_mask - def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -908,12 +819,10 @@ def _core_forward( with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() - # mlp_output1["work"].wait() - mlp_output0["work"].is_completed() - # mlp_output1["work"].is_completed() torch.cuda.current_stream().wait_stream(comm_stream) + # torch.cuda.synchronize() hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 5ac3bedf..5446f571 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -27,58 +27,39 @@ class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None): + def forward( + ctx, + tensor, + group: Optional[ProcessGroup], + async_all_reduce: bool, + op_name: str = None, + comm_stream: torch.cuda.Stream = None, + ): + ctx.group = group ctx.async_all_reduce = async_all_reduce ctx.op_name = op_name - ctx.group = group + ctx.comm_stream = comm_stream return tensor @staticmethod def backward(ctx, grad_output): # import pydevd # pydevd.settrace(suspend=False, trace_only_current_thread=True) - # NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async - # assert ctx.handle_idx is not None - group = ctx.group - - # if ctx.handle_idx is not None and "fwd." in ctx.handle_idx: - # handle_idx = ctx.handle_idx.replace("fwd.", "bwd.") - # # if "bwd.layer_mlp_1_batch_1" == handle_idx: - # # from nanotron.parallel.comm import is_async_comm - # # async_all_reduce = is_async_comm(handle_idx) - # # else: - # # async_all_reduce = ctx.async_all_reduce - # # from nanotron.parallel.comm import is_async_comm - # from nanotron.parallel.tensor_parallel.domino import is_async_comm - - # async_all_reduce = is_async_comm(handle_idx) - # else: - # handle_idx = ctx.handle_idx - # async_all_reduce = ctx.async_all_reduce - - # if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: - # assert 1 == 1 - - op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name - async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce - - if op_name is not None and "layer_mlp_27_batch_1" in op_name: - assert 1 == 1 - from nanotron.constants import _AUTOGRAD_RUNS _AUTOGRAD_RUNS.append(ctx.op_name) - return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name), None, None, None - + group = ctx.group -def is_last_batch_of_attn(x): - import re + op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name + async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce - pattern = r"layer_attn_\d+_batch_0" - if re.match(pattern, x): - return True - return False + return ( + DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream), + None, + None, + None, + ) class DifferentiableAllReduceSum(torch.autograd.Function): @@ -86,47 +67,25 @@ class DifferentiableAllReduceSum(torch.autograd.Function): @staticmethod def forward( - ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None + ctx, + tensor, + group: Optional[ProcessGroup], + async_all_reduce: bool, + op_name: str = None, + comm_stream: torch.cuda.Stream = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: ctx.async_all_reduce = async_all_reduce + ctx.comm_stream = comm_stream if group.size() == 1: return tensor - # if handle_idx == "bwd.layer_mlp_1_batch_0": - # assert 1 == 1 - - # id(tensor) - # if async_all_reduce is True: - # # if isinstance(handle_idx, str): - # # do_async = is_last_batch_of_attn(handle_idx) is False - # # else: - # # do_async = async_all_reduce - # # from nanotron.parallel.comm import is_async_comm - # from nanotron.parallel.tensor_parallel.domino import is_async_comm - - # do_async = is_async_comm(handle_idx) - - # handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async) - # if do_async: - # if "bwd" in handle_idx: - # assert 1 == 1 - - # # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx - # # if handle_idx is not None and "bwd." in handle_idx: - # # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle) - # # else: - # # AsyncCommBucket.add(orig_id, handle) - # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx - # assert handle_idx is not None - # AsyncCommBucket.add(handle_idx, handle) - # else: - # dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) - if async_all_reduce: - handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) - AsyncCommBucket.add(op_name, handle) - else: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + with torch.cuda.stream(comm_stream): + if async_all_reduce: + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) + AsyncCommBucket.add(op_name, handle) + else: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) return tensor diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 7fb68abc..d864bcbf 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -4,6 +4,11 @@ from nanotron.parallel.comm import AsyncCommBucket +FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" +FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" +BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" +BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" + def is_async_comm(op_name: str): """ @@ -40,6 +45,9 @@ def backward(ctx, grad_output): if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: assert 1 == 1 + if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx: + assert 1 == 1 + if is_async_comm(ctx.wait_handle_idx): from nanotron.constants import _AUTOGRAD_RUNS @@ -48,13 +56,12 @@ def backward(ctx, grad_output): assert handle is not None handle.wait() torch.cuda.default_stream().wait_stream(ctx.comm_stream) - # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" else: - from nanotron import constants - # if dist.get_rank() == 0: - # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) + # if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx: + # assert AsyncCommBucket._copy_async_op.get(ctx.wait_handle_idx).is_completed() is True + return grad_output, None, None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index dfec12c0..fb05ebf5 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -564,9 +564,6 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) - if dist.get_rank() == 0: - assert 1 == 1 - # Apply gradient self.optimizer.step() self.optimizer.zero_grad() @@ -583,11 +580,8 @@ def training_step( from nanotron.parallel.comm import AsyncCommBucket - # import torch.distributed as dist - not_finished = [] for k, v in AsyncCommBucket._copy_async_op.items(): - # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}" if v.is_completed() is not True: not_finished.append((k, v)) From ea09a25ef6c8493cc7be2779ce9a581ba27b7dca Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Wed, 12 Feb 2025 13:39:08 +0000 Subject: [PATCH 13/17] refactor --- examples/config_llama_domino.yaml | 2 +- src/nanotron/optim/gradient_accumulator.py | 4 ---- src/nanotron/parallel/comm.py | 10 ++++++++ .../parallel/pipeline_parallel/engine.py | 3 --- .../distributed_differentiable_primitives.py | 8 ------- .../parallel/tensor_parallel/domino.py | 23 +------------------ .../parallel/tensor_parallel/functional.py | 22 +----------------- src/nanotron/sanity_checks.py | 4 ++++ src/nanotron/trainer.py | 16 ++----------- 9 files changed, 19 insertions(+), 73 deletions(-) diff --git a/examples/config_llama_domino.yaml b/examples/config_llama_domino.yaml index b9811fdd..30f59161 100644 --- a/examples/config_llama_domino.yaml +++ b/examples/config_llama_domino.yaml @@ -20,7 +20,7 @@ data_stages: general: benchmark_csv_path: null consumed_train_samples: null - ignore_sanity_checks: true + ignore_sanity_checks: false project: nanotron_domino run: config_llama_domino seed: 42 diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index b5ef7d89..2e940744 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -202,10 +202,6 @@ def build_grad_buffers( return fp32_grad_buffers, contiguous_buffer_f32_gradients def backward(self, loss: torch.Tensor): - if not isinstance(loss, torch.Tensor): - assert 1 == 1 - raise NotImplementedError("Not implemented yet") - result = loss.backward() for name, elt in self.fp32_grad_buffers.items(): diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 6dbb041f..63736718 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -57,6 +57,16 @@ def wait(tensor_id: int): work = AsyncCommBucket._async_op.pop(tensor_id) work.wait() + @staticmethod + def is_all_completed() -> bool: + assert len(AsyncCommBucket._async_op) == 0, "there are still some async ops haven't executed" + + not_finished = [] + for k, v in AsyncCommBucket._copy_async_op.items(): + if v.is_completed() is not True: + not_finished.append((k, v)) + return len(not_finished) == 0 + @staticmethod def clear_all(): AsyncCommBucket._async_op.clear() diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 8160f302..076943c7 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -84,9 +84,6 @@ def backward( if grad_accumulator is None: sum(activations).backward() else: - # if not isinstance(activations, torch.Tensor): - # raise NotImplementedError("Only support sum of tensors for now") - grad_accumulator.backward(sum(activations)) # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 5446f571..1254fdb1 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -43,17 +43,9 @@ def forward( @staticmethod def backward(ctx, grad_output): - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - from nanotron.constants import _AUTOGRAD_RUNS - - _AUTOGRAD_RUNS.append(ctx.op_name) - group = ctx.group - op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce - return ( DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream), None, diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index d864bcbf..e35cac06 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -39,29 +39,8 @@ def forward(ctx, input, wait_handle_idx, comm_stream): @staticmethod def backward(ctx, grad_output): - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - - if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx: - assert 1 == 1 - - if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx: - assert 1 == 1 - if is_async_comm(ctx.wait_handle_idx): - from nanotron.constants import _AUTOGRAD_RUNS - - _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}") - handle = AsyncCommBucket.pop(ctx.wait_handle_idx) - assert handle is not None - handle.wait() + AsyncCommBucket.wait(ctx.wait_handle_idx) torch.cuda.default_stream().wait_stream(ctx.comm_stream) - else: - from nanotron import constants - - constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) - - # if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx: - # assert AsyncCommBucket._copy_async_op.get(ctx.wait_handle_idx).is_completed() is True return grad_output, None, None diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 6d69408c..915a2c31 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,6 +19,7 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, @@ -600,30 +601,9 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - # out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce) - # id(out) - # NOTE: why the id(out) doesn't match the id(out) before the all_reduce? - if op_name == "fwd.layer_attn_0_batch_0": - assert 1 == 1 - - if op_name == "fwd.layer_mlp_0_batch_1": - assert 1 == 1 - - if op_name == "fwd.layer_attn_0_batch_0": - assert 1 == 1 - out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, op_name=op_name) if async_all_reduce: - from nanotron.parallel.comm import AsyncCommBucket - - # work = AsyncCommBucket.get(orig_out_id) - # work = AsyncCommBucket.pop(orig_out_id) - # if handle_idx == "fwd.layer_mlp_1_batch_0": - if op_name == "fwd.layer_attn_0_batch_0": - assert 1 == 1 - work = AsyncCommBucket.pop(op_name) - assert 1 == 1 else: work = None elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index 56ef1e2e..9d1a1589 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -10,6 +10,7 @@ from nanotron.models import NanotronModel from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.tied_parameters import get_tied_id_to_param logger = get_logger(__name__) @@ -239,6 +240,9 @@ def before_optim_step_sanity_checks( # SANITY CHECK: run model specific sanity checks unwrapped_model.before_optim_step_sanity_checks() + # SANITY CHECK: for domino + assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" + def after_optim_step_sanity_checks( config: Config, diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index fb05ebf5..35686590 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -61,6 +61,7 @@ from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.parallel.pipeline_parallel.engine import ( @@ -563,6 +564,7 @@ def training_step( before_optim_step_sanity_checks( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) + AsyncCommBucket.clear_all() # Apply gradient self.optimizer.step() @@ -578,20 +580,6 @@ def training_step( self.post_train_step() - from nanotron.parallel.comm import AsyncCommBucket - - not_finished = [] - for k, v in AsyncCommBucket._copy_async_op.items(): - if v.is_completed() is not True: - not_finished.append((k, v)) - - # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: - # assert 1 == 1 - - assert len(not_finished) == 0, f"len={len(not_finished)}, AsyncCommBucket._copy_async_op: {not_finished}" - assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" - AsyncCommBucket.clear_all() - return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: From 7761d82401d7f3e9eb1a106007c0da2e1f6460e0 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Wed, 12 Feb 2025 16:35:19 +0000 Subject: [PATCH 14/17] refactor --- examples/config_llama_domino.yaml | 2 +- src/nanotron/config/parallelism_config.py | 8 + src/nanotron/constants.py | 7 - src/nanotron/models/llama.py | 171 +++++++++--------- src/nanotron/parallel/comm.py | 6 +- .../parallel/tensor_parallel/domino.py | 8 +- src/nanotron/trainer.py | 5 + tests/helpers/llama.py | 49 ++--- tests/test_base_model.py | 32 +++- tests/test_domino.py | 52 ++++++ 10 files changed, 214 insertions(+), 126 deletions(-) diff --git a/examples/config_llama_domino.yaml b/examples/config_llama_domino.yaml index 30f59161..b9811fdd 100644 --- a/examples/config_llama_domino.yaml +++ b/examples/config_llama_domino.yaml @@ -20,7 +20,7 @@ data_stages: general: benchmark_csv_path: null consumed_train_samples: null - ignore_sanity_checks: false + ignore_sanity_checks: true project: nanotron_domino run: config_llama_domino seed: 42 diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 2701bf9c..07688959 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -25,6 +25,7 @@ class DominoArgs: def __post_init__(self): assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1" + assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino" @dataclass @@ -68,3 +69,10 @@ def __post_init__(self): self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine) if isinstance(self.tp_mode, str): self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()] + + if self.is_domino_enabled is True: + assert self.tp > 1, "Domino requires TP > 1" + + @property + def is_domino_enabled(self) -> bool: + return True if self.domino else False diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 3fe440a8..580bd99d 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -10,10 +10,3 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" - - -CUDA_STREAMS = {} - -CLOCK = 0 -_AUTOGRAD_RUNS = [] -_NOT_BWD_ASYNC_OPS = [] diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index b6785604..ca0b50c6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,8 +30,7 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext - -# from nanotron.parallel.comm import WaitComm +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P @@ -55,8 +54,6 @@ logger = logging.get_logger(__name__) -DOMINO_COMM_STREAM = "domino_comm_stream_{}" - class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -248,11 +245,10 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - # async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states, op_name): # [seq_length, batch_size, hidden_dim] + def forward(self, hidden_states, op_name: str = None): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states, op_name=op_name) hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) return {"hidden_states": hidden_states, "work": work} @@ -347,7 +343,6 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, - async_all_reduce: bool = False, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -431,7 +426,6 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - # async_all_reduce=async_all_reduce, ) self.attention = CoreAttention( @@ -448,7 +442,7 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] - op_name, + op_name: str = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -707,7 +701,7 @@ def forward( return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} -class LlamaDecoderLayer(nn.Module): +class _BaseLlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, @@ -723,7 +717,6 @@ def __init__( parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, - async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -731,31 +724,93 @@ def __init__( self.recompute_layer = parallel_config.recompute_layer self.parallel_config = parallel_config + self.layer_idx = layer_idx - # if parallel_config.domino is not None and parallel_config.domino.num_input_batches > 1: - # from nanotron.parallel.comm import CudaStreamManager - # # NOTE: we use different cuda streams for different gpus, so it can overlaps the communication - # CudaStreamManager.create(DOMINO_COMM_STREAM.format(torch.cuda.current_device())) - num_gpus = torch.cuda.device_count() - for i in range(num_gpus): - from nanotron import constants + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + sequence_mask: torch.Tensor, + ) -> List[torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) - constants.CUDA_STREAMS[i] = torch.cuda.Stream(device=torch.device(f"cuda:{i}")) + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - self.layer_idx = layer_idx + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): + hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) + else: + hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) + + return { + "hidden_states": hidden_states, + "sequence_mask": sequence_mask, + } + +class LlamaDecoderLayer(_BaseLlamaDecoderLayer): def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> List[Union[torch.Tensor, TensorPointer]]: - from nanotron import constants + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + hidden_states = hidden_states + residual + + return hidden_states, output["sequence_mask"] + +class Embedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length] + store = self.get_local_store() + if store is not None: + if "past_length" in store: + past_length = store["past_length"] + else: + past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) + + cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) + # Store new past_length in store + store["past_length"] = past_length + cumsum_mask[:, -1] + + # Format input in `[seq_length, batch_size]` to support high TP with low batch_size + input_ids = input_ids.transpose(0, 1) + input_embeds = self.token_embedding(input_ids) + return {"input_embeds": input_embeds} + + +class DominoLlamaDecoderLayer(_BaseLlamaDecoderLayer): + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> List[Union[torch.Tensor, TensorPointer]]: num_input_batches = self.parallel_config.domino.num_input_batches orig_sequence_mask = sequence_mask - comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] + comm_stream = CudaStreamManager.get(f"comm_stream_{torch.cuda.current_device()}") - assert num_input_batches == 2 hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) @@ -788,20 +843,13 @@ def _core_forward( hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 hidden_states0 = self.post_attention_layernorm(hidden_states0) - hidden_states0 = WaitComm.apply( - hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), comm_stream - ) # new + hidden_states0 = WaitComm.apply(hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), comm_stream) mlp_output0 = self.mlp( hidden_states=hidden_states0, op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) - # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend( - # run_after=attn_output1["hidden_states"], - # run_before=mlp_output0["hidden_states"] - # ) - with torch.cuda.stream(comm_stream): attn_output1["work"].wait() attn_output1["work"].is_completed() @@ -822,70 +870,14 @@ def _core_forward( mlp_output0["work"].is_completed() torch.cuda.current_stream().wait_stream(comm_stream) - # torch.cuda.synchronize() hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 - # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1) - hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) return hidden_states, orig_sequence_mask - def _checkpointed_forward( - self, - hidden_states: torch.Tensor, - sequence_mask: torch.Tensor, - ) -> List[torch.Tensor]: - return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) - - def forward( - self, - hidden_states: Union[torch.Tensor, TensorPointer], - sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - - if self.recompute_layer and not isinstance(hidden_states, TensorPointer): - hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) - else: - hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) - - return { - "hidden_states": hidden_states, - "sequence_mask": sequence_mask, - } - - -class Embedding(nn.Module, AttachableStore): - def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): - super().__init__() - self.token_embedding = TensorParallelEmbedding( - num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - padding_idx=config.pad_token_id, - pg=tp_pg, - mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, - ) - self.pg = tp_pg - - def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length] - store = self.get_local_store() - if store is not None: - if "past_length" in store: - past_length = store["past_length"] - else: - past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) - - cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) - # Store new past_length in store - store["past_length"] = past_length + cumsum_mask[:, -1] - - # Format input in `[seq_length, batch_size]` to support high TP with low batch_size - input_ids = input_ids.transpose(0, 1) - input_embeds = self.token_embedding(input_ids) - return {"input_embeds": input_embeds} - class LlamaModel(nn.Module): """Build pipeline graph""" @@ -896,6 +888,8 @@ def __init__( parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], ): + # from nanotron.parallel.tensor_parallel.domino import DominoLlamaDecoderLayer + super().__init__() # Declare all the nodes @@ -931,7 +925,7 @@ def __init__( [ PipelineBlock( p2p=self.p2p, - module_builder=LlamaDecoderLayer, + module_builder=DominoLlamaDecoderLayer if parallel_config.is_domino_enabled else LlamaDecoderLayer, module_kwargs={ "config": config, "parallel_config": parallel_config, @@ -992,7 +986,6 @@ def forward_with_hidden_states( input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] ): # all tensors are optional as most ranks don't need anything from the dataloader. - output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) hidden_encoder_states = { diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 63736718..248b966c 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -8,12 +8,14 @@ class CudaStreamManager: _streams: Dict[str, "torch.cuda.Stream"] = {} @staticmethod - def create(name: str): + def create(name: str, device: torch.device = None): assert name not in CudaStreamManager._streams - CudaStreamManager._streams[name] = torch.cuda.Stream() + CudaStreamManager._streams[name] = torch.cuda.Stream(device=device) @staticmethod def get(name: str): + if name not in CudaStreamManager._streams: + CudaStreamManager.create(name) return CudaStreamManager._streams.get(name) @contextmanager diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index e35cac06..26050c3a 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -4,6 +4,11 @@ from nanotron.parallel.comm import AsyncCommBucket +# from nanotron.models.llama import _BaseLlamaDecoderLayer +# from nanotron.parallel.pipeline_parallel.block import TensorPointer +# from nanotron.parallel.comm import CudaStreamManager + + FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" @@ -17,10 +22,7 @@ def is_async_comm(op_name: str): for the backward pass: the first micro-batch of the attention layer """ NON_ASYNC_HANDLE_IDX = [ - # "fwd.layer_attn_{}_batch_0", - # "fwd.layer_mlp_{}_batch_0", "fwd.layer_mlp_{}_batch_1", - # "bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0", ] diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 35686590..f5370525 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -444,6 +444,11 @@ def train( # free memory gc.collect() torch.cuda.empty_cache() + + # num_gpus = torch.cuda.device_count() + # for i in range(num_gpus): + # CudaStreamManager.create(f"comm_stream_{i}", device=torch.device(f"cuda:{i}")) + with prof: for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1): if isinstance(prof, torch.profiler.profile): diff --git a/tests/helpers/llama.py b/tests/helpers/llama.py index 3f94031f..8aae7669 100644 --- a/tests/helpers/llama.py +++ b/tests/helpers/llama.py @@ -1,5 +1,6 @@ import torch from nanotron.config import ( + AdamWOptimizerArgs, AllForwardAllBackwardPipelineEngine, CheckpointsArgs, Config, @@ -46,7 +47,19 @@ ) -def get_llama_training_config(model_config: ModelArgs): +def get_parallel_config(parallel_context: ParallelContext): + return ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + expert_parallel_size=parallel_context.expert_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + +def get_llama_training_config(model_config: ModelArgs, parallel_context): return Config( model=model_config, general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), @@ -54,25 +67,20 @@ def get_llama_training_config(model_config: ModelArgs): checkpoints_path="./checkpoints", checkpoint_interval=10, ), - parallelism=ParallelismArgs( - dp=1, - pp=1, - tp=2, - expert_parallel_size=2, - pp_engine="1f1b", - tp_mode="ALL_REDUCE", - tp_linear_async_communication=False, - ), + parallelism=get_parallel_config(parallel_context), tokenizer=TokenizerArgs("gpt2"), optimizer=OptimizerArgs( zero_stage=0, weight_decay=0.01, clip_grad=1.0, accumulate_grad_in_fp32=False, - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + name="adamW", + ), learning_rate_scheduler=LRSchedulerArgs( learning_rate=3e-4, lr_warmup_steps=100, @@ -103,7 +111,10 @@ def get_llama_training_config(model_config: ModelArgs): def create_llama_from_config( - model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext + model_config: LlamaConfig, + parallel_config: ParallelismArgs, + device: torch.device, + parallel_context: ParallelContext, ) -> LlamaForTraining: """ @@ -114,14 +125,6 @@ def create_llama_from_config( the model created will have random weights. """ - parallel_config = ParallelismArgs( - dp=parallel_context.data_parallel_size, - pp=parallel_context.pipeline_parallel_size, - tp=parallel_context.tensor_parallel_size, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) model = build_model( model_builder=lambda: LlamaForTraining( config=model_config, diff --git a/tests/test_base_model.py b/tests/test_base_model.py index b4759905..410e302c 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -10,7 +10,6 @@ @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)]) -@pytest.mark.skip @rerun_if_address_is_in_use() def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int): model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) @@ -43,3 +42,34 @@ def _test_get_named_modules_in_pp_rank( # not PipelineBlock assert isinstance(module, nn.Module) assert name not in modules_that_not_in_current_pp_rank + + +@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_llama_model(tp: int, dp: int, pp: int): + BATCH_SIZE, SEQ_LEN = 10, 128 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + config = get_llama_training_config(model_args) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_llama_model)(config=config, batch_size=BATCH_SIZE, seq_len=SEQ_LEN) + + +def _test_llama_model( + parallel_context: ParallelContext, + config: Config, + batch_size: int, + seq_len: int, +): + llama_model = create_llama_from_config( + model_config=config.model.model_config, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama_model.init_model_randomly(config=config) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert list(outputs.keys()) == ["loss"] + assert isinstance(outputs["loss"], torch.Tensor) diff --git a/tests/test_domino.py b/tests/test_domino.py index 8f474ff8..977e5407 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -1,4 +1,14 @@ +from copy import deepcopy + import pytest +import torch +from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.utils import init_distributed, rerun_if_address_is_in_use +from nanotron.config import ModelArgs, RandomInit +from nanotron.config.parallelism_config import DominoArgs +from nanotron.models.llama import DominoLlamaDecoderLayer +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.tensor_parallel.domino import is_async_comm @@ -17,3 +27,45 @@ ) def test_is_async_comm(op_name, expected): assert is_async_comm(op_name) == expected + + +@pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_domino_model(tp: int, dp: int, pp: int): + BATCH_SIZE, SEQ_LEN = 10, 128 + + model_config = deepcopy(TINY_LLAMA_CONFIG) + model_config.num_hidden_layers = 28 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_domino_model)( + model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN + ) + + +def _test_domino_model( + parallel_context: ParallelContext, + model_args: ModelArgs, + batch_size: int, + seq_len: int, +): + config = get_llama_training_config(model_args, parallel_context) + config.parallelism.domino = DominoArgs(num_input_batches=2) + + llama_model = create_llama_from_config( + model_config=config.model.model_config, + parallel_config=config.parallelism, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama_model.init_model_randomly(config=config) + + for m in llama_model.model.decoder: + assert isinstance(m.pp_block, DominoLlamaDecoderLayer) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert isinstance(outputs["loss"], torch.Tensor) + assert AsyncCommBucket.is_all_completed() From b11e48fc0de4a86855d27c116365604d28832466 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Wed, 12 Feb 2025 16:37:06 +0000 Subject: [PATCH 15/17] remove dependency.py and comments --- src/nanotron/parallel/dependency.py | 102 ------------------ .../parallel/tensor_parallel/domino.py | 4 - src/nanotron/trainer.py | 4 - 3 files changed, 110 deletions(-) delete mode 100644 src/nanotron/parallel/dependency.py diff --git a/src/nanotron/parallel/dependency.py b/src/nanotron/parallel/dependency.py deleted file mode 100644 index 6a633d8a..00000000 --- a/src/nanotron/parallel/dependency.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Dict, Tuple - -import torch -from torch import Tensor - -_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} - - -def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: - """Gets a phony. Phony is tensor without space. It is useful to make - arbitrary dependency in a autograd graph because it doesn't require any - gradient accumulation. - - .. note:: - - Phonies for each device are cached. If an autograd function gets a phony - internally, the phony must be detached to be returned. Otherwise, the - autograd engine will mutate the cached phony in-place:: - - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() # detach() is necessary. - - """ - key = (device, requires_grad) - - try: - phony = _phonies[key] - except KeyError: - with torch.cuda.stream(torch.cuda.default_stream(device)): - phony = torch.empty(0, device=device, requires_grad=requires_grad) - - _phonies[key] = phony - - return phony - - -def fork(input: Tensor) -> Tuple[Tensor, Tensor]: - """Branches out from an autograd lane of the given tensor.""" - if torch.is_grad_enabled() and input.requires_grad: - input, phony = Fork.apply(input) - else: - phony = get_phony(input.device, requires_grad=False) - - return input, phony - - -class Fork(torch.autograd.Function): - @staticmethod - def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore - phony = get_phony(input.device, requires_grad=False) - return input, phony.detach() - - @staticmethod - def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - return grad_input - - -def join(input: Tensor, phony: Tensor) -> Tensor: - """Merges two autograd lanes.""" - if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): - input = Join.apply(input, phony) - - return input - - -class Join(torch.autograd.Function): - @staticmethod - def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore - return input - - @staticmethod - def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore - # import pydevd - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - return grad_input, None - - -# def depend(fork_from, join_to) -> None: -# # Ensure that batches[i-1] is executed after batches[i] in -# # # backpropagation by an explicit dependency. -# # if i != 0: -# # depend(batches[i-1], batches[i]) -# # depend(run_after, run_before) -# fork_from, phony = fork(fork_from) -# join_to = join(join_to, phony) -# return fork_from, join_to - - -def depend(run_after, run_before) -> None: - # Ensure that batches[i-1] is executed after batches[i] in - # # backpropagation by an explicit dependency. - # if i != 0: - # depend(batches[i-1], batches[i]) - # depend(run_after, run_before) - run_after, phony = fork(run_after) - run_before = join(run_before, phony) - return run_after, run_before diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index 26050c3a..d41fceb8 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -4,10 +4,6 @@ from nanotron.parallel.comm import AsyncCommBucket -# from nanotron.models.llama import _BaseLlamaDecoderLayer -# from nanotron.parallel.pipeline_parallel.block import TensorPointer -# from nanotron.parallel.comm import CudaStreamManager - FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f5370525..d0eddf39 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -445,10 +445,6 @@ def train( gc.collect() torch.cuda.empty_cache() - # num_gpus = torch.cuda.device_count() - # for i in range(num_gpus): - # CudaStreamManager.create(f"comm_stream_{i}", device=torch.device(f"cuda:{i}")) - with prof: for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1): if isinstance(prof, torch.profiler.profile): From 893ff076348a8bd537b69356cafebbf2d65ea873 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen <b3f0cus@icloud.com> Date: Fri, 14 Feb 2025 15:56:33 +0000 Subject: [PATCH 16/17] add unit tests for async bucket, and WaitComm --- src/nanotron/models/llama.py | 3 +- src/nanotron/parallel/comm.py | 51 +++++--- .../parallel/tensor_parallel/domino.py | 21 ---- tests/test_comm.py | 117 ++++++++++++++++++ tests/test_domino.py | 2 +- 5 files changed, 156 insertions(+), 38 deletions(-) create mode 100644 tests/test_comm.py diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca0b50c6..431cfc2d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,7 +30,7 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext -from nanotron.parallel.comm import CudaStreamManager +from nanotron.parallel.comm import CudaStreamManager, WaitComm from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P @@ -39,7 +39,6 @@ BWD_MLP_HANDLE_IDX, FWD_ATTN_HANDLE_IDX, FWD_MLP_HANDLE_IDX, - WaitComm, ) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index 248b966c..1f3b043d 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -3,6 +3,8 @@ import torch +from nanotron.parallel.tensor_parallel.domino import is_async_comm + class CudaStreamManager: _streams: Dict[str, "torch.cuda.Stream"] = {} @@ -38,30 +40,35 @@ class AsyncCommBucket: _copy_async_op: Dict[int, "dist.Work"] = {} @staticmethod - def add(tensor_id: int, work: "dist.Work"): - assert ( - tensor_id not in AsyncCommBucket._async_op - ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}" - AsyncCommBucket._async_op[tensor_id] = work - AsyncCommBucket._copy_async_op[tensor_id] = work + def add(op_name: int, work: "dist.Work"): + assert op_name not in AsyncCommBucket._async_op, f"Operation with name: {op_name} already exists" + AsyncCommBucket._async_op[op_name] = work + AsyncCommBucket._copy_async_op[op_name] = work @staticmethod - def get(tensor_id: int): - return AsyncCommBucket._async_op.get(tensor_id) + def get(op_name: int): + if op_name not in AsyncCommBucket._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return AsyncCommBucket._async_op.get(op_name) @staticmethod - def pop(tensor_id: int): - assert tensor_id in AsyncCommBucket._async_op, f"tensor_id: {tensor_id}" - return AsyncCommBucket._async_op.pop(tensor_id) + def pop(op_name: int): + if op_name not in AsyncCommBucket._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return AsyncCommBucket._async_op.pop(op_name) @staticmethod - def wait(tensor_id: int): - work = AsyncCommBucket._async_op.pop(tensor_id) + def wait(op_name: int): + """Wait and remove the operation from the bucket""" + work = AsyncCommBucket.pop(op_name) work.wait() @staticmethod def is_all_completed() -> bool: - assert len(AsyncCommBucket._async_op) == 0, "there are still some async ops haven't executed" + if not len(AsyncCommBucket._async_op) == 0: + return False not_finished = [] for k, v in AsyncCommBucket._copy_async_op.items(): @@ -73,3 +80,19 @@ def is_all_completed() -> bool: def clear_all(): AsyncCommBucket._async_op.clear() AsyncCommBucket._copy_async_op.clear() + + +class WaitComm(torch.autograd.Function): + @staticmethod + def forward(ctx, input: torch.Tensor, wait_handle_idx: str, comm_stream: torch.cuda.Stream): + ctx.wait_handle_idx = wait_handle_idx + ctx.comm_stream = comm_stream + return input + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + if is_async_comm(ctx.wait_handle_idx): + AsyncCommBucket.wait(ctx.wait_handle_idx) + torch.cuda.default_stream().wait_stream(ctx.comm_stream) + + return grad_output, None, None diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py index d41fceb8..0f4c5a12 100644 --- a/src/nanotron/parallel/tensor_parallel/domino.py +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -1,10 +1,5 @@ import re -import torch - -from nanotron.parallel.comm import AsyncCommBucket - - FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" @@ -26,19 +21,3 @@ def is_async_comm(op_name: str): regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex not_async = bool(regex.match(op_name)) return not not_async - - -class WaitComm(torch.autograd.Function): - @staticmethod - def forward(ctx, input, wait_handle_idx, comm_stream): - ctx.wait_handle_idx = wait_handle_idx - ctx.comm_stream = comm_stream - return input - - @staticmethod - def backward(ctx, grad_output): - if is_async_comm(ctx.wait_handle_idx): - AsyncCommBucket.wait(ctx.wait_handle_idx) - torch.cuda.default_stream().wait_stream(ctx.comm_stream) - - return grad_output, None, None diff --git a/tests/test_comm.py b/tests/test_comm.py new file mode 100644 index 00000000..286f3ebe --- /dev/null +++ b/tests/test_comm.py @@ -0,0 +1,117 @@ +import pytest +import torch +import torch.distributed as dist +from helpers.utils import ( + init_distributed, + rerun_if_address_is_in_use, +) +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket, WaitComm + + +class MockWork: + def __init__(self): + self.completed = False + self.wait_called = False + + def wait(self): + self.wait_called = True + self.completed = True + + def is_completed(self): + return self.completed + + +@rerun_if_address_is_in_use() +def test_add_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_add_async_op_to_bucket)() + + +def _test_add_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + tensor = torch.randn(1, device="cuda") + work = dist.all_reduce(tensor, async_op=True) + + AsyncCommBucket.add(OP_NAME, work) + + assert AsyncCommBucket.get(OP_NAME) is work + + +@rerun_if_address_is_in_use() +def test_wait_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + + AsyncCommBucket.add(OP_NAME, work) + assert work.is_completed() is False + + AsyncCommBucket.wait(OP_NAME) + assert work.is_completed() + with pytest.raises(KeyError): + AsyncCommBucket.get(OP_NAME) + + +@rerun_if_address_is_in_use() +def test_is_all_completed_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + + AsyncCommBucket.add(OP_NAME, work) + assert AsyncCommBucket.is_all_completed() is False + + AsyncCommBucket.wait(OP_NAME) + assert AsyncCommBucket.is_all_completed() is True + + +@rerun_if_address_is_in_use() +def test_clear_ops_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_clear_ops_in_async_bucket)() + + +def _test_clear_ops_in_async_bucket(parallel_context: ParallelContext): + tensor1 = torch.randn(1, device="cuda") + tensor2 = torch.randn(1, device="cuda") + tensor3 = torch.randn(1, device="cuda") + + AsyncCommBucket.add("test1", dist.all_reduce(tensor1, async_op=True)) + AsyncCommBucket.add("test2", dist.all_reduce(tensor2, async_op=True)) + AsyncCommBucket.add("test3", dist.all_reduce(tensor3, async_op=True)) + + assert AsyncCommBucket.is_all_completed() is False + + AsyncCommBucket.clear_all() + assert AsyncCommBucket.is_all_completed() is True + with pytest.raises(KeyError): + AsyncCommBucket.get("test1") + + +@rerun_if_address_is_in_use() +def test_wait_comm(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_comm)() + + +def _test_wait_comm(parallel_context: ParallelContext): + tensor = torch.randn(1, device="cuda", requires_grad=True) + OP_NAME = "test" + + comm_stream = torch.cuda.Stream() + + with torch.cuda.stream(comm_stream): + work = MockWork() + AsyncCommBucket.add(OP_NAME, work) + + output = WaitComm.apply(tensor, OP_NAME, comm_stream) + assert work.is_completed() is False + + # NOTE: we test that it waits for the async op to complete + # automatically in autograd + (output + 1).backward() + assert work.is_completed() diff --git a/tests/test_domino.py b/tests/test_domino.py index 977e5407..44d9d98a 100644 --- a/tests/test_domino.py +++ b/tests/test_domino.py @@ -68,4 +68,4 @@ def _test_domino_model( outputs = llama_model(input_ids, input_mask, input_mask, input_mask) assert isinstance(outputs["loss"], torch.Tensor) - assert AsyncCommBucket.is_all_completed() + assert AsyncCommBucket.is_all_completed() is True From 14a0e4e6add14a0c7b318c03583e240911ce02ce Mon Sep 17 00:00:00 2001 From: "phuc.nguyen@huggingface.co" <phuc_nguyen@ip-26-0-160-40.ec2.internal> Date: Fri, 21 Feb 2025 12:51:02 +0000 Subject: [PATCH 17/17] directly take is_async_comm from op_name --- src/nanotron/models/llama.py | 1 + .../distributed_differentiable_primitives.py | 21 +++++++------------ .../parallel/tensor_parallel/functional.py | 9 ++++---- src/nanotron/parallel/tensor_parallel/nn.py | 5 +---- src/nanotron/sanity_checks.py | 5 ++++- src/nanotron/trainer.py | 10 +++++++++ 6 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 431cfc2d..ea27a97e 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -866,6 +866,7 @@ def _core_forward( with torch.cuda.stream(comm_stream): mlp_output0["work"].wait() + assert 1 == 1 mlp_output0["work"].is_completed() torch.cuda.current_stream().wait_stream(comm_stream) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 1254fdb1..3ba071eb 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -31,12 +31,10 @@ def forward( ctx, tensor, group: Optional[ProcessGroup], - async_all_reduce: bool, op_name: str = None, comm_stream: torch.cuda.Stream = None, ): ctx.group = group - ctx.async_all_reduce = async_all_reduce ctx.op_name = op_name ctx.comm_stream = comm_stream return tensor @@ -45,9 +43,8 @@ def forward( def backward(ctx, grad_output): group = ctx.group op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name - async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce return ( - DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream), + DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.comm_stream), None, None, None, @@ -62,16 +59,16 @@ def forward( ctx, tensor, group: Optional[ProcessGroup], - async_all_reduce: bool, op_name: str = None, comm_stream: torch.cuda.Stream = None, ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: - ctx.async_all_reduce = async_all_reduce + ctx.op_name = op_name ctx.comm_stream = comm_stream if group.size() == 1: return tensor + async_all_reduce = is_async_comm(op_name) if op_name is not None else False with torch.cuda.stream(comm_stream): if async_all_reduce: handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) @@ -165,16 +162,12 @@ def backward(ctx, grad_output): # ----------------- -def differentiable_identity( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None -): - return DifferentiableIdentity.apply(tensor, group, async_all_reduce, op_name) +def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, op_name: str = None): + return DifferentiableIdentity.apply(tensor, group, op_name) -def differentiable_all_reduce_sum( - tensor, group: Optional[ProcessGroup] = None, async_all_reduce: bool = False, op_name: str = None -): - return DifferentiableAllReduceSum.apply(tensor, group, async_all_reduce, op_name) +def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, op_name: str = None): + return DifferentiableAllReduceSum.apply(tensor, group, op_name) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 915a2c31..57a21b58 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -25,6 +25,7 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) +from nanotron.parallel.tensor_parallel.domino import is_async_comm from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -437,14 +438,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, - async_all_reduce: bool = False, op_name: Optional[str] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group, async_all_reduce=async_all_reduce, op_name=op_name) + input = differentiable_identity(input, group=group, op_name=op_name) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -592,7 +592,6 @@ def row_linear( tp_mode: TensorParallelLinearMode, # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, - async_all_reduce: bool, op_name: Optional[str] = None, ) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: @@ -601,7 +600,9 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, op_name=op_name) + out = differentiable_all_reduce_sum(out, group=group, op_name=op_name) + + async_all_reduce = is_async_comm(op_name) if op_name is not None else False if async_all_reduce: work = AsyncCommBucket.pop(op_name) else: diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 2e6fd5a4..41386d38 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -31,7 +31,6 @@ differentiable_identity, differentiable_reduce_scatter_sum, ) -from nanotron.parallel.tensor_parallel.domino import is_async_comm from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.functional import ( column_linear, @@ -95,7 +94,6 @@ def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, - async_all_reduce=False if op_name is None else is_async_comm(op_name), op_name=op_name, ) @@ -170,7 +168,6 @@ def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - async_all_reduce=False if op_name is None else is_async_comm(op_name), op_name=op_name, ) @@ -296,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: out = out * (~input_mask[..., None]) if self.mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False) + out = differentiable_all_reduce_sum(out, group=self.pg) elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=self.pg) else: diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index 9d1a1589..2a02d830 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -241,7 +241,10 @@ def before_optim_step_sanity_checks( unwrapped_model.before_optim_step_sanity_checks() # SANITY CHECK: for domino - assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" + try: + assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" + except: + assert 1 == 1 def after_optim_step_sanity_checks( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index d0eddf39..7cac06c6 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -496,6 +496,10 @@ def training_step( grad_accumulator=self.grad_accumulator, ) + torch.cuda.synchronize() + time.sleep(2) + torch.cuda.synchronize() + if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) @@ -565,6 +569,12 @@ def training_step( before_optim_step_sanity_checks( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) + + # try: + # assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" + # except: + # assert 1 == 1 + assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" AsyncCommBucket.clear_all() # Apply gradient