From 82acdca574735f4aab49ba43c8fd2a775a813e02 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 27 Aug 2025 20:37:52 -0700 Subject: [PATCH] Add pipeline parallel example --- examples/example_pipeline.py | 750 +++++++++++++++++++++++++++++++++++ 1 file changed, 750 insertions(+) create mode 100644 examples/example_pipeline.py diff --git a/examples/example_pipeline.py b/examples/example_pipeline.py new file mode 100644 index 00000000..944de9d1 --- /dev/null +++ b/examples/example_pipeline.py @@ -0,0 +1,750 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import time +from dataclasses import dataclass +from typing import Callable, ClassVar + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.pipelining.schedules import ( + PipelineScheduleMulti, + _PipelineSchedule, + get_schedule_class, +) +from torch.distributed.pipelining.stage import PipelineStage +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.nn.attention import SDPBackend, sdpa_kernel + +from autoparallel.api import AutoParallel + +logger = logging.getLogger(__name__) + + +def has_cuda_capability(major: int, minor: int) -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + major, + minor, + ) + + +# Pipeline helpers copied from torchtitan + + +def build_pipeline_schedule( + stages: list[PipelineStage], + loss_fn: Callable, + pipeline_parallel_schedule: str, + microbatch_size: int, + local_batch_size: int, + pipeline_parallel_degree: int, +) -> _PipelineSchedule: + """Builds a pipeline schedule for the given configuration and stages.""" + schedule_class = get_schedule_class(pipeline_parallel_schedule) + + looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) + # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training + if local_batch_size % microbatch_size != 0: + raise ValueError( + f"Batch size {local_batch_size} must be divisible by {microbatch_size=}. " + ) + n_microbatches = local_batch_size // microbatch_size + # We expect that the number of local stages (`len(stages)`) is the same across all ranks + num_total_stages = pipeline_parallel_degree * len(stages) + if n_microbatches < num_total_stages: + logger.warning( + f"Number of microbatches ({n_microbatches}) is less than the total number " + f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." + ) + + schedule = schedule_class( + stages if looped_schedule else stages[0], + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + logger.info( + f"Using pipeline schedule {pipeline_parallel_schedule} " + f"with {n_microbatches} microbatches and {num_total_stages} stages." + ) + return schedule + + +# / pipeline helpers + + +class ScaledDotProductAttention(torch.nn.Module): + backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + + ScaledDotProductAttention._init_backend() + + @classmethod + def _init_backend(cls) -> None: + if cls.backends: + return + + # Add CuDNN on B200 w/ highest priority + cls.backends = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + if has_cuda_capability(10, 0): + cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + assert self.backends, "SDPA Backends should not be empty." + with sdpa_kernel(self.backends, set_priority=True): + return F.scaled_dot_product_attention(q, k, v, is_causal=True) + + +def build_attention( + use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None +): + if use_flex_attn: + raise NotImplementedError() + # return FlexAttention(attn_mask_type, fixed_block_size) + else: + if fixed_block_size is not None: + raise ValueError( + "TorchTitan with SDPA currently does not support fixed_block_size." + ) + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + return ScaledDotProductAttention(attn_mask_type) + + +@dataclass +class TransformerModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int | None = None + vocab_size: int = 64000 # -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: float | None = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 + + def update_from_config(self, job_config, tokenizer) -> None: + self.vocab_size = tokenizer.n_words + self.max_seq_len = job_config.training.seq_len + self.eos_id = tokenizer.eos_id + + if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with selective AC yet. " + "See https://github.com/pytorch/pytorch/issues/147879" + ) + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with CP yet. " + "We are still working on this." + ) + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + nparams = sum(p.numel() for p in model.parameters()) + nparams_embedding = sum( + sum(p.numel() for p in m.parameters()) + for m in model.children() + if isinstance(m, nn.Embedding) + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + + return nparams, num_flops_per_token + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # TODO: uncomment + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + output = self.sdpa(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float | None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + return + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device # type: ignore + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + input_batch (torch.Tensor): The input batch read from the dataloader. + This will always be the input batch regardless of the pipeline stage. + This field is required for non-first PP stages to perform document + masking attention (to analyze the boundary of the document). + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output + + +# ============================================================== +# AutoParallel code starts here +# ============================================================== + +torch.distributed.init_process_group() +assert "WORLD_SIZE" in os.environ, "run with torchrun --nproc-per-node 4 or 8" +world_size = int(os.getenv("WORLD_SIZE")) +pp_degree = 2 +world_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (pp_degree, world_size // pp_degree), + mesh_dim_names=( + "pp", + "dp", + ), +) +use_1d_mesh = True +pp_rank = world_mesh["pp"].get_local_rank() +spmd_mesh = world_mesh["dp"] + +global_batch_size = 4 * world_mesh.shape[0] +dp_degree = world_mesh.size(1) +local_batch_size = global_batch_size // dp_degree +n_microbatches = 4 +microbatch_size = local_batch_size // n_microbatches +assert microbatch_size >= 1, f"invalid config {local_batch_size=}, {n_microbatches=}" +spmd_batch_size = microbatch_size * dp_degree + +seqlen = 2048 * 4 +vocab_size = 128256 +use_vocab_parallel = not use_1d_mesh +device = torch.device("cuda") + + +def model_fn(): + model_args = TransformerModelArgs( + n_layers=4, + vocab_size=vocab_size, + max_seq_len=seqlen, + multiple_of=1024, + ffn_dim_multiplier=1.3, + n_kv_heads=8, + ) + m = Transformer(model_args) + + # Very simplified / hardcoded version of split util in torchtitan, for the case of pp=2 n_layers=4 + if pp_rank == 0: + del m.layers["2"] + del m.layers["3"] + m.norm = None + m.output = None + else: + m.tok_embeddings = None + del m.layers["0"] + del m.layers["1"] + return m + + +def input_fn(runtime=False): + # batch_size = microbatch_size if runtime else spmd_batch_size + # note: don't use microbatch_size at runtime, bc pipeline runtime will actually split the mb's + batch_size = spmd_batch_size + if pp_rank == 0: + x = torch.randint(0, vocab_size, (batch_size, seqlen), device=device) + else: + x = torch.randn( + (batch_size, seqlen, model.model_args.dim), + device=device, + dtype=torch.bfloat16, + requires_grad=True, + ) + return x + + +# parallelize the model +with torch.device("meta"): + model = model_fn() + +mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) + +# parallelize the model +with AutoParallel( + model, input_fn, spmd_mesh, mp_policy, compile=True, repeated_subgraphs=True +) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0),) + (Replicate(),) * (spmd_mesh.ndim - 1) + out_sharding = x_sharding + if use_vocab_parallel: + # add vocab parallel constraint + assert spmd_mesh.ndim == 2, "Only 2d mesh supported here" + out_sharding = (Shard(0), Shard(2)) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + + # example of how to add manual constraints + if use_1d_mesh and pp_rank == 0: + # add constraint on the output sharding of embedding bag + # otherwise it might decide that it's ok to replicate both inputs. This is indeed fine + # for 1d but the current cost model doesn't take output memory into account, so it thinks + # it is not expensive. I should add an activation memory constraint as well to avoid + # those cases + embedding_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.embedding.default + ) + autop.sharding_optimizer.add_node_constraint(embedding_nodes[0], x_sharding) + + t = time.time() + sharding_placement = autop.optimize_placement() + print(f"Took {time.time() - t:.2f} s") + parallel_mod = autop.apply_placement(sharding_placement) + +# run weight init on our sharded DTensor params +torch.manual_seed(pp_rank) +parallel_mod.to_empty(device="cuda") +parallel_mod.init_weights() + +# now let's run it +x = input_fn(runtime=True) +stage = PipelineStage( + parallel_mod, + stage_index=pp_rank, + num_stages=pp_degree, + device="cuda", + group=world_mesh.get_group("pp"), +) +schedule = build_pipeline_schedule( + stages=[stage], + loss_fn=None, + pipeline_parallel_schedule="1F1B", + microbatch_size=microbatch_size, + local_batch_size=spmd_batch_size, + pipeline_parallel_degree=pp_degree, +) +if pp_rank == 0: + schedule.step(x) +else: + schedule.step() +print("All good!") +torch.distributed.barrier() +torch.cuda.synchronize() +torch.distributed.destroy_process_group()