diff --git a/examples/domino/config_domino.yaml b/examples/domino/config_domino.yaml new file mode 100644 index 00000000..9887ad26 --- /dev/null +++ b/examples/domino/config_domino.yaml @@ -0,0 +1,108 @@ +checkpoints: + checkpoint_interval: 10000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + load_lr_scheduler: false + load_optimizer: false + save_final_state: true + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: + - /fsx/loubna/datasets/llama_tokenized/other_sources/wiki + token_size_in_bytes: 4 + tokenizer_name: meta-llama/Llama-3.2-1B + vocab_size: 128256 + num_loading_workers: 8 + seed: 42 + name: Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: nanotron_domino + run: domino_config + seed: 6 + step: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.041666666666666664 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 4096 + num_attention_heads: 32 + num_hidden_layers: 20 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 2 + rms_norm_eps: 1.0e-05 + rope_interleaved: false + rope_scaling: + factor: 32.0 + high_freq_factor: 4.0 + low_freq_factor: 1.0 + original_max_position_embeddings: 4096 + rope_type: llama3 + rope_theta: 500000.0 + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.00005 + lr_decay_starting_step: 50000 + lr_decay_steps: 10000 + lr_decay_style: linear + lr_warmup_steps: 1000 + lr_warmup_style: linear + min_decay_lr: 0 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 1 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + recompute_layer: false + tp: 8 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + tp_recompute_allgather: false + domino: + num_input_batches: 2 +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Llama-3.2-1B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 2 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 8 + sequence_length: 4096 + train_steps: 15000 + val_check_interval: -1 diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 48aa941e..72f667ab 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -11,6 +11,23 @@ from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +@dataclass +class DominoArgs: + """ + Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping + https://arxiv.org/abs/2409.15241 + """ + + # NOTE: if the number of input batches is 1, + # it's equivalent to non-domino mode + # so if you want to enable domino mode, set this to > 1 + num_input_batches: int + + def __post_init__(self): + assert self.num_input_batches > 1, "In order to enable domino mode, set num_input_batches > 1" + assert self.num_input_batches == 2, "Currently parallelism only supports 2 batches for Domino" + + @dataclass class ParallelismArgs: """Arguments related to TP/PP/DP @@ -39,6 +56,7 @@ class ParallelismArgs: expert_parallel_size: int = 1 context_parallel_size: int = 1 + domino: Optional[DominoArgs] = None def __post_init__(self): # Conservative defaults @@ -53,3 +71,19 @@ def __post_init__(self): self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine) if isinstance(self.tp_mode, str): self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()] + + if self.is_domino_enabled is True: + assert self.tp > 1, "Domino requires TP > 1" + # NOTE: For DoMiNo since we overlapping the communication + # so it doesnt matter whether it's all_reduce or reduce_scatter + # so we just support and tested with all_reduce up to now + # but in principle, it should work with reduce_scatter as well + assert ( + self.tp_linear_async_communication is False + ), "Domino requires TP linear async communication to be False" + # TODO: support REDUCE_SCATTER mode for Domino + assert self.tp_mode == TensorParallelLinearMode.ALL_REDUCE, "Domino requires TP mode to be ALL_REDUCE" + + @property + def is_domino_enabled(self) -> bool: + return True if self.domino else False diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99d..61878181 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -10,3 +10,7 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" + + +### FOR COMMUNICATION ### +CUDA_STREAM_COMM_NAME = "comm_stream_{}" diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 9cbf53f7..5234a0e7 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -518,6 +518,7 @@ def get_profiler(config: Config): record_shapes=config.profiler.record_shapes, profile_memory=config.profiler.profile_memory, with_stack=config.profiler.with_stack, + with_modules=True, ) else: prof = contextlib.nullcontext() diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index db820644..4d7de8a6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -34,9 +34,17 @@ from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import CudaStreamManager, insert_backward_sync_to_tensor from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.domino import ( + BWD_ATTN_OP_NAME, + BWD_MLP_OP_NAME, + FWD_ATTN_OP_NAME, + FWD_MLP_OP_NAME, + OpNameContext, +) from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -211,6 +219,7 @@ def __init__( config: LlamaConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() @@ -233,6 +242,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, + stream_manager=stream_manager, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -241,10 +251,11 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + stream_manager=stream_manager, ) self.split_silu_mul = GLUActivation(config.hidden_act) - def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + def forward(self, hidden_states: torch.Tensor): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) hidden_states = self.down_proj(self.split_silu_mul(merged_states)) return {"hidden_states": hidden_states} @@ -339,6 +350,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, + stream_manager: Optional[CudaStreamManager] = None, ): from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding @@ -394,6 +406,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, tp_recompute_allgather=parallel_config.tp_recompute_allgather, + stream_manager=stream_manager, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -422,6 +435,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + stream_manager=stream_manager, ) self.attention = CoreAttention( @@ -717,46 +731,33 @@ def _forward_training(self, query_states, key_states, value_states, sequence_mas return {"hidden_states": output, "sequence_mask": sequence_mask} -class LlamaDecoderLayer(nn.Module): +class _BaseLlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, + stream_manager=stream_manager, ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, stream_manager=stream_manager) self.recompute_layer = parallel_config.recompute_layer - - def _core_forward( - self, - hidden_states: Union[torch.Tensor, TensorPointer], - sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> List[Union[torch.Tensor, TensorPointer]]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) - hidden_states = output["hidden_states"] - hidden_states = hidden_states + residual - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] - hidden_states = hidden_states + residual - - return hidden_states, output["sequence_mask"] + self.parallel_config = parallel_config + self.layer_idx = layer_idx + self.stream_manager = stream_manager def _checkpointed_forward( self, @@ -782,6 +783,142 @@ def forward( } +class DominoLlamaDecoderLayer(_BaseLlamaDecoderLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.stream_manager is not None, "DominoLlamaDecoderLayer requires a stream_manager" + + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> List[Union[torch.Tensor, TensorPointer]]: + num_input_batches = self.parallel_config.domino.num_input_batches + + orig_sequence_mask = sequence_mask + comm_stream = self.stream_manager.get_default_comm_stream() + comm_bucket = self.stream_manager.comm_bucket + + hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1) + sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0) + + hidden_states0, hidden_states1 = hidden_states + sequence_mask0, sequence_mask1 = sequence_mask + + residual0 = hidden_states0 + residual1 = hidden_states1 + + # TODO: overlap the 'layernorm > attn' of the second batch + # with the comm of the first batch in both forward and backward + hidden_states0 = self.input_layernorm(hidden_states0) + hidden_states1 = self.input_layernorm(hidden_states1) + hidden_states0 = insert_backward_sync_to_tensor( + hidden_states0, + BWD_ATTN_OP_NAME.format(self.layer_idx, 1), + self.stream_manager, + ) + hidden_states1 = insert_backward_sync_to_tensor( + hidden_states1, + BWD_MLP_OP_NAME.format(self.layer_idx, 0), + self.stream_manager, + ) + + # TODO: maybe try to bucket all the communication as in DPP, + # do it at at once + with OpNameContext( + fwd_op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 0), + bwd_op_name=BWD_ATTN_OP_NAME.format(self.layer_idx, 0), + ): + attn_output0 = self.attn( + hidden_states=hidden_states0, + sequence_mask=sequence_mask0, + ) + + with OpNameContext( + fwd_op_name=FWD_ATTN_OP_NAME.format(self.layer_idx, 1), + bwd_op_name=BWD_ATTN_OP_NAME.format(self.layer_idx, 1), + ): + attn_output1 = self.attn( + hidden_states=hidden_states1, + sequence_mask=sequence_mask1, + ) + + # TODO(xrsrke): double check if we need this explicit synchronization + # otherwise, remove it + comm_stream.wait_stream(torch.cuda.default_stream()) + with torch.cuda.stream(comm_stream): + comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 0)) + + torch.cuda.default_stream().wait_stream(comm_stream) + + hidden_states0 = attn_output0["hidden_states"] + residual0 + residual0 = hidden_states0 + hidden_states0 = self.post_attention_layernorm(hidden_states0) + hidden_states0 = insert_backward_sync_to_tensor( + hidden_states0, + BWD_MLP_OP_NAME.format(self.layer_idx, 1), + self.stream_manager, + ) + + with OpNameContext( + fwd_op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 0), + bwd_op_name=BWD_MLP_OP_NAME.format(self.layer_idx, 0), + ): + mlp_output0 = self.mlp(hidden_states=hidden_states0) + + comm_stream.wait_stream(torch.cuda.default_stream()) + with torch.cuda.stream(comm_stream): + comm_bucket.wait(FWD_ATTN_OP_NAME.format(self.layer_idx, 1)) + + torch.cuda.default_stream().wait_stream(comm_stream) + + hidden_states1 = attn_output1["hidden_states"] + residual1 + residual1 = hidden_states1 + hidden_states1 = self.post_attention_layernorm(hidden_states1) + + with OpNameContext( + fwd_op_name=FWD_MLP_OP_NAME.format(self.layer_idx, 1), + bwd_op_name=BWD_MLP_OP_NAME.format(self.layer_idx, 1), + ): + mlp_output1 = self.mlp(hidden_states=hidden_states1) + + comm_stream.wait_stream(torch.cuda.default_stream()) + with torch.cuda.stream(comm_stream): + comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 0)) + comm_bucket.wait(FWD_MLP_OP_NAME.format(self.layer_idx, 1)) + + torch.cuda.default_stream().wait_stream(comm_stream) + + hidden_states0 = mlp_output0["hidden_states"] + residual0 + hidden_states1 = mlp_output1["hidden_states"] + residual1 + + # TODO: make sure no memory overhead, + # and try a fixed memory buffer as in section 4.2 in the paper + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + return hidden_states, orig_sequence_mask + + +class LlamaDecoderLayer(_BaseLlamaDecoderLayer): + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> List[Union[torch.Tensor, TensorPointer]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + hidden_states = hidden_states + residual + + return hidden_states, output["sequence_mask"] + + class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -820,6 +957,7 @@ def __init__( config: LlamaConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() @@ -856,12 +994,13 @@ def __init__( [ PipelineBlock( p2p=self.p2p, - module_builder=LlamaDecoderLayer, + module_builder=DominoLlamaDecoderLayer if parallel_config.is_domino_enabled else LlamaDecoderLayer, module_kwargs={ "config": config, "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, + "stream_manager": stream_manager, }, module_input_keys={"hidden_states", "sequence_mask"}, module_output_keys={"hidden_states", "sequence_mask"}, @@ -924,6 +1063,7 @@ def forward_with_hidden_states( "hidden_states": output["input_embeds"], "sequence_mask": input_mask, } + for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) @@ -1030,10 +1170,15 @@ def __init__( parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, + stream_manager: Optional[CudaStreamManager] = None, ): super().__init__() - self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) - + self.model = LlamaModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + stream_manager=stream_manager, + ) # Choose the appropriate loss class based on config loss_kwargs = { "tp_pg": parallel_context.tp_pg, @@ -1056,6 +1201,7 @@ def __init__( self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config + self.stream_manager = stream_manager def forward( self, diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py new file mode 100644 index 00000000..d112076a --- /dev/null +++ b/src/nanotron/parallel/comm.py @@ -0,0 +1,132 @@ +from contextlib import contextmanager +from typing import Dict + +import torch + +from nanotron.constants import CUDA_STREAM_COMM_NAME +from nanotron.parallel.tensor_parallel.domino import is_domino_async_comm + + +class AsyncCommBucket: + """ + Store aynchronous communication operations. + """ + + def __init__(self): + self._async_op: Dict[int, "dist.Work"] = {} + self._copy_async_op: Dict[int, "dist.Work"] = {} + + def add(self, op_name: int, work: "dist.Work"): + assert op_name not in self._async_op, f"Operation with name: {op_name} already exists" + assert work is not None + self._async_op[op_name] = work + self._copy_async_op[op_name] = work + + def get(self, op_name: str) -> "dist.Work": + if op_name not in self._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return self._async_op.get(op_name) + + def pop(self, op_name: str) -> "dist.Work": + if op_name not in self._async_op: + raise KeyError(f"Operation with name: {op_name} doesn't exist") + + return self._async_op.pop(op_name) + + def wait(self, op_name: str): + """Wait and remove the operation from the bucket""" + work = self.pop(op_name) + work.wait() + + def is_all_completed(self) -> bool: + if not len(self._async_op) == 0: + return False + + not_finished = [] + for k, v in self._copy_async_op.items(): + if v.is_completed() is not True: + not_finished.append((k, v)) + return len(not_finished) == 0 + + def clear_all(self): + self._async_op.clear() + self._copy_async_op.clear() + + +class CudaStreamManager: + def __init__(self): + self._streams: Dict[str, "torch.cuda.Stream"] = {} + self.comm_bucket = AsyncCommBucket() + + def init_default_comm_stream(self): + """ + Initialize the default communication stream for the current cuda device. + """ + self.create(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device()), torch.cuda.current_device()) + + def create(self, name: str, device: torch.device): + assert name not in self._streams + self._streams[name] = torch.cuda.Stream(device=device) + + def get(self, name: str): + if name not in self._streams: + self.create(name) + return self._streams.get(name) + + def get_default_comm_stream(self) -> torch.cuda.Stream: + """ + Return the default communication stream for the current cuda device. + """ + return self.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device())) + + @contextmanager + def run_on_stream(self, name: str): + stream = self.get(name) + with torch.cuda.stream(stream): + yield stream + + +class WaitComm(torch.autograd.Function): + """ + Enforce a tensor to wait for the communication operation to finish + in torch's autograd graph. + """ + + @staticmethod + def forward(ctx, input: torch.Tensor, op_name: str, comm_stream: torch.cuda.Stream, comm_bucket: AsyncCommBucket): + assert isinstance(comm_stream, torch.cuda.Stream) + ctx.op_name = op_name + ctx.comm_stream = comm_stream + ctx.comm_bucket = comm_bucket + return input + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """ + NOTE: because the communication operation is already being executed + so the communication stream don't have to wait for the compute stream here + but the compute stream waits for the communication stream + before proceeding + """ + if is_domino_async_comm(ctx.op_name): + handle = ctx.comm_bucket.pop(ctx.op_name) + handle.wait() + + ctx.comm_stream.synchronize() + torch.cuda.default_stream().wait_stream(ctx.comm_stream) + + return grad_output, None, None, None + + +def insert_backward_sync_to_tensor( + tensor: torch.Tensor, op_name: str, stream_manager: CudaStreamManager +) -> torch.Tensor: + """ + Insert a wait communication operation of a given op_name to the autograd graph + of a tensor. + """ + + assert isinstance(stream_manager, CudaStreamManager) + comm_stream = stream_manager.get(CUDA_STREAM_COMM_NAME.format(torch.cuda.current_device())) + return WaitComm.apply(tensor, op_name, comm_stream, stream_manager.comm_bucket) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index bd41347a..a5e65258 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -12,43 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from contextlib import nullcontext +from typing import Optional, Tuple import torch from torch import distributed as torch_dist from nanotron import distributed as dist from nanotron.distributed import ProcessGroup +from nanotron.parallel.comm import CudaStreamManager, is_domino_async_comm +from nanotron.parallel.tensor_parallel.domino import get_current_bwd_op_name class DifferentiableIdentity(torch.autograd.Function): """All-reduce gradients in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward( + ctx, + tensor: torch.Tensor, + group: Optional[ProcessGroup], + op_name: Optional[str] = None, + stream_manager: Optional[CudaStreamManager] = None, + ): + ctx.bwd_op_name = get_current_bwd_op_name() ctx.group = group + ctx.stream_manager = stream_manager return tensor @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output: torch.Tensor): group = ctx.group - return DifferentiableAllReduceSum.apply(grad_output, group), None + + return ( + DifferentiableAllReduceSum.apply(grad_output, group, ctx.bwd_op_name, ctx.stream_manager), + None, + None, + None, + None, + ) class DifferentiableAllReduceSum(torch.autograd.Function): """All-reduce in a differentiable fashion""" @staticmethod - def forward(ctx, tensor, group: Optional[ProcessGroup]): + def forward( + ctx, + tensor: torch.Tensor, + group: Optional[ProcessGroup], + op_name: Optional[int] = None, + stream_manager: Optional[CudaStreamManager] = None, + ) -> Tuple[torch.Tensor, Optional["dist.Work"]]: + async_all_reduce = is_domino_async_comm(op_name) if op_name is not None else False + ctx.async_all_reduce = async_all_reduce + if group.size() == 1: return tensor - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + if stream_manager is not None: + comm_stream = stream_manager.get_default_comm_stream() + comm_stream.wait_stream(torch.cuda.default_stream()) + comm_context = torch.cuda.stream(comm_stream) + else: + comm_context = nullcontext() + + with comm_context: + if async_all_reduce is True: + assert comm_stream is not None + handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True) + stream_manager.comm_bucket.add(op_name, handle) + else: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + return tensor @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None, None, None class DifferentiableAllGather(torch.autograd.Function): @@ -130,12 +171,22 @@ def backward(ctx, grad_output): # ----------------- -def differentiable_identity(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableIdentity.apply(tensor, group) +def differentiable_identity( + tensor, + group: Optional[ProcessGroup] = None, + op_name: Optional[str] = None, + stream_manager: Optional[CudaStreamManager] = None, +): + return DifferentiableIdentity.apply(tensor, group, op_name, stream_manager) -def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): - return DifferentiableAllReduceSum.apply(tensor, group) +def differentiable_all_reduce_sum( + tensor, + group: Optional[ProcessGroup] = None, + op_name: Optional[str] = None, + stream_manager: Optional[CudaStreamManager] = None, +): + return DifferentiableAllReduceSum.apply(tensor, group, op_name, stream_manager) def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): diff --git a/src/nanotron/parallel/tensor_parallel/domino.py b/src/nanotron/parallel/tensor_parallel/domino.py new file mode 100644 index 00000000..3d0b4415 --- /dev/null +++ b/src/nanotron/parallel/tensor_parallel/domino.py @@ -0,0 +1,73 @@ +""" +Implementation of communication overlapping +in the paper "Domino: Eliminating Communication in LLM Training via +Generic Tensor Slicing and Overlapping" +https://arxiv.org/abs/2409.15241 +""" + +import re +import threading +from typing import Optional + +FWD_MLP_OP_NAME = "fwd.layer_mlp_{}_batch_{}" +FWD_ATTN_OP_NAME = "fwd.layer_attn_{}_batch_{}" +BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}" +BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}" + +_operation_context = threading.local() + + +def is_domino_async_comm(x: str) -> bool: + """ + Determine whether a module (e.g., mlp, attention) + runs all-reduce asynchronously in tensor parallelism + based on its module name. + + Currently support intra-layer communication overlapping + as described in domino's input splitting approach. + + How do we determine it? + + In the forward pass: We run all the forward pass's communication asynchronously + diagram: https://imgur.com/a/g5Ou2iZ + + + In the backward pass: We run all backward pass's communication asynchronously + except for the first batch's attention module. + https://imgur.com/a/MrZb57a + """ + NON_ASYNC_HANDLE_IDX = [ + "bwd.layer_attn_{}_batch_0", + ] + + patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers + regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex + not_async = bool(regex.match(x)) + return not not_async + + +class OpNameContext: + """Track both forward and backward operation names""" + + def __init__(self, fwd_op_name: str, bwd_op_name: str): + self.fwd_op_name = fwd_op_name + self.bwd_op_name = bwd_op_name + self.prev_fwd = None + self.prev_bwd = None + + def __enter__(self): + self.prev_fwd = getattr(_operation_context, "current_fwd_op_name", None) + self.prev_bwd = getattr(_operation_context, "current_bwd_op_name", None) + _operation_context.current_fwd_op_name = self.fwd_op_name + _operation_context.current_bwd_op_name = self.bwd_op_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _operation_context.current_fwd_op_name = self.prev_fwd + _operation_context.current_bwd_op_name = self.prev_bwd + + +def get_current_fwd_op_name() -> Optional[str]: + return getattr(_operation_context, "current_fwd_op_name", None) + + +def get_current_bwd_op_name() -> Optional[str]: + return getattr(_operation_context, "current_bwd_op_name", None) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 722bc52f..9ed815a6 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,11 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) +from nanotron.parallel.tensor_parallel.domino import get_current_fwd_op_name from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.utils import MemoryBuffer @@ -40,7 +42,7 @@ def forward( logits_max = torch.max(sharded_logits, dim=-1)[0] dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group) # Subtract the maximum value. - sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1) + sharded_logits.sub_(logits_max.unsqueeze(dim=-1)) # Get the shard's indices sharded_hidden_size = sharded_logits.shape[-1] @@ -552,12 +554,14 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - input = differentiable_identity(input, group=group) + op_name = get_current_fwd_op_name() + input = differentiable_identity(input, group=group, op_name=op_name, stream_manager=stream_manager) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( @@ -702,6 +706,7 @@ def row_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + stream_manager: Optional[CudaStreamManager] = None, ): if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -709,7 +714,8 @@ def row_linear( out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: - out = differentiable_all_reduce_sum(out, group=group) + op_name = get_current_fwd_op_name() + out = differentiable_all_reduce_sum(out, group=group, op_name=op_name, stream_manager=stream_manager) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: out = differentiable_reduce_scatter_sum(out, group=group) else: diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index b738c7b3..28ce424a 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -21,6 +21,7 @@ from nanotron import distributed as dist from nanotron.distributed import get_global_rank from nanotron.logging import get_logger +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.sharded_parameters import ( SplitConfig, @@ -56,6 +57,7 @@ def __init__( async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, tp_recompute_allgather: bool = True, + stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -76,6 +78,7 @@ def __init__( self.mode = mode self.async_communication = async_communication + self.stream_manager = stream_manager if self.world_size > 1: assert ( @@ -94,7 +97,10 @@ def __init__( split_config=split_config, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: return column_linear( input=x, weight=self.weight, @@ -103,6 +109,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tp_mode=self.mode, async_communication=self.async_communication, tp_recompute_allgather=self.tp_recompute_allgather, + stream_manager=self.stream_manager, ) def extra_repr(self) -> str: @@ -121,6 +128,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + stream_manager: Optional[CudaStreamManager] = None, ): self.pg = pg self.world_size = pg.size() @@ -129,6 +137,7 @@ def __init__( self.in_features = in_features // self.world_size self.out_features = out_features + self.stream_manager = stream_manager # No need to shard the bias term, only rank 0 would have it bias = dist.get_rank(self.pg) == 0 and bias @@ -175,6 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + stream_manager=self.stream_manager, ) def extra_repr(self) -> str: diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 4cb272c4..08de40ed 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -64,6 +64,7 @@ from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.parallel.pipeline_parallel.engine import ( @@ -609,6 +610,7 @@ def training_step( self.post_train_step() + self.stream_manager.comm_bucket.clear_all() return outputs, loss_avg, z_loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: @@ -742,12 +744,16 @@ def _init_model_instance(self) -> NanotronModel: model_config_cls in CONFIG_TO_MODEL_CLASS ), f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported" + self.stream_manager = CudaStreamManager() + self.stream_manager.init_default_comm_stream() + model = self._init_model( model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls]( config=self.model_config, parallel_context=self.parallel_context, parallel_config=self.config.parallelism, random_states=self.random_states, + stream_manager=self.stream_manager, ), ) return model diff --git a/tests/helpers/llama_helper.py b/tests/helpers/llama_helper.py index 7334f857..13c66658 100644 --- a/tests/helpers/llama_helper.py +++ b/tests/helpers/llama_helper.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from nanotron.config import ( AdamWOptimizerArgs, @@ -20,6 +22,7 @@ from nanotron.config.config import PretrainDatasetsArgs from nanotron.models import build_model from nanotron.models.llama import LlamaForTraining +from nanotron.parallel.comm import CudaStreamManager from nanotron.parallel.context import ParallelContext from nanotron.trainer import mark_tied_parameters @@ -47,15 +50,21 @@ ) -def get_llama_training_config(model_config: ModelArgs): - return Config( - model=model_config, - general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), - checkpoints=CheckpointsArgs( - checkpoints_path="./checkpoints", - checkpoint_interval=10, - ), - parallelism=ParallelismArgs( +def get_parallel_config(parallel_context: ParallelContext): + return ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + expert_parallel_size=parallel_context.expert_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + +def get_llama_training_config(model_config: ModelArgs, parallel_context: ParallelContext): + if parallel_context is None: + parallel_config = ParallelismArgs( dp=1, pp=1, tp=2, @@ -63,7 +72,18 @@ def get_llama_training_config(model_config: ModelArgs): pp_engine="1f1b", tp_mode="ALL_REDUCE", tp_linear_async_communication=False, + ) + else: + parallel_config = get_parallel_config(parallel_context) + + return Config( + model=model_config, + general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), + checkpoints=CheckpointsArgs( + checkpoints_path="./checkpoints", + checkpoint_interval=10, ), + parallelism=parallel_config, tokenizer=TokenizerArgs("gpt2"), optimizer=OptimizerArgs( optimizer_factory=AdamWOptimizerArgs( @@ -106,7 +126,11 @@ def get_llama_training_config(model_config: ModelArgs): def create_llama_from_config( - model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext + model_config: LlamaConfig, + device: torch.device, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, + stream_manager: Optional[CudaStreamManager] = None, ) -> LlamaForTraining: """ @@ -117,20 +141,25 @@ def create_llama_from_config( the model created will have random weights. """ - parallel_config = ParallelismArgs( - dp=parallel_context.data_parallel_size, - pp=parallel_context.pipeline_parallel_size, - tp=parallel_context.tensor_parallel_size, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) + if parallel_config is None: + parallel_config = ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + else: + parallel_config = parallel_config + model = build_model( model_builder=lambda: LlamaForTraining( config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, + stream_manager=stream_manager, ), parallel_context=parallel_context, dtype=torch.bfloat16, diff --git a/tests/test_base_model.py b/tests/test_base_model.py index b3474bc2..81cdd83e 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist from helpers.utils import init_distributed, rerun_if_address_is_in_use -from nanotron.config import Config, ModelArgs, RandomInit +from nanotron.config import ModelArgs, RandomInit from nanotron.parallel import ParallelContext from nanotron.parallel.pipeline_parallel.block import PipelineBlock from torch import nn @@ -11,21 +11,21 @@ @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)]) -@pytest.mark.skip @rerun_if_address_is_in_use() def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int): model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) - config = get_llama_training_config(model_args) - init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(config=config) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(model_args=model_args) def _test_get_named_modules_in_pp_rank( parallel_context: ParallelContext, - config: Config, + model_args: ModelArgs, ): + config = get_llama_training_config(model_args, parallel_context) model = create_llama_from_config( model_config=config.model.model_config, + parallel_config=config.parallelism, device=torch.device("cuda"), parallel_context=parallel_context, ) @@ -44,3 +44,37 @@ def _test_get_named_modules_in_pp_rank( # not PipelineBlock assert isinstance(module, nn.Module) assert name not in modules_that_not_in_current_pp_rank + + +@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_llama_model(tp: int, dp: int, pp: int): + BATCH_SIZE, SEQ_LEN = 10, 128 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_llama_model)( + model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN + ) + + +def _test_llama_model( + parallel_context: ParallelContext, + model_args: ModelArgs, + batch_size: int, + seq_len: int, +): + config = get_llama_training_config(model_args, parallel_context) + llama_model = create_llama_from_config( + model_config=config.model.model_config, + parallel_config=config.parallelism, + device=torch.device("cuda"), + parallel_context=parallel_context, + ) + llama_model.init_model_randomly(config=config) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert list(outputs.keys()) == ["loss"] + assert isinstance(outputs["loss"], torch.Tensor) diff --git a/tests/test_comm.py b/tests/test_comm.py new file mode 100644 index 00000000..39514a38 --- /dev/null +++ b/tests/test_comm.py @@ -0,0 +1,127 @@ +import pytest +import torch +import torch.distributed as dist +from helpers.utils import ( + init_distributed, + rerun_if_address_is_in_use, +) +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import AsyncCommBucket, CudaStreamManager, insert_backward_sync_to_tensor + + +class MockWork: + def __init__(self): + self.completed = False + self.wait_called = False + + def wait(self): + self.wait_called = True + self.completed = True + + def is_completed(self): + return self.completed + + +def test_cuda_stream_manager(): + manager = CudaStreamManager() + manager.create("test", torch.device("cuda")) + + stream = manager.get("test") + assert isinstance(stream, torch.cuda.Stream) + + +@rerun_if_address_is_in_use() +def test_add_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_add_async_op_to_bucket)() + + +def _test_add_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + tensor = torch.randn(1, device="cuda") + work = dist.all_reduce(tensor, async_op=True) + + comm_bucket = AsyncCommBucket() + comm_bucket.add(OP_NAME, work) + + assert comm_bucket.get(OP_NAME) is work + + +@rerun_if_address_is_in_use() +def test_wait_async_op_to_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + comm_bucket = AsyncCommBucket() + + comm_bucket.add(OP_NAME, work) + assert work.is_completed() is False + + comm_bucket.wait(OP_NAME) + assert work.is_completed() + with pytest.raises(KeyError): + comm_bucket.get(OP_NAME) + + +@rerun_if_address_is_in_use() +def test_is_all_completed_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_async_op_to_bucket)() + + +def _test_wait_async_op_to_bucket(parallel_context: ParallelContext): + OP_NAME = "test" + work = MockWork() + comm_bucket = AsyncCommBucket() + + comm_bucket.add(OP_NAME, work) + assert comm_bucket.is_all_completed() is False + + comm_bucket.wait(OP_NAME) + assert comm_bucket.is_all_completed() is True + + +@rerun_if_address_is_in_use() +def test_clear_ops_in_async_bucket(): + init_distributed(tp=2, dp=1, pp=1)(_test_clear_ops_in_async_bucket)() + + +def _test_clear_ops_in_async_bucket(parallel_context: ParallelContext): + comm_bucket = AsyncCommBucket() + + comm_bucket.add("test1", MockWork()) + comm_bucket.add("test2", MockWork()) + comm_bucket.add("test3", MockWork()) + + assert comm_bucket.is_all_completed() is False + + comm_bucket.clear_all() + assert comm_bucket.is_all_completed() is True + with pytest.raises(KeyError): + comm_bucket.get("test1") + + +@rerun_if_address_is_in_use() +def test_wait_comm(): + init_distributed(tp=2, dp=1, pp=1)(_test_wait_comm)() + + +def _test_wait_comm(parallel_context: ParallelContext): + OP_NAME = "test" + tensor = torch.randn(1, device="cuda", requires_grad=True) + stream_manager = CudaStreamManager() + + comm_stream = torch.cuda.Stream() + + with torch.cuda.stream(comm_stream): + work = MockWork() + stream_manager.comm_bucket.add(OP_NAME, work) + + output = insert_backward_sync_to_tensor(tensor, OP_NAME, stream_manager) + assert work.is_completed() is False + + # NOTE: we test that it waits for the async op to complete + # automatically in autograd + (output + 1).backward() + assert work.is_completed() diff --git a/tests/test_domino.py b/tests/test_domino.py new file mode 100644 index 00000000..843437a0 --- /dev/null +++ b/tests/test_domino.py @@ -0,0 +1,126 @@ +import os +from copy import deepcopy + +import pytest +import torch +from helpers.llama_helper import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from helpers.utils import init_distributed, rerun_if_address_is_in_use +from nanotron.config import ModelArgs, RandomInit +from nanotron.config.parallelism_config import DominoArgs +from nanotron.models.llama import DominoLlamaDecoderLayer +from nanotron.parallel import ParallelContext +from nanotron.parallel.comm import CudaStreamManager +from nanotron.parallel.tensor_parallel.domino import ( + OpNameContext, + get_current_bwd_op_name, + get_current_fwd_op_name, + is_domino_async_comm, +) + + +@pytest.mark.parametrize( + "op_name, expected", + [ + ("fwd.layer_attn_1_batch_0", True), + ("fwd.layer_attn_1_batch_1", True), + ("fwd.layer_mlp_1_batch_0", True), + ("fwd.layer_mlp_1_batch_1", True), + ("bwd.layer_mlp_1_batch_1", True), + ("bwd.layer_mlp_1_batch_0", True), + ("bwd.layer_attn_1_batch_1", True), + ("bwd.layer_attn_1_batch_0", False), + ], +) +def test_is_domino_async_comm(op_name, expected): + assert is_domino_async_comm(op_name) == expected + + +@pytest.mark.parametrize("tp,dp,pp", [(2, 2, 1)]) +@rerun_if_address_is_in_use() +def test_domino_model(tp: int, dp: int, pp: int): + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + BATCH_SIZE, SEQ_LEN = 10, 128 + + model_config = deepcopy(TINY_LLAMA_CONFIG) + model_config.num_hidden_layers = 28 + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_domino_model)( + model_args=model_args, batch_size=BATCH_SIZE, seq_len=SEQ_LEN + ) + + +def _test_domino_model( + parallel_context: ParallelContext, + model_args: ModelArgs, + batch_size: int, + seq_len: int, +): + config = get_llama_training_config(model_args, parallel_context) + config.parallelism.domino = DominoArgs(num_input_batches=2) + stream_manager = CudaStreamManager() + stream_manager.init_default_comm_stream() + + llama_model = create_llama_from_config( + model_config=config.model.model_config, + parallel_config=config.parallelism, + device=torch.device("cuda"), + parallel_context=parallel_context, + stream_manager=stream_manager, + ) + llama_model.init_model_randomly(config=config) + + for m in llama_model.model.decoder: + assert isinstance(m.pp_block, DominoLlamaDecoderLayer) + + input_ids = torch.randint(0, config.model.model_config.vocab_size, size=(batch_size, seq_len), device="cuda") + input_mask = torch.ones_like(input_ids) + outputs = llama_model(input_ids, input_mask, input_mask, input_mask) + + assert isinstance(outputs["loss"], torch.Tensor) + assert stream_manager.comm_bucket.is_all_completed() is True + + +### OpNameContext tests ### + + +def test_op_name_context_reentry(): + assert get_current_fwd_op_name() is None + assert get_current_bwd_op_name() is None + context = OpNameContext(fwd_op_name="fwd.reusable_op", bwd_op_name="bwd.reusable_op") + + with context: + assert get_current_fwd_op_name() == "fwd.reusable_op" + assert get_current_bwd_op_name() == "bwd.reusable_op" + + assert get_current_fwd_op_name() is None + assert get_current_bwd_op_name() is None + + with context: + assert get_current_fwd_op_name() == "fwd.reusable_op" + assert get_current_bwd_op_name() == "bwd.reusable_op" + + assert get_current_fwd_op_name() is None + assert get_current_bwd_op_name() is None + + +def test_deeply_nested_contexts(): + with OpNameContext(fwd_op_name="fwd.level1", bwd_op_name="fwd.level1"): + assert get_current_fwd_op_name() == "fwd.level1" + + with OpNameContext(fwd_op_name="fwd.level2", bwd_op_name="fwd.level2"): + assert get_current_fwd_op_name() == "fwd.level2" + + assert get_current_fwd_op_name() == "fwd.level1" + + +def test_multiple_sequential_contexts(): + assert get_current_fwd_op_name() is None + + with OpNameContext(fwd_op_name="fwd.first_op", bwd_op_name="bwd.first_op"): + assert get_current_fwd_op_name() == "fwd.first_op" + + with OpNameContext(fwd_op_name="fwd.second_op", bwd_op_name="bwd.second_op"): + assert get_current_fwd_op_name() == "fwd.second_op" + + assert get_current_fwd_op_name() is None