diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 14fc6041574..3832b7f6514 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -104,10 +104,20 @@ def forward_step(data_iterator, model): num_microbatches (int, required): The number of microbatches to go through - seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack - transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths - in the config is True. Otherwise, each microbatch in the current global batch size must use - this sequence length. + seq_length (int, required): Sequence length of the current global batch. If this is a + dual-stack transformer, this is the encoder's sequence length. Its effect depends on + which schedule is selected: + + * ``forward_backward_no_pipelining`` (pp_size=1): ``seq_length`` is unused. + * ``forward_backward_pipelining_without_interleaving`` (pp_size>1, vp_size=None): + when ``config.variable_seq_lengths=True``, P2P activation tensor shapes are + exchanged dynamically and ``seq_length`` is ignored; otherwise every microbatch + in the current global batch must use exactly this ``seq_length``. + * ``forward_backward_pipelining_with_interleaving`` (pp_size>1, vp_size>1): + ``seq_length`` is always used to size the P2P activation buffer as + ``[seq_length, micro_batch_size, hidden_size]``. When + ``config.variable_seq_lengths=True`` it acts as the per-step maximum sequence + length used for P2P; actual microbatches may be shorter, but must not exceed it. micro_batch_size (int, required): The number of sequences in a microbatch. @@ -911,6 +921,12 @@ def forward_backward_pipelining_with_interleaving( """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. + ``seq_length`` is required and is always used here to size the P2P activation buffer as + ``[seq_length, micro_batch_size, hidden_size]`` (then divided by ``cp_group.size()`` and, + when sequence parallelism is enabled, by ``tp_group.size()``). When + ``config.variable_seq_lengths=True`` it acts as the per-step maximum sequence length; + actual microbatches may be shorter, but must not exceed it. + Returns dictionary with losses if the last stage, empty dict otherwise.""" # Convention used in this function: @@ -2051,8 +2067,13 @@ def forward_backward_pipelining_without_interleaving( ] = None, force_all_reduce: Optional[bool] = False, ): - """Run non-interleaved 1F1B schedule, with communication between pipeline - stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" + """Run non-interleaved 1F1B schedule, with communication between pipeline stages. + + When ``config.variable_seq_lengths=True``, P2P activation tensor shapes are exchanged + dynamically (see ``get_tensor_shapes``) and ``seq_length`` is ignored. Otherwise every + microbatch in the current global batch must use exactly this ``seq_length``. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" if isinstance(model, list): assert (