From 159376c4466e7e2bfec3bc9e59cc5aa9c4657168 Mon Sep 17 00:00:00 2001 From: Feng Tian Date: Mon, 20 May 2024 11:14:43 -0700 Subject: [PATCH 1/3] enable doraPP --- pippy/PipelineSchedule.py | 216 +++++++++++++++++++++++ test/test_pipeline_schedule_e2e.py | 13 +- torchpippy.egg-info/PKG-INFO | 18 ++ torchpippy.egg-info/SOURCES.txt | 48 +++++ torchpippy.egg-info/dependency_links.txt | 1 + torchpippy.egg-info/requires.txt | 1 + torchpippy.egg-info/top_level.txt | 2 + 7 files changed, 296 insertions(+), 3 deletions(-) create mode 100644 torchpippy.egg-info/PKG-INFO create mode 100644 torchpippy.egg-info/SOURCES.txt create mode 100644 torchpippy.egg-info/dependency_links.txt create mode 100644 torchpippy.egg-info/requires.txt create mode 100644 torchpippy.egg-info/top_level.txt diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 4f79c1287..ae01ccfa7 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -14,6 +14,8 @@ from .microbatch import merge_chunks, split_args_kwargs_into_chunks logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + class PipelineSchedule(ABC): @@ -776,3 +778,217 @@ def backward_stage_local_index(step): # Return losses if there is a container passed in self._update_losses(self._stages, losses) + + +class ScheduleDoraPP(PipelineScheduleMulti): + """ + This is interleaved dfs+bfs zero bubble schedule. + """ + def __init__( + self, + stages: List[PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + self.pp_group_size = stages[0].group_size + # TODO: is this limitation a must? + if n_microbatches % self.pp_group_size != 0: + raise ValueError( + "Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of pipeline ranks ({self.pp_group_size}), " + f"but got {n_microbatches}." + ) + + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + ) + + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Operate on the microbatches for doraPP schedule (https://arxiv.org/pdf/2104.04473.pdf). + + Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks + and each rank away from highest rank adds 2 warmup steps due to: + - one happened before highest rank's warmup started, + - one waiting for backward result to trickle down from highest rank + + TODO: Interleaved 1F1B does not support using sorted_batch_isend_irecv() + because it requires recvs and sends from different peers + to execute in the same coalesced operation. As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs( + arg_mbs, kwarg_mbs, target_mbs, losses + ) + + # increment warmup_steps by 2 for each hop away + warmup_steps = (self.n_local_stages - 1) * self.pp_group_size + warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank) + warmup_steps = min( + warmup_steps, self._n_microbatches * self.n_local_stages + ) + fwd_bwd_steps = ( + self.n_local_stages * self._n_microbatches + ) - warmup_steps + cooldown_steps = ( + self.n_local_stages * self._n_microbatches + ) - fwd_bwd_steps + + assert ( + warmup_steps + fwd_bwd_steps * 2 + cooldown_steps + == self.n_local_stages * self._n_microbatches * 2 + ) + total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps + + logger.debug( + f""" + n_microbatches {self._n_microbatches} + stages {self.n_local_stages} + rank {self.rank} + warmup_steps {warmup_steps} + 1f1b {fwd_bwd_steps} + cooldown_steps {cooldown_steps} + """ + ) + + def forward_stage_local_index(step): + return (step // self.pp_group_size) % self.n_local_stages + + def backward_stage_local_index(step): + return ( + self.n_local_stages + - 1 + - ((step - warmup_steps) // self.pp_group_size) + % self.n_local_stages + ) + + fwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) + bwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) + + # Delay send waits + sends_to_wait: List[dist.Work] = [] + + # Store ops (potentially across steps) + ops: List[dist.P2POp] = [] + + # Warmup Phase (forward only) + for step in range(warmup_steps): + fwd_stage = self._stages[forward_stage_local_index(step)] + + # This will assign the current microbatch index and update it for future steps + fwd_stage_mb_index[fwd_stage] = ( + mb_index := fwd_stage_mb_index[fwd_stage] + ) + 1 + + logger.debug( + f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}" + ) + + with record_function(f"Forward {step}"): + ops.extend(fwd_stage.get_fwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + + output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index] + + ops.extend(fwd_stage.get_fwd_send_ops()) + # If we are right before the fwd-bwd step, then we need to delay the send to the next step, + # This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang. + # In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed + if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0): + work = dist.batch_isend_irecv(ops).pop() + sends_to_wait.append(work) + ops.clear() + + self._maybe_compute_loss( + fwd_stage, output, target_mbs, mb_index + ) + + # 1F1B Phase (forward and backward) + for step in range(warmup_steps, warmup_steps + fwd_bwd_steps): + fwd_stage = self._stages[forward_stage_local_index(step)] + bwd_stage = self._stages[backward_stage_local_index(step)] + + fwd_stage_mb_index[fwd_stage] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage] + ) + 1 + bwd_stage_mb_index[bwd_stage] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage] + ) + 1 + + bwd_stage._configure_data_parallel_mode( + bwd_mb_index == self._n_microbatches - 1 + ) + logger.debug( + f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}" + ) + with record_function(f"1F1B {step}"): + ops.extend(fwd_stage.get_fwd_recv_ops()) + ops.extend(bwd_stage.get_bwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + + # Forward + output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + ops.extend(fwd_stage.get_fwd_send_ops()) + self._maybe_compute_loss( + fwd_stage, output, target_mbs, fwd_mb_index + ) + + # Backward + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) + bwd_stage.backward_one_chunk(loss=loss) + ops.extend(bwd_stage.get_bwd_send_ops()) + + # Cooldown Phase (backward only) + for step in range(warmup_steps + fwd_bwd_steps, total_steps): + bwd_stage = self._stages[backward_stage_local_index(step)] + bwd_stage_mb_index[bwd_stage] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage] + ) + 1 + bwd_stage._configure_data_parallel_mode( + bwd_mb_index == self._n_microbatches - 1 + ) + + logger.debug( + f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}" + ) + with record_function(f"Cooldown {step}"): + ops.extend(bwd_stage.get_bwd_recv_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + work.wait() + ops.clear() + + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) + bwd_stage.backward_one_chunk(loss=loss) + + ops.extend(bwd_stage.get_bwd_send_ops()) + if ops: + work = dist.batch_isend_irecv(ops).pop() + sends_to_wait.append(work) + ops.clear() + + # Make sure all sends are finished + for work in sends_to_wait: + work.wait() + + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) diff --git a/test/test_pipeline_schedule_e2e.py b/test/test_pipeline_schedule_e2e.py index 4efd22455..71b1a5064 100644 --- a/test/test_pipeline_schedule_e2e.py +++ b/test/test_pipeline_schedule_e2e.py @@ -34,10 +34,13 @@ ScheduleLoopedBFS, ) +from pippy.PipelineSchedule import ScheduleDoraPP + from torch.distributed._tensor.device_mesh import init_device_mesh from torch.profiler import record_function logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) # profiling context manager @@ -176,7 +179,7 @@ def rank_print(msg): if kwargs["stage_type"] == "manual": stage_model = ManualPipelineStage( module_list[rank], - stage_id=rank, + stage_index=rank, num_stages=world_size, device=device, input_args=input_args, @@ -186,7 +189,7 @@ def rank_print(msg): stage_model_looped = [ ManualPipelineStage( module_list[rank], - stage_id=(world_size * i) + rank, + stage_index=(world_size * i) + rank, num_stages=world_size * world_size, device=device, input_args=input_args, @@ -232,6 +235,10 @@ def rank_print(msg): my_schedule = ScheduleInterleaved1F1B( stage_model_looped, n_microbatches, loss_fn ) + elif schedule == "doraPP": + my_schedule = ScheduleDoraPP( + stage_model_looped, n_microbatches, loss_fn + ) if _run_profiler: logger.info(f"====== Rank {rank} profile ======") @@ -299,7 +306,7 @@ def set_up_logging(rank, log_level): "--schedules", type=str, nargs="+", - choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b"], + choices=["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b","doraPP"], default=["interleaved_1f1b"], ) parser.add_argument("--device", type=str, default="cuda") diff --git a/torchpippy.egg-info/PKG-INFO b/torchpippy.egg-info/PKG-INFO new file mode 100644 index 000000000..3a05234b2 --- /dev/null +++ b/torchpippy.egg-info/PKG-INFO @@ -0,0 +1,18 @@ +Metadata-Version: 2.1 +Name: torchpippy +Version: 0.2.0+cdc0ac6 +Summary: Pipeline Parallelism for PyTorch +Home-page: https://github.com/pytorch/PiPPy +Author: PiPPy Team +License: BSD +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: torch>=2.3.0.dev + + +The PiPPy project stands for Pipeline Parallelism for PyTorch. It consists of a +compiler and runtime stack for automated parallelism and scaling of PyTorch +models. PiPPy partitions the code of the model in a pipelined fashion and +enables multiple micro-batches to execute different parts of the model code +concurrently. For details, please visit PiPPy's [GitHub +page](https://github.com/pytorch/PiPPy). diff --git a/torchpippy.egg-info/SOURCES.txt b/torchpippy.egg-info/SOURCES.txt new file mode 100644 index 000000000..1961b525c --- /dev/null +++ b/torchpippy.egg-info/SOURCES.txt @@ -0,0 +1,48 @@ +LICENSE +README.md +pyproject.toml +setup.py +pippy/ManualPipelineStage.py +pippy/ModelSplit.py +pippy/PipelineSchedule.py +pippy/_IR.py +pippy/_PipelineStage.py +pippy/__init__.py +pippy/_backward.py +pippy/_debug.py +pippy/_unflatten.py +pippy/_utils.py +pippy/microbatch.py +pippy/version.py +pippy/utilities/__init__.py +pippy/utilities/hf_checkpoint.py +test/__init__.py +test/hf_test.py +test/local_test_c10d_ddp.py +test/local_test_checkpoint.py +test/local_test_null_coalesce_accumulate.py +test/test_autosplit.py +test/test_bwd.py +test/test_chunkspec.py +test/test_composability.py +test/test_cpu_init.py +test/test_fwd.py +test/test_grad.py +test/test_interleave.py +test/test_ir.py +test/test_microbatch.py +test/test_optim.py +test/test_pipe.py +test/test_pipe_bwd.py +test/test_pipeline_schedule.py +test/test_pipeline_schedule_e2e.py +test/test_pipeline_stage.py +test/test_skip_conn.py +test/test_stage_backward.py +test/test_transformer.py +test/test_unflatten.py +torchpippy.egg-info/PKG-INFO +torchpippy.egg-info/SOURCES.txt +torchpippy.egg-info/dependency_links.txt +torchpippy.egg-info/requires.txt +torchpippy.egg-info/top_level.txt \ No newline at end of file diff --git a/torchpippy.egg-info/dependency_links.txt b/torchpippy.egg-info/dependency_links.txt new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/torchpippy.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/torchpippy.egg-info/requires.txt b/torchpippy.egg-info/requires.txt new file mode 100644 index 000000000..498366136 --- /dev/null +++ b/torchpippy.egg-info/requires.txt @@ -0,0 +1 @@ +torch>=2.3.0.dev diff --git a/torchpippy.egg-info/top_level.txt b/torchpippy.egg-info/top_level.txt new file mode 100644 index 000000000..239cc8fd3 --- /dev/null +++ b/torchpippy.egg-info/top_level.txt @@ -0,0 +1,2 @@ +pippy +test From 1e0f73d155192ea5c514c84b67edec69b75a638c Mon Sep 17 00:00:00 2001 From: Feng Tian Date: Wed, 29 May 2024 14:36:30 -0700 Subject: [PATCH 2/3] WIP --- pippy/PipelineSchedule.py | 765 ++++++++++++++++++++++++++++- test/test_pipeline_schedule_e2e.py | 2 +- 2 files changed, 761 insertions(+), 6 deletions(-) diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index ae01ccfa7..b338593c0 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -816,9 +816,11 @@ def _step_microbatches( kwarg_mbs: Optional[List] = None, target_mbs: Optional[List] = None, losses: Optional[List] = None, + microbatch_size: Optional[int] = None, + model_dim: Optional[int] = None, ): """ - Operate on the microbatches for doraPP schedule (https://arxiv.org/pdf/2104.04473.pdf). + Operate on the microbatches for doraPP schedule . Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks and each rank away from highest rank adds 2 warmup steps due to: @@ -893,10 +895,6 @@ def backward_stage_local_index(step): mb_index := fwd_stage_mb_index[fwd_stage] ) + 1 - logger.debug( - f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}" - ) - with record_function(f"Forward {step}"): ops.extend(fwd_stage.get_fwd_recv_ops()) if ops: @@ -992,3 +990,760 @@ def backward_stage_local_index(step): # Return losses if there is a container passed in self._update_losses(self._stages, losses) + + + + + + + + + + ########################################################## + # Dora PP Xlformer Implementation + ########################################################## + + input_tensors = [[] for _ in range(len(self._stages))] + output_tensors = [[] for _ in range(len(self._stages))] + output_tensor_grads = [[] for _ in range(len(self._stages))] + # We need to pop input, output and grad during bwd, we use this list to track real input tensor index. + popped_input_tensors = [[] for _ in range(len(self._stages))] + input_tensor_grad = None + + pipeline_parallel_size = self.pp_group_size + pipeline_parallel_rank = self._stage.stage_index + + microbatch_x = arg_mbs + microbatch_y = target_mbs + microbatch_mask = None + mask = None + if mask is not None: + microbatch_mask = mask.split(args.pipeline_parallel_microbatch_size, dim=0) + + num_microbatches = self._n_microbatches + # microbatch_attn_bias = [ + # model[0].get_attn_bias(microbatch_x[i], cache=None) + # for i in range(num_microbatches) + # ] + microbatch_attn_bias = [ + self._stages[0].submodule.get_attn_bias(microbatch_x[i], cache=None) + for i in range(num_microbatches) + ] + + + # TODO: get the model args from API directly, should modify it later + assert(microbatch_size is not None), "microbatch_size is None" + assert(model_dim is not None), "model_dim is None" + + microbatch_less_than_pp = num_microbatches < pipeline_parallel_size + num_round = max(num_microbatches // pipeline_parallel_size, 1) + assert ( + num_microbatches % num_round == 0 + ), "Number of microbatches should be divisible by number of pipeline rounds." + # the number of microbatches run in each round, in dfs it is pipeline_parallel_size + num_microbatch_per_round = num_microbatches // num_round + + tensor_shape = ( + microbatch_size, + model_dim, + ) + + num_model_chunks = len(model) + total_num_microbatches = num_microbatches * num_model_chunks + + dtype = get_torch_dtype(args.dtype) + + mpu.set_virtual_pipeline_model_parallel_rank(0) + all_warmup_microbatches = False + + if not args.model.enable_ddp: + for model_chunk in model: + model_chunk._rebuild_full_params_recursive() + else: + for model_chunk in model: + model_chunk.zero_grad() + + + num_warmup_microbatches = 0 + # The number of microbatches that last pipeline stage run before 1f1b. + num_warmup_microbatches += (num_model_chunks - 1) * num_microbatch_per_round + # From last PP stage up, each rank will be 2 more than the previous one. + num_warmup_microbatches += ( + pipeline_parallel_size - pipeline_parallel_rank - 1 + ) * 2 + num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches) + num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches + # The number of 1f1b for zero bubble schedule + if num_microbatches == pipeline_parallel_size: + num_1f1b_microbatches = pipeline_parallel_rank + else: + num_1f1b_microbatches = 2 * pipeline_parallel_rank + + # Checkpoint the activations of partial Transformer layers in a number of micro-batches + # within the maximum outstanding micro-batch backpropagations. + # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' + # checkpoint partial Transformer layers (or skip checkpointing) and + # the rest of micro-batches within a window of micro-batches checkpoint + # all Transformer layers. The window of micro-batches is set by the maximum + # outstanding backpropagations and becomes smaller at later pipeline stages. + # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf + max_outstanding_backprops = None + if args.num_microbatches_with_partial_activation_checkpoints is not None: + max_outstanding_backprops = num_warmup_microbatches + 1 + + p0_chunk0_batch = [0, 0] + mean_losses = [] + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number. + Each group has num_microbatch_per_round * num_model_chunks microbatches. + within each chunk, there are num_microbatch_per_round microbatches. + backward is reverse order of forward. + """ + microbatch_id_in_group = microbatch_id % ( + num_microbatch_per_round * num_model_chunks + ) + model_chunk_id = microbatch_id_in_group // num_microbatch_per_round + if not forward: + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def get_real_microbatch_id(microbatch_id: int) -> int: + """Get the microbatch id for input tokens.""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + microbatch_group_id = microbatch_id // microbatch_group_size + real_microbatch_id_in_group = ( + microbatch_id % microbatch_group_size + ) % num_microbatch_per_round + real_microbatch_id = ( + real_microbatch_id_in_group + microbatch_group_id * num_microbatch_per_round + ) + return real_microbatch_id + + def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the first for a model chunk.""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == 0: + return microbatch_id_in_group % num_microbatch_per_round == 0 + else: + return False + + def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the last for a model chunk.""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + num_microbatch_groups = total_num_microbatches // microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == num_microbatch_groups - 1: + return ( + microbatch_id_in_group % num_microbatch_per_round + == num_microbatch_per_round - 1 + ) + else: + return False + + def get_input_index(microbatch_id): + """Get pipeline input index for a microbatch""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + microbatch_id_in_group = microbatch_id % microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + input_index = microbatch_id_in_group % num_microbatch_per_round + return input_index + microbatch_group_id * num_microbatch_per_round + + def microbatch_fwd( + model_chunk_id, + input_tensor, + microbatch_tokens, + y, + state, + mask, + mean_losses, + is_first_microbatch=False, + recompute_attn=None, + recompute_fc1_fc3=None, + attn_bias=None, + ): + if input_tensor is None: + assert mpu.is_pipeline_first_stage() + else: + assert not mpu.is_pipeline_first_stage() + + if args.num_microbatches_with_partial_activation_checkpoints is not None: + output, _ = model[model_chunk_id]( + microbatch_tokens, + pipeline_parallel_input_tensor=input_tensor, + is_first_microbatch=is_first_microbatch, + recompute_attn=recompute_attn, + recompute_fc1_fc3=recompute_fc1_fc3, + precomputed_attn_bias=attn_bias, + ) + else: + output, _ = model[model_chunk_id]( + microbatch_tokens, + pipeline_parallel_input_tensor=input_tensor, + is_first_microbatch=is_first_microbatch, + precomputed_attn_bias=attn_bias, + ) + + if mpu.is_pipeline_last_stage(): + if loss_fn is not None: + loss = loss_fn( + output, + y, + mask, + ) + output = loss.mean() / num_microbatches + else: + if args.model.loss_parallel: + tok_loss = state.scale * vocab_parallel_cross_entropy( + partial_logits=output, + target=y, + z_loss_multiplier=args.z_loss_multiplier, + ) + else: + tok_loss = state.scale * F.cross_entropy( + output.flatten(0, 1), y.flatten(0, 1), reduction="none" + ) + if mask is None: + output = tok_loss.mean() / num_microbatches + else: + mask = mask.flatten(0, 1) + tok_loss = tok_loss * mask + output = tok_loss.sum() / (mask.sum() + 1e-6) / num_microbatches + mean_losses.append(output) + p0_chunk0_batch[1] += 1 + return output + + def deallocate_output_tensor(out): + """Deallocate the output tensor's '.data' field. + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + assert isinstance(out, torch.Tensor), ( + "expected Tensor, found %s." % type(out).__name__ + ) + assert out._base is None, "counter-productive to free a view of another tensor." + out.data.storage().resize_(0) + + def custom_backward(output, grad_output): + """Custom backward where directly call C++ autograd engine. + Since Pytorch's 'backward' checks that the output and + grad have the same shape. We need to manually call the C++ autograd + instead of using Pytorch's torch.autograd.backward. + So that the 'deallocate_output_tensor' optimization can work. + """ + + assert ( + output.storage().size() == 0 + ), "output should be pseudo-'freed' in schedule, to optimize memory" + assert isinstance(output, torch.Tensor), ( + "output == '%s'." % type(output).__name__ + ) + assert isinstance(grad_output, (torch.Tensor, type(None))), ( + "grad_output == '%s'." % type(grad_output).__name__ + ) + + # Handle scalar output + if grad_output is None: + assert output.numel() == 1, "implicit grad requires scalar output." + grad_output = torch.ones_like( + output, + memory_format=torch.preserve_format, + ) + + # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] + Variable._execution_engine.run_backward( + tensors=(output,), + grad_tensors=(grad_output,), + keep_graph=False, + create_graph=False, + inputs=tuple(), + allow_unreachable=True, + accumulate_grad=True, + ) + + def microbatch_bwd(input_tensor, output_tensor, output_tensor_grad): + if input_tensor is not None: + input_tensor.retain_grad() + if output_tensor_grad is None: + output_tensor.backward() + else: + if args.deallocate_pipeline_outputs: + custom_backward(output_tensor, output_tensor_grad) + else: + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + if input_tensor is not None: + return input_tensor.grad + return None + + def forward_step_helper( + microbatch_id, p0_chunk0_batch, recompute_attn=None, recompute_fc1_fc3=None + ): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + is_first_microbatch = is_first_microbatch_for_model_chunk(microbatch_id) + + # forward step + if mpu.is_pipeline_first_stage(): + # This is to make sure each model chunk has the number of input same as num_microbatch + # For other pipeline stages, input will append the received tensor from previous pipeline stage + if len(input_tensors[model_chunk_id]) == len( + output_tensors[model_chunk_id] + ): + input_tensors[model_chunk_id].append(None) + + # input_tensors has all the input for each model chunk. + # If not first PP stage(including virtual), we will use the very last input in input_tensors. + # On the first PP stage, if num_microbatch_per_round is larger than pipeline stage, + # this means we will receive the input num_microbatch_per_round - pipeline_parallel_size earlier than it will be used. + # So we need to use the input according to index of microbatch. We first figure out in this model chunk, which microbatch we are running. + # then substract the number of popped input_tensors. + if mpu.is_pipeline_first_stage(ignore_virtual=True): + input_index = get_input_index(microbatch_id) + input_index -= len(popped_input_tensors[model_chunk_id]) + else: + input_index = -1 + input_tensor = input_tensors[model_chunk_id][input_index] + real_microbatch_id = get_real_microbatch_id(microbatch_id) + output_tensor = microbatch_fwd( + model_chunk_id, + input_tensor, + microbatch_x[real_microbatch_id], + microbatch_y[p0_chunk0_batch[1]], + state, + ( + microbatch_mask[real_microbatch_id] + if microbatch_mask is not None + else None + ), + mean_losses, + is_first_microbatch=is_first_microbatch, + recompute_attn=recompute_attn, + recompute_fc1_fc3=recompute_fc1_fc3, + attn_bias=microbatch_attn_bias[real_microbatch_id], + ) + output_tensors[model_chunk_id].append(output_tensor) + return output_tensor + + def backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + if mpu.is_pipeline_last_stage(): + if len(output_tensor_grads[model_chunk_id]) == 0: + output_tensor_grads[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id].pop(0) + popped_input_tensors[model_chunk_id].append(None) + output_tensor = output_tensors[model_chunk_id].pop(0) + output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) + + input_tensor_grad = microbatch_bwd( + input_tensor, output_tensor, output_tensor_grad + ) + # Reuse the deallocate_output_tensor function to release input_tensor + if input_tensor is not None: + deallocate_output_tensor(input_tensor) + + return input_tensor_grad + + mpu.set_virtual_pipeline_model_parallel_rank(0) + with record_function("warmup forward passes p2p comm"): + input_tensors[0].append( + p2p_communication.recv_forward( + tensor_shape, dtype, batch_p2p_comm=batch_p2p_communication + ) + ) + + with record_function("warmup forward passes"): + fwd_wait_handles = None + bwd_wait_handles = None + for k in range(num_warmup_microbatches): + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + k % max_outstanding_backprops + >= args.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + with record_function("1f"): + output_tensor = forward_step_helper( + k, + p0_chunk0_batch, + recompute_attn=checkpoint_activations_microbatch + and args.mb_recompute_attn, + recompute_fc1_fc3=checkpoint_activations_microbatch + and args.mb_recompute_fc1_fc3, + ) + + # Determine the model chunk that received input from this iteration belongs to. + # On the first PP stage, if num_microbatch_per_round is larger than pipeline stage, + # this means we will receive the input num_microbatch_per_round - pipeline_parallel_size earlier than it will be used by its model chunk. + # so to determine the true model chunk, we need to add num_microbatch_per_round - pipeline_parallel_size. + next_forward_model_chunk_id = None + if mpu.is_pipeline_first_stage(ignore_virtual=True): + if microbatch_less_than_pp: + next_forward_model_chunk_id = get_model_chunk_id( + k + 1, + forward=True, + ) + else: + next_forward_model_chunk_id = get_model_chunk_id( + k + 1 + num_microbatch_per_round - pipeline_parallel_size, + forward=True, + ) + else: + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) + + + recv_prev = True + # For first PP rank, there are two cases that to not receive: + # (1) Before first model chunk of last PP stage start to run, there is nothing to receive. + # (2) when last model chunk of last PP stage start running, last PP rank wont send input anymore. + if mpu.is_pipeline_first_stage(ignore_virtual=True): + if microbatch_less_than_pp: + if k < num_microbatch_per_round - 1: + recv_prev = False + else: + if k < pipeline_parallel_size - 1: + recv_prev = False + elif ( + k + >= (num_model_chunks - 1) * num_microbatch_per_round + + pipeline_parallel_size + - 1 + ): + recv_prev = False + if k == (total_num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if mpu.is_pipeline_last_stage(): + output_tensor = None + + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration + + ( + input_tensor, + fwd_wait_handles, + ) = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + + if k == (num_warmup_microbatches - 1) and not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if mpu.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_communication, + dtype=dtype, + overlap_p2p_comm=True, + ) + + output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) + # make sure number of input tensor is same as number of microbatch + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + + if args.deallocate_pipeline_outputs and output_tensor is not None: + deallocate_output_tensor(output_tensor) + + # Run 1F1B in steady state. + with record_function("forward 1F1B steady"): + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + sync_grads = is_last_microbatch_for_model_chunk(k) + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + forward_k % max_outstanding_backprops + >= args.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + if args.deallocate_pipeline_outputs and output_tensor is not None: + deallocate_output_tensor(output_tensor) + with record_function("1f"): + output_tensor = forward_step_helper( + forward_k, + p0_chunk0_batch, + recompute_attn=checkpoint_activations_microbatch + and args.mb_recompute_attn, + recompute_fc1_fc3=checkpoint_activations_microbatch + and args.mb_recompute_fc1_fc3, + ) + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + + # Last virtual stage no activation tensor to send + if mpu.is_pipeline_last_stage(): + output_tensor = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if mpu.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True + ) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id( + forward_k + 1, forward=True + ) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + ( + input_tensor, + fwd_wait_handles, + ) = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() + + if input_tensor_grad is not None: + deallocate_output_tensor(input_tensor_grad) + + # Backward pass. + backward_k = k + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + + if not args.model.enable_ddp and sync_grads: + model[ + backward_model_chunk_id + ].dont_wait_current_stream_for_post_all_gather = True + with ( + nullcontext() + if sync_grads + else model[backward_model_chunk_id].no_sync() + ): + with record_function("1b"): + input_tensor_grad = backward_step_helper(backward_k) + + mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + + # First virtual stage no activation gradient tensor to send + if mpu.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if the current virtual stage has an activation gradient tensor to receive + recv_next = True + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id( + backward_k + 1, forward=False + ) + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + if not args.model.enable_ddp and sync_grads: + model[ + backward_model_chunk_id + ].dont_wait_current_stream_for_post_all_gather = True + with ( + nullcontext() + if sync_grads + else model[backward_model_chunk_id].no_sync() + ): + if args.zero_bubble and k >= num_1f1b_microbatches: + with record_function("zero bubble 1w"): + WeightGradStore.pop() + + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + output_tensor_grad + ) + model_chunk_id = get_model_chunk_id(backward_k, forward=False) + + if args.deallocate_pipeline_outputs and output_tensor is not None: + deallocate_output_tensor(output_tensor) + + # Run cooldown backward passes (flush out pipeline). + with record_function("cooldown backward"): + if overlap_p2p_communication and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + if input_tensor_grad is not None: + deallocate_output_tensor(input_tensor_grad) + + if all_warmup_microbatches: + output_tensor_grads[num_model_chunks - 1].append( + p2p_communication.recv_backward( + tensor_shape, batch_p2p_comm=batch_p2p_communication, dtype=dtype + ) + ) + for k in range(num_microbatches_remaining, total_num_microbatches): + if overlap_p2p_communication and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + # same as warmup, for last PP stage, currently received grad is + # (num_microbatch_per_round - pipeline_parallel_size) earlier than its corresponding model chunk + if mpu.is_pipeline_last_stage(ignore_virtual=True): + if microbatch_less_than_pp: + next_backward_model_chunk_id = get_model_chunk_id( + k + 1, + forward=False, + ) + else: + next_backward_model_chunk_id = get_model_chunk_id( + k + 1 + num_microbatch_per_round - pipeline_parallel_size, + forward=False, + ) + else: + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) + model_chunk_id = get_model_chunk_id(k, forward=False) + if not args.model.enable_ddp and is_last_microbatch_for_model_chunk(k): + model[ + model_chunk_id + ].dont_wait_current_stream_for_post_all_gather = True + with ( + nullcontext() + if is_last_microbatch_for_model_chunk(k) + else model[model_chunk_id].no_sync() + ): + with record_function("1b"): + input_tensor_grad = backward_step_helper(k) + + recv_next = True + # for last pp stage, if it start the very last model chunk, then no need to receive + # edge case is when it is bfs, before first model chunk of first pp stage start bwd, last stage doesnt need to receive. + if mpu.is_pipeline_last_stage(ignore_virtual=True): + if microbatch_less_than_pp: + if k < num_microbatch_per_round - 1: + recv_next = False + else: + if k < pipeline_parallel_size - 1: + recv_next = False + elif ( + k + >= total_num_microbatches + - num_microbatch_per_round + - 1 + + pipeline_parallel_size + ): + recv_next = False + if k == (total_num_microbatches - 1): + recv_next = False + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + output_tensor_grad + ) + + with ( + nullcontext() + if is_last_microbatch_for_model_chunk(k) + else model[model_chunk_id].no_sync() + ): + with record_function("zero bubble 1w"): + WeightGradStore.pop() + while WeightGradStore.weight_grad_queue.qsize() > 0: + with record_function("zero bubble 1w"): + WeightGradStore.pop() + + # Make sure all communication is finished + torch.cuda.synchronize() + + for model_chunk_id in range(num_model_chunks): + model[model_chunk_id].dont_wait_current_stream_for_post_all_gather = False + # logger.warning(f"model_chunk: {model_chunk_id}; rank: {torch.distributed.get_rank()}") + model[model_chunk_id]._wait_for_post_backward() + + if len(mean_losses) > 0: + sum_loss_across_mb = torch.stack(mean_losses).sum() + else: + sum_loss_across_mb = torch.zeros([], dtype=torch.float32, device="cuda") + + torch.distributed.broadcast( + sum_loss_across_mb, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) + return sum_loss_across_mb, None diff --git a/test/test_pipeline_schedule_e2e.py b/test/test_pipeline_schedule_e2e.py index 71b1a5064..93f4fffa5 100644 --- a/test/test_pipeline_schedule_e2e.py +++ b/test/test_pipeline_schedule_e2e.py @@ -237,7 +237,7 @@ def rank_print(msg): ) elif schedule == "doraPP": my_schedule = ScheduleDoraPP( - stage_model_looped, n_microbatches, loss_fn + stage_model_looped, n_microbatches, loss_fn, microbatch_size=microbatch_size, model_dim=input_dim, ) if _run_profiler: From 663189a6855800d621f764b34397b0ffc66f4839 Mon Sep 17 00:00:00 2001 From: Feng Tian Date: Thu, 13 Jun 2024 11:37:46 -0700 Subject: [PATCH 3/3] [wip] share up to date work --- pippy/PipelineSchedule.py | 906 +---------------------------- pippy/zero_bubble.py | 739 +++++++++++++++++++++++ test/test_pipeline_schedule_e2e.py | 9 +- 3 files changed, 766 insertions(+), 888 deletions(-) create mode 100644 pippy/zero_bubble.py diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index b338593c0..4c90f2be4 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -792,6 +792,7 @@ def __init__( output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ): self.pp_group_size = stages[0].group_size + self.deallocate_pipeline_outputs = False # TODO: is this limitation a must? if n_microbatches % self.pp_group_size != 0: raise ValueError( @@ -819,8 +820,9 @@ def _step_microbatches( microbatch_size: Optional[int] = None, model_dim: Optional[int] = None, ): + """ - Operate on the microbatches for doraPP schedule . + Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf). Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks and each rank away from highest rank adds 2 warmup steps due to: @@ -836,534 +838,36 @@ def _step_microbatches( arg_mbs, kwarg_mbs, target_mbs, losses ) - # increment warmup_steps by 2 for each hop away - warmup_steps = (self.n_local_stages - 1) * self.pp_group_size - warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank) - warmup_steps = min( - warmup_steps, self._n_microbatches * self.n_local_stages - ) - fwd_bwd_steps = ( - self.n_local_stages * self._n_microbatches - ) - warmup_steps - cooldown_steps = ( - self.n_local_stages * self._n_microbatches - ) - fwd_bwd_steps - - assert ( - warmup_steps + fwd_bwd_steps * 2 + cooldown_steps - == self.n_local_stages * self._n_microbatches * 2 - ) - total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps - - logger.debug( - f""" - n_microbatches {self._n_microbatches} - stages {self.n_local_stages} - rank {self.rank} - warmup_steps {warmup_steps} - 1f1b {fwd_bwd_steps} - cooldown_steps {cooldown_steps} - """ - ) - - def forward_stage_local_index(step): - return (step // self.pp_group_size) % self.n_local_stages - - def backward_stage_local_index(step): - return ( - self.n_local_stages - - 1 - - ((step - warmup_steps) // self.pp_group_size) - % self.n_local_stages - ) - - fwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) - bwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int) - - # Delay send waits - sends_to_wait: List[dist.Work] = [] - - # Store ops (potentially across steps) - ops: List[dist.P2POp] = [] - - # Warmup Phase (forward only) - for step in range(warmup_steps): - fwd_stage = self._stages[forward_stage_local_index(step)] - - # This will assign the current microbatch index and update it for future steps - fwd_stage_mb_index[fwd_stage] = ( - mb_index := fwd_stage_mb_index[fwd_stage] - ) + 1 - - with record_function(f"Forward {step}"): - ops.extend(fwd_stage.get_fwd_recv_ops()) - if ops: - work = dist.batch_isend_irecv(ops).pop() - work.wait() - ops.clear() - - output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index] - - ops.extend(fwd_stage.get_fwd_send_ops()) - # If we are right before the fwd-bwd step, then we need to delay the send to the next step, - # This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang. - # In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed - if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0): - work = dist.batch_isend_irecv(ops).pop() - sends_to_wait.append(work) - ops.clear() - - self._maybe_compute_loss( - fwd_stage, output, target_mbs, mb_index - ) - - # 1F1B Phase (forward and backward) - for step in range(warmup_steps, warmup_steps + fwd_bwd_steps): - fwd_stage = self._stages[forward_stage_local_index(step)] - bwd_stage = self._stages[backward_stage_local_index(step)] - - fwd_stage_mb_index[fwd_stage] = ( - fwd_mb_index := fwd_stage_mb_index[fwd_stage] - ) + 1 - bwd_stage_mb_index[bwd_stage] = ( - bwd_mb_index := bwd_stage_mb_index[bwd_stage] - ) + 1 - - bwd_stage._configure_data_parallel_mode( - bwd_mb_index == self._n_microbatches - 1 - ) - logger.debug( - f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}" - ) - with record_function(f"1F1B {step}"): - ops.extend(fwd_stage.get_fwd_recv_ops()) - ops.extend(bwd_stage.get_bwd_recv_ops()) - if ops: - work = dist.batch_isend_irecv(ops).pop() - work.wait() - ops.clear() - - # Forward - output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] - ops.extend(fwd_stage.get_fwd_send_ops()) - self._maybe_compute_loss( - fwd_stage, output, target_mbs, fwd_mb_index - ) - - # Backward - loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) - bwd_stage.backward_one_chunk(loss=loss) - ops.extend(bwd_stage.get_bwd_send_ops()) - - # Cooldown Phase (backward only) - for step in range(warmup_steps + fwd_bwd_steps, total_steps): - bwd_stage = self._stages[backward_stage_local_index(step)] - bwd_stage_mb_index[bwd_stage] = ( - bwd_mb_index := bwd_stage_mb_index[bwd_stage] - ) + 1 - bwd_stage._configure_data_parallel_mode( - bwd_mb_index == self._n_microbatches - 1 - ) - - logger.debug( - f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}" - ) - with record_function(f"Cooldown {step}"): - ops.extend(bwd_stage.get_bwd_recv_ops()) - if ops: - work = dist.batch_isend_irecv(ops).pop() - work.wait() - ops.clear() - - loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) - bwd_stage.backward_one_chunk(loss=loss) - - ops.extend(bwd_stage.get_bwd_send_ops()) - if ops: - work = dist.batch_isend_irecv(ops).pop() - sends_to_wait.append(work) - ops.clear() - - # Make sure all sends are finished - for work in sends_to_wait: - work.wait() - - # Return losses if there is a container passed in - self._update_losses(self._stages, losses) - - - - - - - - - - ########################################################## - # Dora PP Xlformer Implementation - ########################################################## - - input_tensors = [[] for _ in range(len(self._stages))] - output_tensors = [[] for _ in range(len(self._stages))] - output_tensor_grads = [[] for _ in range(len(self._stages))] - # We need to pop input, output and grad during bwd, we use this list to track real input tensor index. - popped_input_tensors = [[] for _ in range(len(self._stages))] - input_tensor_grad = None - - pipeline_parallel_size = self.pp_group_size - pipeline_parallel_rank = self._stage.stage_index - - microbatch_x = arg_mbs - microbatch_y = target_mbs - microbatch_mask = None - mask = None - if mask is not None: - microbatch_mask = mask.split(args.pipeline_parallel_microbatch_size, dim=0) - - num_microbatches = self._n_microbatches - # microbatch_attn_bias = [ - # model[0].get_attn_bias(microbatch_x[i], cache=None) - # for i in range(num_microbatches) - # ] - microbatch_attn_bias = [ - self._stages[0].submodule.get_attn_bias(microbatch_x[i], cache=None) - for i in range(num_microbatches) - ] - - - # TODO: get the model args from API directly, should modify it later - assert(microbatch_size is not None), "microbatch_size is None" - assert(model_dim is not None), "model_dim is None" - - microbatch_less_than_pp = num_microbatches < pipeline_parallel_size - num_round = max(num_microbatches // pipeline_parallel_size, 1) + num_round = max(self._n_microbatches // self.pp_group_size, 1) assert ( - num_microbatches % num_round == 0 + self._n_microbatches % num_round == 0 ), "Number of microbatches should be divisible by number of pipeline rounds." # the number of microbatches run in each round, in dfs it is pipeline_parallel_size - num_microbatch_per_round = num_microbatches // num_round - - tensor_shape = ( - microbatch_size, - model_dim, - ) - - num_model_chunks = len(model) - total_num_microbatches = num_microbatches * num_model_chunks - - dtype = get_torch_dtype(args.dtype) - - mpu.set_virtual_pipeline_model_parallel_rank(0) - all_warmup_microbatches = False - - if not args.model.enable_ddp: - for model_chunk in model: - model_chunk._rebuild_full_params_recursive() - else: - for model_chunk in model: - model_chunk.zero_grad() + num_microbatch_per_round = self._n_microbatches // num_round + total_num_microbatches = self._n_microbatches * self.n_local_stages + # increment warmup_steps by 2 for each hop away num_warmup_microbatches = 0 # The number of microbatches that last pipeline stage run before 1f1b. - num_warmup_microbatches += (num_model_chunks - 1) * num_microbatch_per_round + num_warmup_microbatches += (self.n_local_stages - 1) * num_microbatch_per_round # From last PP stage up, each rank will be 2 more than the previous one. num_warmup_microbatches += ( - pipeline_parallel_size - pipeline_parallel_rank - 1 + self.pp_group_size - self.rank - 1 ) * 2 + num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches) num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches - # The number of 1f1b for zero bubble schedule - if num_microbatches == pipeline_parallel_size: - num_1f1b_microbatches = pipeline_parallel_rank + if self._n_microbatches == self.pp_group_size: + num_1f1b_microbatches = self.rank else: - num_1f1b_microbatches = 2 * pipeline_parallel_rank - - # Checkpoint the activations of partial Transformer layers in a number of micro-batches - # within the maximum outstanding micro-batch backpropagations. - # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' - # checkpoint partial Transformer layers (or skip checkpointing) and - # the rest of micro-batches within a window of micro-batches checkpoint - # all Transformer layers. The window of micro-batches is set by the maximum - # outstanding backpropagations and becomes smaller at later pipeline stages. - # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf - max_outstanding_backprops = None - if args.num_microbatches_with_partial_activation_checkpoints is not None: - max_outstanding_backprops = num_warmup_microbatches + 1 - - p0_chunk0_batch = [0, 0] - mean_losses = [] - - def get_model_chunk_id(microbatch_id, forward): - """Helper method to get the model chunk ID given the iteration number. - Each group has num_microbatch_per_round * num_model_chunks microbatches. - within each chunk, there are num_microbatch_per_round microbatches. - backward is reverse order of forward. - """ - microbatch_id_in_group = microbatch_id % ( - num_microbatch_per_round * num_model_chunks - ) - model_chunk_id = microbatch_id_in_group // num_microbatch_per_round - if not forward: - model_chunk_id = num_model_chunks - model_chunk_id - 1 - return model_chunk_id - - def get_real_microbatch_id(microbatch_id: int) -> int: - """Get the microbatch id for input tokens.""" - microbatch_group_size = num_microbatch_per_round * num_model_chunks - microbatch_group_id = microbatch_id // microbatch_group_size - real_microbatch_id_in_group = ( - microbatch_id % microbatch_group_size - ) % num_microbatch_per_round - real_microbatch_id = ( - real_microbatch_id_in_group + microbatch_group_id * num_microbatch_per_round - ) - return real_microbatch_id - - def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Check if an iteration is the first for a model chunk.""" - microbatch_group_size = num_microbatch_per_round * num_model_chunks - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == 0: - return microbatch_id_in_group % num_microbatch_per_round == 0 - else: - return False - - def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Check if an iteration is the last for a model chunk.""" - microbatch_group_size = num_microbatch_per_round * num_model_chunks - num_microbatch_groups = total_num_microbatches // microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == num_microbatch_groups - 1: - return ( - microbatch_id_in_group % num_microbatch_per_round - == num_microbatch_per_round - 1 - ) - else: - return False - - def get_input_index(microbatch_id): - """Get pipeline input index for a microbatch""" - microbatch_group_size = num_microbatch_per_round * num_model_chunks - microbatch_id_in_group = microbatch_id % microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - input_index = microbatch_id_in_group % num_microbatch_per_round - return input_index + microbatch_group_id * num_microbatch_per_round - - def microbatch_fwd( - model_chunk_id, - input_tensor, - microbatch_tokens, - y, - state, - mask, - mean_losses, - is_first_microbatch=False, - recompute_attn=None, - recompute_fc1_fc3=None, - attn_bias=None, - ): - if input_tensor is None: - assert mpu.is_pipeline_first_stage() - else: - assert not mpu.is_pipeline_first_stage() - - if args.num_microbatches_with_partial_activation_checkpoints is not None: - output, _ = model[model_chunk_id]( - microbatch_tokens, - pipeline_parallel_input_tensor=input_tensor, - is_first_microbatch=is_first_microbatch, - recompute_attn=recompute_attn, - recompute_fc1_fc3=recompute_fc1_fc3, - precomputed_attn_bias=attn_bias, - ) - else: - output, _ = model[model_chunk_id]( - microbatch_tokens, - pipeline_parallel_input_tensor=input_tensor, - is_first_microbatch=is_first_microbatch, - precomputed_attn_bias=attn_bias, - ) - - if mpu.is_pipeline_last_stage(): - if loss_fn is not None: - loss = loss_fn( - output, - y, - mask, - ) - output = loss.mean() / num_microbatches - else: - if args.model.loss_parallel: - tok_loss = state.scale * vocab_parallel_cross_entropy( - partial_logits=output, - target=y, - z_loss_multiplier=args.z_loss_multiplier, - ) - else: - tok_loss = state.scale * F.cross_entropy( - output.flatten(0, 1), y.flatten(0, 1), reduction="none" - ) - if mask is None: - output = tok_loss.mean() / num_microbatches - else: - mask = mask.flatten(0, 1) - tok_loss = tok_loss * mask - output = tok_loss.sum() / (mask.sum() + 1e-6) / num_microbatches - mean_losses.append(output) - p0_chunk0_batch[1] += 1 - return output - - def deallocate_output_tensor(out): - """Deallocate the output tensor's '.data' field. - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - """ - assert isinstance(out, torch.Tensor), ( - "expected Tensor, found %s." % type(out).__name__ - ) - assert out._base is None, "counter-productive to free a view of another tensor." - out.data.storage().resize_(0) - - def custom_backward(output, grad_output): - """Custom backward where directly call C++ autograd engine. - Since Pytorch's 'backward' checks that the output and - grad have the same shape. We need to manually call the C++ autograd - instead of using Pytorch's torch.autograd.backward. - So that the 'deallocate_output_tensor' optimization can work. - """ + num_1f1b_microbatches = 2 * self.rank - assert ( - output.storage().size() == 0 - ), "output should be pseudo-'freed' in schedule, to optimize memory" - assert isinstance(output, torch.Tensor), ( - "output == '%s'." % type(output).__name__ - ) - assert isinstance(grad_output, (torch.Tensor, type(None))), ( - "grad_output == '%s'." % type(grad_output).__name__ - ) - - # Handle scalar output - if grad_output is None: - assert output.numel() == 1, "implicit grad requires scalar output." - grad_output = torch.ones_like( - output, - memory_format=torch.preserve_format, - ) - - # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] - Variable._execution_engine.run_backward( - tensors=(output,), - grad_tensors=(grad_output,), - keep_graph=False, - create_graph=False, - inputs=tuple(), - allow_unreachable=True, - accumulate_grad=True, - ) - - def microbatch_bwd(input_tensor, output_tensor, output_tensor_grad): - if input_tensor is not None: - input_tensor.retain_grad() - if output_tensor_grad is None: - output_tensor.backward() - else: - if args.deallocate_pipeline_outputs: - custom_backward(output_tensor, output_tensor_grad) - else: - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) - if input_tensor is not None: - return input_tensor.grad - return None - - def forward_step_helper( - microbatch_id, p0_chunk0_batch, recompute_attn=None, recompute_fc1_fc3=None - ): - """Helper method to run forward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - forward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) - mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - is_first_microbatch = is_first_microbatch_for_model_chunk(microbatch_id) - - # forward step - if mpu.is_pipeline_first_stage(): - # This is to make sure each model chunk has the number of input same as num_microbatch - # For other pipeline stages, input will append the received tensor from previous pipeline stage - if len(input_tensors[model_chunk_id]) == len( - output_tensors[model_chunk_id] - ): - input_tensors[model_chunk_id].append(None) - - # input_tensors has all the input for each model chunk. - # If not first PP stage(including virtual), we will use the very last input in input_tensors. - # On the first PP stage, if num_microbatch_per_round is larger than pipeline stage, - # this means we will receive the input num_microbatch_per_round - pipeline_parallel_size earlier than it will be used. - # So we need to use the input according to index of microbatch. We first figure out in this model chunk, which microbatch we are running. - # then substract the number of popped input_tensors. - if mpu.is_pipeline_first_stage(ignore_virtual=True): - input_index = get_input_index(microbatch_id) - input_index -= len(popped_input_tensors[model_chunk_id]) - else: - input_index = -1 - input_tensor = input_tensors[model_chunk_id][input_index] - real_microbatch_id = get_real_microbatch_id(microbatch_id) - output_tensor = microbatch_fwd( - model_chunk_id, - input_tensor, - microbatch_x[real_microbatch_id], - microbatch_y[p0_chunk0_batch[1]], - state, - ( - microbatch_mask[real_microbatch_id] - if microbatch_mask is not None - else None - ), - mean_losses, - is_first_microbatch=is_first_microbatch, - recompute_attn=recompute_attn, - recompute_fc1_fc3=recompute_fc1_fc3, - attn_bias=microbatch_attn_bias[real_microbatch_id], - ) - output_tensors[model_chunk_id].append(output_tensor) - return output_tensor - - def backward_step_helper(microbatch_id): - """Helper method to run backward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - backward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) - mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - if mpu.is_pipeline_last_stage(): - if len(output_tensor_grads[model_chunk_id]) == 0: - output_tensor_grads[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id].pop(0) - popped_input_tensors[model_chunk_id].append(None) - output_tensor = output_tensors[model_chunk_id].pop(0) - output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) - - input_tensor_grad = microbatch_bwd( - input_tensor, output_tensor, output_tensor_grad - ) - # Reuse the deallocate_output_tensor function to release input_tensor - if input_tensor is not None: - deallocate_output_tensor(input_tensor) - - return input_tensor_grad - - mpu.set_virtual_pipeline_model_parallel_rank(0) - with record_function("warmup forward passes p2p comm"): - input_tensors[0].append( - p2p_communication.recv_forward( - tensor_shape, dtype, batch_p2p_comm=batch_p2p_communication - ) - ) + print("------------------------") + print(f"{num_warmup_microbatches=}, {num_microbatches_remaining=}, {num_1f1b_microbatches=}, {total_num_microbatches=}") + print("------------------------") + # Run warmup steps. with record_function("warmup forward passes"): fwd_wait_handles = None bwd_wait_handles = None @@ -1372,378 +876,14 @@ def backward_step_helper(microbatch_id): for req in fwd_wait_handles: req.wait() - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_microbatch = ( - k % max_outstanding_backprops - >= args.num_microbatches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_microbatch = None - - with record_function("1f"): - output_tensor = forward_step_helper( - k, - p0_chunk0_batch, - recompute_attn=checkpoint_activations_microbatch - and args.mb_recompute_attn, - recompute_fc1_fc3=checkpoint_activations_microbatch - and args.mb_recompute_fc1_fc3, - ) - - # Determine the model chunk that received input from this iteration belongs to. - # On the first PP stage, if num_microbatch_per_round is larger than pipeline stage, - # this means we will receive the input num_microbatch_per_round - pipeline_parallel_size earlier than it will be used by its model chunk. - # so to determine the true model chunk, we need to add num_microbatch_per_round - pipeline_parallel_size. - next_forward_model_chunk_id = None - if mpu.is_pipeline_first_stage(ignore_virtual=True): - if microbatch_less_than_pp: - next_forward_model_chunk_id = get_model_chunk_id( - k + 1, - forward=True, - ) - else: - next_forward_model_chunk_id = get_model_chunk_id( - k + 1 + num_microbatch_per_round - pipeline_parallel_size, - forward=True, - ) - else: - next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) - - - recv_prev = True - # For first PP rank, there are two cases that to not receive: - # (1) Before first model chunk of last PP stage start to run, there is nothing to receive. - # (2) when last model chunk of last PP stage start running, last PP rank wont send input anymore. - if mpu.is_pipeline_first_stage(ignore_virtual=True): - if microbatch_less_than_pp: - if k < num_microbatch_per_round - 1: - recv_prev = False - else: - if k < pipeline_parallel_size - 1: - recv_prev = False - elif ( - k - >= (num_model_chunks - 1) * num_microbatch_per_round - + pipeline_parallel_size - - 1 - ): - recv_prev = False - if k == (total_num_microbatches - 1): - recv_prev = False - - # Don't send tensor downstream if on last stage. - if mpu.is_pipeline_last_stage(): - output_tensor = None - - # Send and receive tensors as appropriate (send tensors computed - # in this iteration; receive tensors for next iteration - - ( - input_tensor, - fwd_wait_handles, - ) = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - dtype=dtype, - batch_p2p_comm=batch_p2p_communication, - overlap_p2p_comm=True, - ) - - if k == (num_warmup_microbatches - 1) and not all_warmup_microbatches: - input_tensor_grad = None - recv_next = True - if mpu.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - - ( - output_tensor_grad, - bwd_wait_handles, - ) = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - batch_p2p_comm=batch_p2p_communication, - dtype=dtype, - overlap_p2p_comm=True, - ) - - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) - # make sure number of input tensor is same as number of microbatch - if recv_prev: - input_tensors[next_forward_model_chunk_id].append(input_tensor) - - if args.deallocate_pipeline_outputs and output_tensor is not None: - deallocate_output_tensor(output_tensor) + with record_function("1f"): + print("forward step") # Run 1F1B in steady state. with record_function("forward 1F1B steady"): for k in range(num_microbatches_remaining): - # Forward pass. - forward_k = k + num_warmup_microbatches - sync_grads = is_last_microbatch_for_model_chunk(k) - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_microbatch = ( - forward_k % max_outstanding_backprops - >= args.num_microbatches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_microbatch = None + print("fwd_bwd") - if fwd_wait_handles is not None: - for req in fwd_wait_handles: - req.wait() - - if args.deallocate_pipeline_outputs and output_tensor is not None: - deallocate_output_tensor(output_tensor) - with record_function("1f"): - output_tensor = forward_step_helper( - forward_k, - p0_chunk0_batch, - recompute_attn=checkpoint_activations_microbatch - and args.mb_recompute_attn, - recompute_fc1_fc3=checkpoint_activations_microbatch - and args.mb_recompute_fc1_fc3, - ) - - # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - - # Last virtual stage no activation tensor to send - if mpu.is_pipeline_last_stage(): - output_tensor = None - - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if mpu.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id( - forward_k + 1, forward=True - ) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - # Send activation tensor to the next stage and receive activation tensor from the - # previous stage - ( - input_tensor, - fwd_wait_handles, - ) = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - dtype=dtype, - batch_p2p_comm=batch_p2p_communication, - overlap_p2p_comm=True, - ) - - if bwd_wait_handles is not None: - for req in bwd_wait_handles: - req.wait() - - if input_tensor_grad is not None: - deallocate_output_tensor(input_tensor_grad) - - # Backward pass. - backward_k = k - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - - if not args.model.enable_ddp and sync_grads: - model[ - backward_model_chunk_id - ].dont_wait_current_stream_for_post_all_gather = True - with ( - nullcontext() - if sync_grads - else model[backward_model_chunk_id].no_sync() - ): - with record_function("1b"): - input_tensor_grad = backward_step_helper(backward_k) - - mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - - # First virtual stage no activation gradient tensor to send - if mpu.is_pipeline_first_stage(): - input_tensor_grad = None - - # Determine if the current virtual stage has an activation gradient tensor to receive - recv_next = True - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id( - backward_k + 1, forward=False - ) - - ( - output_tensor_grad, - bwd_wait_handles, - ) = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - batch_p2p_comm=batch_p2p_communication, - overlap_p2p_comm=True, - ) - if not args.model.enable_ddp and sync_grads: - model[ - backward_model_chunk_id - ].dont_wait_current_stream_for_post_all_gather = True - with ( - nullcontext() - if sync_grads - else model[backward_model_chunk_id].no_sync() - ): - if args.zero_bubble and k >= num_1f1b_microbatches: - with record_function("zero bubble 1w"): - WeightGradStore.pop() - - # Put input_tensor and output_tensor_grad in data structures in the - # right location. - if recv_prev: - input_tensors[next_forward_model_chunk_id].append(input_tensor) - if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append( - output_tensor_grad - ) - model_chunk_id = get_model_chunk_id(backward_k, forward=False) - - if args.deallocate_pipeline_outputs and output_tensor is not None: - deallocate_output_tensor(output_tensor) - - # Run cooldown backward passes (flush out pipeline). - with record_function("cooldown backward"): - if overlap_p2p_communication and bwd_wait_handles is not None: - for wait_handle in bwd_wait_handles: - wait_handle.wait() - if input_tensor_grad is not None: - deallocate_output_tensor(input_tensor_grad) - - if all_warmup_microbatches: - output_tensor_grads[num_model_chunks - 1].append( - p2p_communication.recv_backward( - tensor_shape, batch_p2p_comm=batch_p2p_communication, dtype=dtype - ) - ) + with torch.profiler.record_function("cooldown backward"): for k in range(num_microbatches_remaining, total_num_microbatches): - if overlap_p2p_communication and bwd_wait_handles is not None: - for wait_handle in bwd_wait_handles: - wait_handle.wait() - # same as warmup, for last PP stage, currently received grad is - # (num_microbatch_per_round - pipeline_parallel_size) earlier than its corresponding model chunk - if mpu.is_pipeline_last_stage(ignore_virtual=True): - if microbatch_less_than_pp: - next_backward_model_chunk_id = get_model_chunk_id( - k + 1, - forward=False, - ) - else: - next_backward_model_chunk_id = get_model_chunk_id( - k + 1 + num_microbatch_per_round - pipeline_parallel_size, - forward=False, - ) - else: - next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) - model_chunk_id = get_model_chunk_id(k, forward=False) - if not args.model.enable_ddp and is_last_microbatch_for_model_chunk(k): - model[ - model_chunk_id - ].dont_wait_current_stream_for_post_all_gather = True - with ( - nullcontext() - if is_last_microbatch_for_model_chunk(k) - else model[model_chunk_id].no_sync() - ): - with record_function("1b"): - input_tensor_grad = backward_step_helper(k) - - recv_next = True - # for last pp stage, if it start the very last model chunk, then no need to receive - # edge case is when it is bfs, before first model chunk of first pp stage start bwd, last stage doesnt need to receive. - if mpu.is_pipeline_last_stage(ignore_virtual=True): - if microbatch_less_than_pp: - if k < num_microbatch_per_round - 1: - recv_next = False - else: - if k < pipeline_parallel_size - 1: - recv_next = False - elif ( - k - >= total_num_microbatches - - num_microbatch_per_round - - 1 - + pipeline_parallel_size - ): - recv_next = False - if k == (total_num_microbatches - 1): - recv_next = False - - ( - output_tensor_grad, - bwd_wait_handles, - ) = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - batch_p2p_comm=batch_p2p_communication, - overlap_p2p_comm=True, - ) - if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append( - output_tensor_grad - ) - - with ( - nullcontext() - if is_last_microbatch_for_model_chunk(k) - else model[model_chunk_id].no_sync() - ): - with record_function("zero bubble 1w"): - WeightGradStore.pop() - while WeightGradStore.weight_grad_queue.qsize() > 0: - with record_function("zero bubble 1w"): - WeightGradStore.pop() - - # Make sure all communication is finished - torch.cuda.synchronize() - - for model_chunk_id in range(num_model_chunks): - model[model_chunk_id].dont_wait_current_stream_for_post_all_gather = False - # logger.warning(f"model_chunk: {model_chunk_id}; rank: {torch.distributed.get_rank()}") - model[model_chunk_id]._wait_for_post_backward() - - if len(mean_losses) > 0: - sum_loss_across_mb = torch.stack(mean_losses).sum() - else: - sum_loss_across_mb = torch.zeros([], dtype=torch.float32, device="cuda") - - torch.distributed.broadcast( - sum_loss_across_mb, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - ) - return sum_loss_across_mb, None + print("cooldown") diff --git a/pippy/zero_bubble.py b/pippy/zero_bubble.py new file mode 100644 index 000000000..40d71efbc --- /dev/null +++ b/pippy/zero_bubble.py @@ -0,0 +1,739 @@ +########################################################## +# Dora PP Xlformer Implementation +########################################################## + +input_tensors = [[] for _ in range(len(self._stages))] +output_tensors = [[] for _ in range(len(self._stages))] +output_tensor_grads = [[] for _ in range(len(self._stages))] +# We need to pop input, output and grad during bwd, we use this list to track real input tensor index. +popped_input_tensors = [[] for _ in range(len(self._stages))] +input_tensor_grad = None + +pipeline_parallel_size = self.pp_group_size +pipeline_parallel_rank = self._stage.stage_index + +microbatch_x = arg_mbs +microbatch_y = target_mbs +microbatch_mask = None +mask = None +if mask is not None: + microbatch_mask = mask.split(microbatch_size, dim=0) + +num_microbatches = self._n_microbatches +# microbatch_attn_bias = [ +# model[0].get_attn_bias(microbatch_x[i], cache=None) +# for i in range(num_microbatches) +# ] +microbatch_attn_bias = [ + self._stages[0].submodule.get_attn_bias(microbatch_x[i], cache=None) + for i in range(num_microbatches) +] + + +# TODO: get the model args from API directly, should modify it later +assert(microbatch_size is not None), "microbatch_size is None" +assert(model_dim is not None), "model_dim is None" + +microbatch_less_than_pp = num_microbatches < pipeline_parallel_size +num_round = max(num_microbatches // pipeline_parallel_size, 1) +assert ( + num_microbatches % num_round == 0 +), "Number of microbatches should be divisible by number of pipeline rounds." +# the number of microbatches run in each round, in dfs it is pipeline_parallel_size +num_microbatch_per_round = num_microbatches // num_round + +tensor_shape = ( + microbatch_size, + model_dim, +) + +num_model_chunks = len(self._stages) +total_num_microbatches = num_microbatches * num_model_chunks + +dtype = torch.fp16 #get_torch_dtype(args.dtype) + +#mpu.set_virtual_pipeline_model_parallel_rank(0) +all_warmup_microbatches = False + +# if not args.model.enable_ddp: +# for model_chunk in model: +# model_chunk._rebuild_full_params_recursive() +# else: +# for model_chunk in model: +# model_chunk.zero_grad() + +# FSDP only +for model_chunk in self._stages: + model_chunk._rebuild_full_params_recursive() + + +num_warmup_microbatches = 0 +# The number of microbatches that last pipeline stage run before 1f1b. +num_warmup_microbatches += (num_model_chunks - 1) * num_microbatch_per_round +# From last PP stage up, each rank will be 2 more than the previous one. +num_warmup_microbatches += ( + pipeline_parallel_size - pipeline_parallel_rank - 1 +) * 2 +num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches) +num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches +# The number of 1f1b for zero bubble schedule +if num_microbatches == pipeline_parallel_size: + num_1f1b_microbatches = pipeline_parallel_rank +else: + num_1f1b_microbatches = 2 * pipeline_parallel_rank + +# Checkpoint the activations of partial Transformer layers in a number of micro-batches +# within the maximum outstanding micro-batch backpropagations. +# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' +# checkpoint partial Transformer layers (or skip checkpointing) and +# the rest of micro-batches within a window of micro-batches checkpoint +# all Transformer layers. The window of micro-batches is set by the maximum +# outstanding backpropagations and becomes smaller at later pipeline stages. +# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf + +# max_outstanding_backprops = None +# if args.num_microbatches_with_partial_activation_checkpoints is not None: +# max_outstanding_backprops = num_warmup_microbatches + 1 + +p0_chunk0_batch = [0, 0] +mean_losses = [] + +def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number. + Each group has num_microbatch_per_round * num_model_chunks microbatches. + within each chunk, there are num_microbatch_per_round microbatches. + backward is reverse order of forward. + """ + microbatch_id_in_group = microbatch_id % ( + num_microbatch_per_round * num_model_chunks + ) + model_chunk_id = microbatch_id_in_group // num_microbatch_per_round + if not forward: + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id + +def get_real_microbatch_id(microbatch_id: int) -> int: + """Get the microbatch id for input tokens.""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + microbatch_group_id = microbatch_id // microbatch_group_size + real_microbatch_id_in_group = ( + microbatch_id % microbatch_group_size + ) % num_microbatch_per_round + real_microbatch_id = ( + real_microbatch_id_in_group + microbatch_group_id * num_microbatch_per_round + ) + return real_microbatch_id + +def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the first for a model chunk.""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == 0: + return microbatch_id_in_group % num_microbatch_per_round == 0 + else: + return False + +def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the last for a model chunk.""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + num_microbatch_groups = total_num_microbatches // microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == num_microbatch_groups - 1: + return ( + microbatch_id_in_group % num_microbatch_per_round + == num_microbatch_per_round - 1 + ) + else: + return False + +def get_input_index(microbatch_id): + """Get pipeline input index for a microbatch""" + microbatch_group_size = num_microbatch_per_round * num_model_chunks + microbatch_id_in_group = microbatch_id % microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + input_index = microbatch_id_in_group % num_microbatch_per_round + return input_index + microbatch_group_id * num_microbatch_per_round + +def microbatch_fwd( + model_chunk_id, + input_tensor, + microbatch_tokens, + y, + state, + mask, + mean_losses, + is_first_microbatch=False, + recompute_attn=None, + recompute_fc1_fc3=None, + attn_bias=None, +): + if input_tensor is None: + assert self.rank == 0 # first stage + else: + assert not self.rank != 0 + + output, _ = self._stages[model_chunk_id]( + microbatch_tokens, + pipeline_parallel_input_tensor=input_tensor, + is_first_microbatch=is_first_microbatch, + precomputed_attn_bias=attn_bias, + ) + + if self.rank == self.pp_group_size - 1: + if loss_fn is not None: + loss = loss_fn( + output, + y, + mask, + ) + output = loss.mean() / num_microbatches + else: + if args.model.loss_parallel: + tok_loss = state.scale * vocab_parallel_cross_entropy( + partial_logits=output, + target=y, + z_loss_multiplier=args.z_loss_multiplier, + ) + else: + tok_loss = state.scale * F.cross_entropy( + output.flatten(0, 1), y.flatten(0, 1), reduction="none" + ) + if mask is None: + output = tok_loss.mean() / num_microbatches + else: + mask = mask.flatten(0, 1) + tok_loss = tok_loss * mask + output = tok_loss.sum() / (mask.sum() + 1e-6) / num_microbatches + mean_losses.append(output) + p0_chunk0_batch[1] += 1 + return output + +def deallocate_output_tensor(out): + """Deallocate the output tensor's '.data' field. + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + assert isinstance(out, torch.Tensor), ( + "expected Tensor, found %s." % type(out).__name__ + ) + assert out._base is None, "counter-productive to free a view of another tensor." + out.data.storage().resize_(0) + +def custom_backward(output, grad_output): + """Custom backward where directly call C++ autograd engine. + Since Pytorch's 'backward' checks that the output and + grad have the same shape. We need to manually call the C++ autograd + instead of using Pytorch's torch.autograd.backward. + So that the 'deallocate_output_tensor' optimization can work. + """ + + assert ( + output.storage().size() == 0 + ), "output should be pseudo-'freed' in schedule, to optimize memory" + assert isinstance(output, torch.Tensor), ( + "output == '%s'." % type(output).__name__ + ) + assert isinstance(grad_output, (torch.Tensor, type(None))), ( + "grad_output == '%s'." % type(grad_output).__name__ + ) + + # Handle scalar output + if grad_output is None: + assert output.numel() == 1, "implicit grad requires scalar output." + grad_output = torch.ones_like( + output, + memory_format=torch.preserve_format, + ) + + # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] + Variable._execution_engine.run_backward( + tensors=(output,), + grad_tensors=(grad_output,), + keep_graph=False, + create_graph=False, + inputs=tuple(), + allow_unreachable=True, + accumulate_grad=True, + ) + +def microbatch_bwd(input_tensor, output_tensor, output_tensor_grad): + if input_tensor is not None: + input_tensor.retain_grad() + if output_tensor_grad is None: + output_tensor.backward() + else: + # if args.deallocate_pipeline_outputs: + # custom_backward(output_tensor, output_tensor_grad) + # else: + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + if input_tensor is not None: + return input_tensor.grad + return None + +def forward_step_helper( + microbatch_id, p0_chunk0_batch, recompute_attn=None, recompute_fc1_fc3=None +): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + #mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + is_first_microbatch = is_first_microbatch_for_model_chunk(microbatch_id) + + # forward step + if self.rank == 0: + # This is to make sure each model chunk has the number of input same as num_microbatch + # For other pipeline stages, input will append the received tensor from previous pipeline stage + if len(input_tensors[model_chunk_id]) == len( + output_tensors[model_chunk_id] + ): + input_tensors[model_chunk_id].append(None) + + # input_tensors has all the input for each model chunk. + # If not first PP stage(including virtual), we will use the very last input in input_tensors. + # On the first PP stage, if num_microbatch_per_round is larger than pipeline stage, + # this means we will receive the input num_microbatch_per_round - pipeline_parallel_size earlier than it will be used. + # So we need to use the input according to index of microbatch. We first figure out in this model chunk, which microbatch we are running. + # then substract the number of popped input_tensors. + if self.rank == 0: + input_index = get_input_index(microbatch_id) + input_index -= len(popped_input_tensors[model_chunk_id]) + else: + input_index = -1 + input_tensor = input_tensors[model_chunk_id][input_index] + real_microbatch_id = get_real_microbatch_id(microbatch_id) + output_tensor = microbatch_fwd( + model_chunk_id, + input_tensor, + microbatch_x[real_microbatch_id], + microbatch_y[p0_chunk0_batch[1]], + state, + ( + microbatch_mask[real_microbatch_id] + if microbatch_mask is not None + else None + ), + mean_losses, + is_first_microbatch=is_first_microbatch, + recompute_attn=recompute_attn, + recompute_fc1_fc3=recompute_fc1_fc3, + attn_bias=microbatch_attn_bias[real_microbatch_id], + ) + output_tensors[model_chunk_id].append(output_tensor) + return output_tensor + +def backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + #mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + if self.rank == self.pp_group_size -1: + if len(output_tensor_grads[model_chunk_id]) == 0: + output_tensor_grads[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id].pop(0) + popped_input_tensors[model_chunk_id].append(None) + output_tensor = output_tensors[model_chunk_id].pop(0) + output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) + + input_tensor_grad = microbatch_bwd( + input_tensor, output_tensor, output_tensor_grad + ) + # Reuse the deallocate_output_tensor function to release input_tensor + if input_tensor is not None: + deallocate_output_tensor(input_tensor) + + return input_tensor_grad + +#mpu.set_virtual_pipeline_model_parallel_rank(0) +with record_function("warmup forward passes p2p comm"): + input_tensors[0].append( + p2p_communication.recv_forward( + tensor_shape, dtype, batch_p2p_comm=batch_p2p_communication + ) + ) + +with record_function("warmup forward passes"): + fwd_wait_handles = None + bwd_wait_handles = None + for k in range(num_warmup_microbatches): + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + # Decide to checkpoint all layers' activations of the current micro-batch + # if max_outstanding_backprops is not None: + # checkpoint_activations_microbatch = ( + # k % max_outstanding_backprops + # >= args.num_microbatches_with_partial_activation_checkpoints + # ) + # else: + checkpoint_activations_microbatch = None + + with record_function("1f"): + output_tensor = forward_step_helper( + k, + p0_chunk0_batch, + recompute_attn=checkpoint_activations_microbatch, + recompute_fc1_fc3=checkpoint_activations_microbatch + ) + + # Determine the model chunk that received input from this iteration belongs to. + # On the first PP stage, if num_microbatch_per_round is larger than pipeline stage, + # this means we will receive the input num_microbatch_per_round - pipeline_parallel_size earlier than it will be used by its model chunk. + # so to determine the true model chunk, we need to add num_microbatch_per_round - pipeline_parallel_size. + next_forward_model_chunk_id = None + if self.rank == 0: + if microbatch_less_than_pp: + next_forward_model_chunk_id = get_model_chunk_id( + k + 1, + forward=True, + ) + else: + next_forward_model_chunk_id = get_model_chunk_id( + k + 1 + num_microbatch_per_round - pipeline_parallel_size, + forward=True, + ) + else: + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) + + + recv_prev = True + # For first PP rank, there are two cases that to not receive: + # (1) Before first model chunk of last PP stage start to run, there is nothing to receive. + # (2) when last model chunk of last PP stage start running, last PP rank wont send input anymore. + if self.rank == 0: + if microbatch_less_than_pp: + if k < num_microbatch_per_round - 1: + recv_prev = False + else: + if k < pipeline_parallel_size - 1: + recv_prev = False + elif ( + k + >= (num_model_chunks - 1) * num_microbatch_per_round + + pipeline_parallel_size + - 1 + ): + recv_prev = False + if k == (total_num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if self.rank == self.pp_group_size - 1: + output_tensor = None + + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration + + ( + input_tensor, + fwd_wait_handles, + ) = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + + if k == (num_warmup_microbatches - 1) and not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if self.rank == self.pp_group_size - 1: + recv_next = False + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_communication, + dtype=dtype, + overlap_p2p_comm=True, + ) + + output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) + # make sure number of input tensor is same as number of microbatch + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + + if self.deallocate_pipeline_outputs and output_tensor is not None: + deallocate_output_tensor(output_tensor) + +# Run 1F1B in steady state. +with record_function("forward 1F1B steady"): + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + sync_grads = is_last_microbatch_for_model_chunk(k) + + # Decide to checkpoint all layers' activations of the current micro-batch + # if max_outstanding_backprops is not None: + # checkpoint_activations_microbatch = ( + # forward_k % max_outstanding_backprops + # >= args.num_microbatches_with_partial_activation_checkpoints + # ) + # else: + checkpoint_activations_microbatch = None + + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + if self.deallocate_pipeline_outputs and output_tensor is not None: + deallocate_output_tensor(output_tensor) + with record_function("1f"): + output_tensor = forward_step_helper( + forward_k, + p0_chunk0_batch, + recompute_attn=checkpoint_activations_microbatch, + recompute_fc1_fc3=checkpoint_activations_microbatch, + ) + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + #mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + + # Last virtual stage no activation tensor to send + if self.rank == self.pp_group_size - 1: + output_tensor = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if self.rank == 0: + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True + ) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id( + forward_k + 1, forward=True + ) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + ( + input_tensor, + fwd_wait_handles, + ) = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() + + if input_tensor_grad is not None: + deallocate_output_tensor(input_tensor_grad) + + # Backward pass. + backward_k = k + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + + if not args.model.enable_ddp and sync_grads: + model[ + backward_model_chunk_id + ].dont_wait_current_stream_for_post_all_gather = True + with ( + nullcontext() + if sync_grads + else model[backward_model_chunk_id].no_sync() + ): + with record_function("1b"): + input_tensor_grad = backward_step_helper(backward_k) + + mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + + # First virtual stage no activation gradient tensor to send + if mpu.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if the current virtual stage has an activation gradient tensor to receive + recv_next = True + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id( + backward_k + 1, forward=False + ) + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + if not args.model.enable_ddp and sync_grads: + model[ + backward_model_chunk_id + ].dont_wait_current_stream_for_post_all_gather = True + with ( + nullcontext() + if sync_grads + else model[backward_model_chunk_id].no_sync() + ): + if args.zero_bubble and k >= num_1f1b_microbatches: + with record_function("zero bubble 1w"): + WeightGradStore.pop() + + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + output_tensor_grad + ) + model_chunk_id = get_model_chunk_id(backward_k, forward=False) + +if args.deallocate_pipeline_outputs and output_tensor is not None: + deallocate_output_tensor(output_tensor) + +# Run cooldown backward passes (flush out pipeline). +with record_function("cooldown backward"): + if overlap_p2p_communication and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + if input_tensor_grad is not None: + deallocate_output_tensor(input_tensor_grad) + + if all_warmup_microbatches: + output_tensor_grads[num_model_chunks - 1].append( + p2p_communication.recv_backward( + tensor_shape, batch_p2p_comm=batch_p2p_communication, dtype=dtype + ) + ) + for k in range(num_microbatches_remaining, total_num_microbatches): + if overlap_p2p_communication and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + # same as warmup, for last PP stage, currently received grad is + # (num_microbatch_per_round - pipeline_parallel_size) earlier than its corresponding model chunk + if mpu.is_pipeline_last_stage(ignore_virtual=True): + if microbatch_less_than_pp: + next_backward_model_chunk_id = get_model_chunk_id( + k + 1, + forward=False, + ) + else: + next_backward_model_chunk_id = get_model_chunk_id( + k + 1 + num_microbatch_per_round - pipeline_parallel_size, + forward=False, + ) + else: + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) + model_chunk_id = get_model_chunk_id(k, forward=False) + if not args.model.enable_ddp and is_last_microbatch_for_model_chunk(k): + model[ + model_chunk_id + ].dont_wait_current_stream_for_post_all_gather = True + with ( + nullcontext() + if is_last_microbatch_for_model_chunk(k) + else model[model_chunk_id].no_sync() + ): + with record_function("1b"): + input_tensor_grad = backward_step_helper(k) + + recv_next = True + # for last pp stage, if it start the very last model chunk, then no need to receive + # edge case is when it is bfs, before first model chunk of first pp stage start bwd, last stage doesnt need to receive. + if mpu.is_pipeline_last_stage(ignore_virtual=True): + if microbatch_less_than_pp: + if k < num_microbatch_per_round - 1: + recv_next = False + else: + if k < pipeline_parallel_size - 1: + recv_next = False + elif ( + k + >= total_num_microbatches + - num_microbatch_per_round + - 1 + + pipeline_parallel_size + ): + recv_next = False + if k == (total_num_microbatches - 1): + recv_next = False + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + overlap_p2p_comm=True, + ) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + output_tensor_grad + ) + + with ( + nullcontext() + if is_last_microbatch_for_model_chunk(k) + else model[model_chunk_id].no_sync() + ): + with record_function("zero bubble 1w"): + WeightGradStore.pop() +while WeightGradStore.weight_grad_queue.qsize() > 0: + with record_function("zero bubble 1w"): + WeightGradStore.pop() + + # Make sure all communication is finished + torch.cuda.synchronize() + +for model_chunk_id in range(num_model_chunks): + model[model_chunk_id].dont_wait_current_stream_for_post_all_gather = False + # logger.warning(f"model_chunk: {model_chunk_id}; rank: {torch.distributed.get_rank()}") + model[model_chunk_id]._wait_for_post_backward() + +if len(mean_losses) > 0: + sum_loss_across_mb = torch.stack(mean_losses).sum() +else: + sum_loss_across_mb = torch.zeros([], dtype=torch.float32, device="cuda") + +torch.distributed.broadcast( + sum_loss_across_mb, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), +) +return sum_loss_across_mb, None diff --git a/test/test_pipeline_schedule_e2e.py b/test/test_pipeline_schedule_e2e.py index 93f4fffa5..0b8620954 100644 --- a/test/test_pipeline_schedule_e2e.py +++ b/test/test_pipeline_schedule_e2e.py @@ -167,8 +167,8 @@ def rank_print(msg): module_list = torch.nn.ModuleList( modules=[model for i in range(world_size)] ) - microbatch_size = 8 - global_batch_size = 64 + microbatch_size = 1 + global_batch_size = 8 assert global_batch_size % microbatch_size == 0 n_microbatches = int(global_batch_size / microbatch_size) @@ -203,7 +203,7 @@ def rank_print(msg): # pipe = pipeline(model, n_microbatches, example_args=(input_args,)) # stage = PipelineStage(pipe, rank, device) - print(f"{[sm.stage_index for sm in stage_model_looped]}") + print(f"Stage: {rank} {[sm.stage_index for sm in stage_model_looped]}") x_cuda_empty = torch.empty_like(x, device="cuda") @@ -237,8 +237,7 @@ def rank_print(msg): ) elif schedule == "doraPP": my_schedule = ScheduleDoraPP( - stage_model_looped, n_microbatches, loss_fn, microbatch_size=microbatch_size, model_dim=input_dim, - ) + stage_model_looped, n_microbatches, loss_fn) if _run_profiler: logger.info(f"====== Rank {rank} profile ======")