diff --git a/examples/config_llama_domino.yaml b/examples/config_llama_domino.yaml new file mode 100644 index 00000000..b9811fdd --- /dev/null +++ b/examples/config_llama_domino.yaml @@ -0,0 +1,98 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: nanotron_domino + run: config_llama_domino + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 16384 + is_llama_config: true + max_position_embeddings: 4096 + num_attention_heads: 32 + num_hidden_layers: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 1000 + lr_decay_style: cosine + lr_warmup_steps: 500 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + pp: 1 + tp: 8 + expert_parallel_size: 1 + pp_engine: 1f1b + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + domino: + num_input_batches: 2 +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 4096 + train_steps: 1500 + val_check_interval: -1 diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 7f20ad99..07688959 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -11,6 +11,23 @@ from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +@dataclass +class DominoArgs: + """ + Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping + https://arxiv.org/abs/2409.15241 + """ + + # NOTE: if the number of input batches is 1, + # it's equivalent to non-domino mode + # so if you want to enable domino mode, set this to > 1 + num_input_batches: int + + def __post_init__(self): + assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1" + assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino" + + @dataclass class ParallelismArgs: """Arguments related to TP/PP/DP @@ -37,6 +54,7 @@ class ParallelismArgs: tp_recompute_allgather: bool = True expert_parallel_size: int = 1 + domino: Optional[DominoArgs] = None def __post_init__(self): # Conservative defaults @@ -51,3 +69,10 @@ def __post_init__(self): self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine) if isinstance(self.tp_mode, str): self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()] + + if self.is_domino_enabled is True: + assert self.tp > 1, "Domino requires TP > 1" + + @property + def is_domino_enabled(self) -> bool: + return True if self.domino else False diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 73ca3484..7f31d812 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -482,7 +482,9 @@ def get_profiler(config: Config): on_trace_ready=on_trace_ready, # record_shapes=True, # profile_memory=True, + with_flops=True, with_stack=True, + with_modules=True, ) else: prof = contextlib.nullcontext() diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..ea27a97e 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,9 +30,16 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import CudaStreamManager, WaitComm from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.domino import ( + BWD_ATTN_HANDLE_IDX, + BWD_MLP_HANDLE_IDX, + FWD_ATTN_HANDLE_IDX, + FWD_MLP_HANDLE_IDX, +) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -240,10 +247,10 @@ def __init__( ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] - merged_states = self.gate_up_proj(hidden_states) - hidden_states = self.down_proj(self.split_silu_mul(merged_states)) - return {"hidden_states": hidden_states} + def forward(self, hidden_states, op_name: str = None): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states, op_name=op_name) + hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), op_name=op_name) + return {"hidden_states": hidden_states, "work": work} class CoreAttention(nn.Module): @@ -434,6 +441,7 @@ def forward( self, hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] + op_name: str = None, ): from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -442,7 +450,7 @@ def forward( ) qkv_states = self.qkv_proj( - hidden_states + hidden_states, op_name=op_name ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape @@ -687,12 +695,12 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output = self.o_proj(attention_output) + output, work = self.o_proj(attention_output, op_name=op_name) - return {"hidden_states": output, "sequence_mask": sequence_mask} + return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} -class LlamaDecoderLayer(nn.Module): +class _BaseLlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, @@ -702,6 +710,7 @@ def __init__( ): super().__init__() self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, @@ -713,25 +722,8 @@ def __init__( self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) self.recompute_layer = parallel_config.recompute_layer - - def _core_forward( - self, - hidden_states: Union[torch.Tensor, TensorPointer], - sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> List[Union[torch.Tensor, TensorPointer]]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) - hidden_states = output["hidden_states"] - hidden_states = hidden_states + residual - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] - hidden_states = hidden_states + residual - - return hidden_states, output["sequence_mask"] + self.parallel_config = parallel_config + self.layer_idx = layer_idx def _checkpointed_forward( self, @@ -757,6 +749,27 @@ def forward( } +class LlamaDecoderLayer(_BaseLlamaDecoderLayer): + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> List[Union[torch.Tensor, TensorPointer]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + hidden_states = hidden_states + residual + + return hidden_states, output["sequence_mask"] + + class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -787,6 +800,85 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ return {"input_embeds": input_embeds} +class DominoLlamaDecoderLayer(_BaseLlamaDecoderLayer): + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> List[Union[torch.Tensor, TensorPointer]]: + num_input_batches = self.parallel_config.domino.num_input_batches + orig_sequence_mask = sequence_mask + comm_stream = CudaStreamManager.get(f"comm_stream_{torch.cuda.current_device()}") + + hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) + sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) + + hidden_states0, hidden_states1 = hidden_states + sequence_mask0, sequence_mask1 = sequence_mask + + residual0 = hidden_states0 + residual1 = hidden_states1 + + hidden_states0 = self.input_layernorm(hidden_states0) + hidden_states1 = self.input_layernorm(hidden_states1) + hidden_states0 = WaitComm.apply(hidden_states0, BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), comm_stream) + hidden_states1 = WaitComm.apply(hidden_states1, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), comm_stream) + + attn_output0 = self.attn( + hidden_states=hidden_states0, + sequence_mask=sequence_mask0, + op_name=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0), + ) + attn_output1 = self.attn( + hidden_states=hidden_states1, + sequence_mask=sequence_mask1, + op_name=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1), + ) + + with torch.cuda.stream(comm_stream): + attn_output0["work"].wait() + attn_output0["work"].is_completed() + + hidden_states0 = attn_output0["hidden_states"] + residual0 + residual0 = hidden_states0 + hidden_states0 = self.post_attention_layernorm(hidden_states0) + hidden_states0 = WaitComm.apply(hidden_states0, BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), comm_stream) + + mlp_output0 = self.mlp( + hidden_states=hidden_states0, + op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), + ) + + with torch.cuda.stream(comm_stream): + attn_output1["work"].wait() + attn_output1["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) + + hidden_states1 = attn_output1["hidden_states"] + residual1 + residual1 = hidden_states1 + hidden_states1 = self.post_attention_layernorm(hidden_states1) + + mlp_output1 = self.mlp( + hidden_states=hidden_states1, + op_name=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1), + ) + + with torch.cuda.stream(comm_stream): + mlp_output0["work"].wait() + assert 1 == 1 + mlp_output0["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) + + hidden_states0 = mlp_output0["hidden_states"] + residual0 + hidden_states1 = mlp_output1["hidden_states"] + residual1 + + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + + return hidden_states, orig_sequence_mask + + class LlamaModel(nn.Module): """Build pipeline graph""" @@ -796,6 +888,8 @@ def __init__( parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], ): + # from nanotron.parallel.tensor_parallel.domino import DominoLlamaDecoderLayer + super().__init__() # Declare all the nodes @@ -831,7 +925,7 @@ def __init__( [ PipelineBlock( p2p=self.p2p, - module_builder=LlamaDecoderLayer, + module_builder=DominoLlamaDecoderLayer if parallel_config.is_domino_enabled else LlamaDecoderLayer, module_kwargs={ "config": config, "parallel_config": parallel_config, @@ -892,15 +986,16 @@ def forward_with_hidden_states( input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] ): # all tensors are optional as most ranks don't need anything from the dataloader. - output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) hidden_encoder_states = { "hidden_states": output["input_embeds"], "sequence_mask": input_mask, } - for encoder_block in self.decoder: - hidden_encoder_states = encoder_block(**hidden_encoder_states) + + for layer_idx, encoder_block in enumerate(self.decoder): + with torch.profiler.record_function(f"layer_{layer_idx}"): + hidden_encoder_states = encoder_block(**hidden_encoder_states) hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py new file mode 100644 index 00000000..1f3b043d --- /dev/null +++ b/src/nanotron/parallel/comm.py @@ -0,0 +1,98 @@ +from contextlib import contextmanager +from typing import Dict + +import torch + +from nanotron.parallel.tensor_parallel.domino import is_async_comm + + +class CudaStreamManager: + _streams: Dict[str, "torch.cuda.Stream"] = {} + + @staticmethod + def create(name: str, device: torch.device = None): + assert name not in CudaStreamManager._streams + CudaStreamManager._streams[name] = torch.cuda.Stream(device=device) + + @staticmethod + def get(name: str): + if name not in CudaStreamManager._streams: + CudaStreamManager.create(name) + return CudaStreamManager._streams.get(name) + + @contextmanager + def run_on_stream(name: str): + stream = CudaStreamManager.get(name) + with torch.cuda.stream(stream): + yield stream + + +class AsyncCommBucket: + """ + + Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + RuntimeError: expected Variable or None (got tuple) + Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + RuntimeError: expected Variable or None (got tuple) + """ + + _async_op: Dict[int, "dist.Work"] = {} + _copy_async_op: Dict[int, "dist.Work"] = {} + + @staticmethod + def add(op_name: int, work: "dist.Work"): + assert op_name not in AsyncCommBucket._async_op, f"Operation with name: {op_name} already exists" + AsyncCommBucket._async_op[op_name] = work + AsyncCommBucket._copy_async_op[op_name] = work + + @staticmethod + def get(op_name: int): + if op_name not in AsyncCommBucket._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return AsyncCommBucket._async_op.get(op_name) + + @staticmethod + def pop(op_name: int): + if op_name not in AsyncCommBucket._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return AsyncCommBucket._async_op.pop(op_name) + + @staticmethod + def wait(op_name: int): + """Wait and remove the operation from the bucket""" + work = AsyncCommBucket.pop(op_name) + work.wait() + + @staticmethod + def is_all_completed() -> bool: + if not len(AsyncCommBucket._async_op) == 0: + return False + + not_finished = [] + for k, v in AsyncCommBucket._copy_async_op.items(): + if v.is_completed() is not True: + not_finished.append((k, v)) + return len(not_finished) == 0 + + @staticmethod + def clear_all(): + AsyncCommBucket._async_op.clear() + AsyncCommBucket._copy_async_op.clear() + + +class WaitComm(torch.autograd.Function): + @staticmethod + def forward(ctx, input: torch.Tensor, wait_handle_idx: str, comm_stream: torch.cuda.Stream): + ctx.wait_handle_idx = wait_handle_idx + ctx.comm_stream = comm_stream + return input + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + if is_async_comm(ctx.wait_handle_idx): + AsyncCommBucket.wait(ctx.wait_handle_idx) + torch.cuda.default_stream().wait_stream(ctx.comm_stream) + + return grad_output, None, None diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..076943c7 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -12,8 +15,6 @@ from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index bd41347a..3ba071eb 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -12,43 +12,75 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from torch import distributed as torch_dist from nanotron import distributed as dist from nanotron.distributed import ProcessGroup +from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.tensor_parallel.domino import is_async_comm class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward( + ctx, + tensor, + group: Optional[ProcessGroup], + op_name: str = None, + comm_stream: torch.cuda.Stream = None, + ): ctx.group = group + ctx.op_name = op_name + ctx.comm_stream = comm_stream return tensor @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllReduceSum.apply(grad_output, group), None + op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name + return ( + DifferentiableAllReduceSum.apply(grad_output, group, op_name, ctx.comm_stream), + None, + None, + None, + ) class DifferentiableAllReduceSum(torch.autograd.Function): """All-reduce in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward( + ctx, + tensor, + group: Optional[ProcessGroup], + op_name: str = None, + comm_stream: torch.cuda.Stream = None, + ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: + ctx.op_name = op_name + ctx.comm_stream = comm_stream + if group.size() == 1: return tensor - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + async_all_reduce = is_async_comm(op_name) if op_name is not None else False + with torch.cuda.stream(comm_stream): + if async_all_reduce: + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) + AsyncCommBucket.add(op_name, handle) + else: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + return tensor @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -130,12 +162,12 @@ def backward(ctx, grad_output): # ----------------- -def differentiable_identity(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableIdentity.apply(tensor, group) +def differentiable_identity(tensor, group: Optional[ProcessGroup] = None, op_name: str = None): + return DifferentiableIdentity.apply(tensor, group, op_name) -def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableAllReduceSum.apply(tensor, group) +def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None, op_name: str = None): + return DifferentiableAllReduceSum.apply(tensor, group, op_name) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py new file mode 100644 index 00000000..0f4c5a12 --- /dev/null +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -0,0 +1,23 @@ +import re + +FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}" +FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}" +BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}" +BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}" + + +def is_async_comm(op_name: str): + """ + There are two operations that we can't overlap + for the forward pass: the last micro-batch of the mlp layer + for the backward pass: the first micro-batch of the attention layer + """ + NON_ASYNC_HANDLE_IDX = [ + "fwd.layer_mlp_{}_batch_1", + "bwd.layer_attn_{}_batch_0", + ] + + patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers + regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex + not_async = bool(regex.match(op_name)) + return not not_async diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index e2ee3a29..57a21b58 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import Optional, Tuple import torch from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) +from nanotron.parallel.tensor_parallel.domino import is_async_comm from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 @@ -436,12 +438,13 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + op_name: Optional[str] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group) + input = differentiable_identity(input, group=group, op_name=op_name) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -587,18 +590,28 @@ def row_linear( bias: Optional[torch.Tensor], group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, + # TODO(xrsrke): use less confusing names for these arguments async_communication: bool, -): + op_name: Optional[str] = None, +) -> Tuple[torch.Tensor, Optional[torch.Future]]: if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group) + out = differentiable_all_reduce_sum(out, group=group, op_name=op_name) + + async_all_reduce = is_async_comm(op_name) if op_name is not None else False + if async_all_reduce: + work = AsyncCommBucket.pop(op_name) + else: + work = None elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode." out = differentiable_reduce_scatter_sum(out, group=group) + work = None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") - return out + return out, work diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 4c7325cd..41386d38 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -85,7 +85,7 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -94,6 +94,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, + op_name=op_name, ) def extra_repr(self) -> str: @@ -133,6 +134,7 @@ def __init__( ) self.mode = mode self.async_communication = async_communication + if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") @@ -158,7 +160,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, op_name: str = None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -166,6 +168,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + op_name=op_name, ) def extra_repr(self) -> str: diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index 56ef1e2e..2a02d830 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -10,6 +10,7 @@ from nanotron.models import NanotronModel from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.tied_parameters import get_tied_id_to_param logger = get_logger(__name__) @@ -239,6 +240,12 @@ def before_optim_step_sanity_checks( # SANITY CHECK: run model specific sanity checks unwrapped_model.before_optim_step_sanity_checks() + # SANITY CHECK: for domino + try: + assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" + except: + assert 1 == 1 + def after_optim_step_sanity_checks( config: Config, diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 94b03c6e..7cac06c6 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -61,6 +61,7 @@ from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.parallel.pipeline_parallel.engine import ( @@ -443,6 +444,7 @@ def train( # free memory gc.collect() torch.cuda.empty_cache() + with prof: for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1): if isinstance(prof, torch.profiler.profile): @@ -494,6 +496,10 @@ def training_step( grad_accumulator=self.grad_accumulator, ) + torch.cuda.synchronize() + time.sleep(2) + torch.cuda.synchronize() + if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) @@ -564,6 +570,13 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) + # try: + # assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" + # except: + # assert 1 == 1 + assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing" + AsyncCommBucket.clear_all() + # Apply gradient self.optimizer.step() self.optimizer.zero_grad() diff --git a/tests/helpers/llama.py b/tests/helpers/llama.py index 3f94031f..8aae7669 100644 --- a/tests/helpers/llama.py +++ b/tests/helpers/llama.py @@ -1,5 +1,6 @@ import torch from nanotron.config import ( + AdamWOptimizerArgs, AllForwardAllBackwardPipelineEngine, CheckpointsArgs, Config, @@ -46,7 +47,19 @@ ) -def get_llama_training_config(model_config: ModelArgs): +def get_parallel_config(parallel_context: ParallelContext): + return ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + expert_parallel_size=parallel_context.expert_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + +def get_llama_training_config(model_config: ModelArgs, parallel_context): return Config( model=model_config, general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), @@ -54,25 +67,20 @@ def get_llama_training_config(model_config: ModelArgs): checkpoints_path="./checkpoints", checkpoint_interval=10, ), - parallelism=ParallelismArgs( - dp=1, - pp=1, - tp=2, - expert_parallel_size=2, - pp_engine="1f1b", - tp_mode="ALL_REDUCE", - tp_linear_async_communication=False, - ), + parallelism=get_parallel_config(parallel_context), tokenizer=TokenizerArgs("gpt2"), optimizer=OptimizerArgs( zero_stage=0, weight_decay=0.01, clip_grad=1.0, accumulate_grad_in_fp32=False, - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + name="adamW", + ), learning_rate_scheduler=LRSchedulerArgs( learning_rate=3e-4, lr_warmup_steps=100, @@ -103,7 +111,10 @@ def get_llama_training_config(model_config: ModelArgs): def create_llama_from_config( - model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext + model_config: LlamaConfig, + parallel_config: ParallelismArgs, + device: torch.device, + parallel_context: ParallelContext, ) -> LlamaForTraining: """ @@ -114,14 +125,6 @@ def create_llama_from_config( the model created will have random weights. """ - parallel_config = ParallelismArgs( - dp=parallel_context.data_parallel_size, - pp=parallel_context.pipeline_parallel_size, - tp=parallel_context.tensor_parallel_size, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) model = build_model( model_builder=lambda: LlamaForTraining( config=model_config, diff --git a/tests/test_base_model.py b/tests/test_base_model.py index b4759905..410e302c 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -10,7 +10,6 @@ @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)]) -@pytest.mark.skip @rerun_if_address_is_in_use() def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int): model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) @@ -43,3 +42,34 @@ def _test_get_named_modules_in_pp_rank( # not PipelineBlock assert isinstance(module, nn.Module) assert name not in modules_that_not_in_current_pp_rank + + +@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_llama_model(tp: int, dp: int, pp: int): + BATCH_SIZE, SEQ_LEN = 10, 128 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + config = get_llama_training_config(model_args) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_llama_model)(config=config, batch_size=BATCH_SIZE, seq_len=SEQ_LEN) + + +def _test_llama_model( + parallel_context: ParallelContext, + config: Config, + batch_size: int, + seq_len: int, +): + llama_model = create_llama_from_config( + model_config=config.model.model_config, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama_model.init_model_randomly(config=config) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert list(outputs.keys()) == ["loss"] + assert isinstance(outputs["loss"], torch.Tensor) diff --git a/tests/test_comm.py b/tests/test_comm.py new file mode 100644 index 00000000..286f3ebe --- /dev/null +++ b/tests/test_comm.py @@ -0,0 +1,117 @@ +import pytest +import torch +import torch.distributed as dist +from helpers.utils import ( + init_distributed, + rerun_if_address_is_in_use, +) +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket, WaitComm + + +class MockWork: + def __init__(self): + self.completed = False + self.wait_called = False + + def wait(self): + self.wait_called = True + self.completed = True + + def is_completed(self): + return self.completed + + +@rerun_if_address_is_in_use() +def test_add_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_add_async_op_to_bucket)() + + +def _test_add_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + tensor = torch.randn(1, device="cuda") + work = dist.all_reduce(tensor, async_op=True) + + AsyncCommBucket.add(OP_NAME, work) + + assert AsyncCommBucket.get(OP_NAME) is work + + +@rerun_if_address_is_in_use() +def test_wait_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + + AsyncCommBucket.add(OP_NAME, work) + assert work.is_completed() is False + + AsyncCommBucket.wait(OP_NAME) + assert work.is_completed() + with pytest.raises(KeyError): + AsyncCommBucket.get(OP_NAME) + + +@rerun_if_address_is_in_use() +def test_is_all_completed_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + + AsyncCommBucket.add(OP_NAME, work) + assert AsyncCommBucket.is_all_completed() is False + + AsyncCommBucket.wait(OP_NAME) + assert AsyncCommBucket.is_all_completed() is True + + +@rerun_if_address_is_in_use() +def test_clear_ops_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_clear_ops_in_async_bucket)() + + +def _test_clear_ops_in_async_bucket(parallel_context: ParallelContext): + tensor1 = torch.randn(1, device="cuda") + tensor2 = torch.randn(1, device="cuda") + tensor3 = torch.randn(1, device="cuda") + + AsyncCommBucket.add("test1", dist.all_reduce(tensor1, async_op=True)) + AsyncCommBucket.add("test2", dist.all_reduce(tensor2, async_op=True)) + AsyncCommBucket.add("test3", dist.all_reduce(tensor3, async_op=True)) + + assert AsyncCommBucket.is_all_completed() is False + + AsyncCommBucket.clear_all() + assert AsyncCommBucket.is_all_completed() is True + with pytest.raises(KeyError): + AsyncCommBucket.get("test1") + + +@rerun_if_address_is_in_use() +def test_wait_comm(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_comm)() + + +def _test_wait_comm(parallel_context: ParallelContext): + tensor = torch.randn(1, device="cuda", requires_grad=True) + OP_NAME = "test" + + comm_stream = torch.cuda.Stream() + + with torch.cuda.stream(comm_stream): + work = MockWork() + AsyncCommBucket.add(OP_NAME, work) + + output = WaitComm.apply(tensor, OP_NAME, comm_stream) + assert work.is_completed() is False + + # NOTE: we test that it waits for the async op to complete + # automatically in autograd + (output + 1).backward() + assert work.is_completed() diff --git a/tests/test_domino.py b/tests/test_domino.py new file mode 100644 index 00000000..44d9d98a --- /dev/null +++ b/tests/test_domino.py @@ -0,0 +1,71 @@ +from copy import deepcopy + +import pytest +import torch +from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.utils import init_distributed, rerun_if_address_is_in_use +from nanotron.config import ModelArgs, RandomInit +from nanotron.config.parallelism_config import DominoArgs +from nanotron.models.llama import DominoLlamaDecoderLayer +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket +from nanotron.parallel.tensor_parallel.domino import is_async_comm + + +@pytest.mark.parametrize( + "op_name, expected", + [ + ("fwd.layer_attn_1_batch_0", True), + ("fwd.layer_attn_1_batch_1", True), + ("fwd.layer_mlp_1_batch_0", True), + ("fwd.layer_mlp_1_batch_1", False), + ("bwd.layer_mlp_1_batch_1", True), + ("bwd.layer_mlp_1_batch_0", True), + ("bwd.layer_attn_1_batch_1", True), + ("bwd.layer_attn_1_batch_0", False), + ], +) +def test_is_async_comm(op_name, expected): + assert is_async_comm(op_name) == expected + + +@pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_domino_model(tp: int, dp: int, pp: int): + BATCH_SIZE, SEQ_LEN = 10, 128 + + model_config = deepcopy(TINY_LLAMA_CONFIG) + model_config.num_hidden_layers = 28 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_domino_model)( + model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN + ) + + +def _test_domino_model( + parallel_context: ParallelContext, + model_args: ModelArgs, + batch_size: int, + seq_len: int, +): + config = get_llama_training_config(model_args, parallel_context) + config.parallelism.domino = DominoArgs(num_input_batches=2) + + llama_model = create_llama_from_config( + model_config=config.model.model_config, + parallel_config=config.parallelism, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama_model.init_model_randomly(config=config) + + for m in llama_model.model.decoder: + assert isinstance(m.pp_block, DominoLlamaDecoderLayer) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert isinstance(outputs["loss"], torch.Tensor) + assert AsyncCommBucket.is_all_completed() is True