Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,20 @@ def forward_step(data_iterator, model):
num_microbatches (int, required):
The number of microbatches to go through

seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.
seq_length (int, required): Sequence length of the current global batch. If this is a
dual-stack transformer, this is the encoder's sequence length. Its effect depends on
which schedule is selected:

* ``forward_backward_no_pipelining`` (pp_size=1): ``seq_length`` is unused.
* ``forward_backward_pipelining_without_interleaving`` (pp_size>1, vp_size=None):
when ``config.variable_seq_lengths=True``, P2P activation tensor shapes are
exchanged dynamically and ``seq_length`` is ignored; otherwise every microbatch
in the current global batch must use exactly this ``seq_length``.
* ``forward_backward_pipelining_with_interleaving`` (pp_size>1, vp_size>1):
``seq_length`` is always used to size the P2P activation buffer as
``[seq_length, micro_batch_size, hidden_size]``. When
``config.variable_seq_lengths=True`` it acts as the per-step maximum sequence
length used for P2P; actual microbatches may be shorter, but must not exceed it.

micro_batch_size (int, required): The number of sequences in a microbatch.

Expand Down Expand Up @@ -911,6 +921,12 @@ def forward_backward_pipelining_with_interleaving(
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.

``seq_length`` is required and is always used here to size the P2P activation buffer as
``[seq_length, micro_batch_size, hidden_size]`` (then divided by ``cp_group.size()`` and,
when sequence parallelism is enabled, by ``tp_group.size()``). When
``config.variable_seq_lengths=True`` it acts as the per-step maximum sequence length;
actual microbatches may be shorter, but must not exceed it.

Returns dictionary with losses if the last stage, empty dict otherwise."""

# Convention used in this function:
Expand Down Expand Up @@ -2051,8 +2067,13 @@ def forward_backward_pipelining_without_interleaving(
] = None,
force_all_reduce: Optional[bool] = False,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages. Returns dictionary with losses if the last stage, empty dict otherwise."""
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.

When ``config.variable_seq_lengths=True``, P2P activation tensor shapes are exchanged
dynamically (see ``get_tensor_shapes``) and ``seq_length`` is ignored. Otherwise every
microbatch in the current global batch must use exactly this ``seq_length``.

Returns dictionary with losses if the last stage, empty dict otherwise."""

if isinstance(model, list):
assert (
Expand Down