diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index 5494d531e52..7073bf72066 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -20,6 +20,7 @@ from megatron.core.fp8_utils import get_fp8_context from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.models.hybrid.hybrid_layer_fusion import build_fused_layer from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig @@ -36,8 +37,19 @@ class HybridStackSubmodules: """ A class for the module specs for the HybridStack. + + The `*_layer` fields specify the full block used for a stand-alone layer + (a single symbol in the hybrid layer pattern). + + The `*_mixer` fields specify the primitive module plugged into a + `TransformerLayer` when two layers are fused via the `[XY]` syntax + in the hybrid layer pattern. For a fused group `[XY]`, the sequence mixer + corresponding to `X` is used as `self_attention` and the channel mixer + corresponding to `Y` is used as `mlp` inside a freshly-built + `TransformerLayer` spec. """ + # Stand-alone layer specs – one block per pattern symbol. mamba_layer: Union[ModuleSpec, type] = IdentityOp gdn_layer: Union[ModuleSpec, type] = IdentityOp attention_layer: Union[ModuleSpec, type] = IdentityOp @@ -46,6 +58,16 @@ class HybridStackSubmodules: moe_layer: Union[ModuleSpec, type] = IdentityOp mtp_block_spec: Optional[ModuleSpec] = None + # Primitive specs used when building a fused TransformerLayer at runtime. + # Sequence mixers fill the `self_attention` slot: + mamba_mixer: Union[ModuleSpec, type] = IdentityOp + gdn_mixer: Union[ModuleSpec, type] = IdentityOp + attention_mixer: Union[ModuleSpec, type] = IdentityOp + dsa_mixer: Union[ModuleSpec, type] = IdentityOp + # Channel mixers fill the `mlp` slot: + mlp_mixer: Union[ModuleSpec, type] = IdentityOp + moe_mixer: Union[ModuleSpec, type] = IdentityOp + class HybridStack(GraphableMegatronModule, MegatronModule): """ @@ -60,7 +82,8 @@ class HybridStack(GraphableMegatronModule, MegatronModule): this pipeline segment. When provided (by HybridModel), pipeline stage selection has already been done via '|' separators in the pattern. pp_layer_offset (int, optional): the global layer offset for this pipeline - segment. Defaults to 0. + segment, measured in physical blocks (fused groups count as one). + Defaults to 0. post_layer_norm (bool, optional): whether to include a final layer norm. Defaults to True. post_process (bool, optional): whether to include an output layer. @@ -118,7 +141,24 @@ def __init__( else: quant_init_context = nullcontext() with quant_init_context: - if layer_type == LayerSymbols.MAMBA: + if len(layer_type) > 1: + # Multi-character entries come from bracketed fusion groups + # in the hybrid layer pattern, e.g., "[*-]" -> "*-". + # `layer_number` already includes `pp_layer_offset` (see + # the computation at the top of this loop), so the outer + # TransformerLayer must not add it again – same contract + # as the stand-alone TransformerLayer dispatches below. + layer = build_fused_layer( + layer_type, + submodules, + config=self.config, + layer_number=layer_number, + pg_collection=pg_collection, + pp_layer_offset=pp_layer_offset, + is_mtp_layer=is_mtp_layer, + add_layer_offset=False, + ) + elif layer_type == LayerSymbols.MAMBA: layer = build_module( submodules.mamba_layer, config=self.config, @@ -200,10 +240,19 @@ def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int """ Returns the Mamba conv and ssm states shapes per input sequence if this block contains Mamba layers (this may not be the case with PP > 1). + + A stand-alone Mamba block exposes `mamba_state_shapes_per_request` + directly (it is a `MambaLayer`). A fused `[M...]` block is a + `TransformerLayer` whose `self_attention` slot holds the MambaMixer, + so we have to descend one level. """ for layer_type, layer in zip(self.layer_type_list, self.layers): if layer_type == LayerSymbols.MAMBA: return layer.mamba_state_shapes_per_request() + # Fused Mamba is surfaced via the enclosing TransformerLayer's `self_attention` + # attribute. + elif LayerSymbols.MAMBA in layer_type: + return layer.self_attention.mamba_state_shapes_per_request() return None def _should_call_local_cudagraph(self, *args, **kwargs): diff --git a/megatron/core/models/hybrid/hybrid_layer_allocation.py b/megatron/core/models/hybrid/hybrid_layer_allocation.py index f1ba94ef7fa..ab67b69e74d 100644 --- a/megatron/core/models/hybrid/hybrid_layer_allocation.py +++ b/megatron/core/models/hybrid/hybrid_layer_allocation.py @@ -22,6 +22,8 @@ class Symbols: MOE = 'E' PIPE = '|' MTP_SEPARATOR = "/" + FUSION_START = "[" + FUSION_END = "]" VALID_LAYERS = {MAMBA, GDN, ATTENTION, DS_ATTENTION, MLP, MOE} @classmethod @@ -137,8 +139,8 @@ def get_hybrid_total_layer_count(pattern: str) -> int: Total number of layers in the main decoder pattern. """ main_pattern = pattern.split(Symbols.MTP_SEPARATOR)[0] - _validate_pattern(main_pattern, "main", allow_pipe=True) - return len(main_pattern.replace(Symbols.PIPE, '')) + _validate_pattern(main_pattern, "main", allow_pipe=True, allow_brackets=True) + return count_pattern_layers(main_pattern) def get_hybrid_total_pipeline_segment_count(pattern: str) -> int: @@ -204,8 +206,50 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: depth. The main pattern may contain "|" pipe symbols for pipeline stage boundaries. + A matching pair of "[" and "]" square brackets indicates that a fusion optimization should be + applied. Currently, this only works for an Attention/sequence mixer variant followed by an MLP + variant, e.g., "*-" for Attention+MLP or "ME" for Mamba+MoE. The contents in the pairs may not + cross pipeline and MTP boundaries. + Format: "///..." + Semantics of layer counting in the presence of fusion: + + - A fused group `[XY]` counts as one (1) physical transformer block. Every + function in this module that returns a "number of layers" (such as + `get_hybrid_total_layer_count`, `count_pattern_layers`, or the length of + the list returned by `validate_segment_layers`) follows this convention. + Pipeline parallelism, VPP/fVPP offsets, and `config.num_layers` all use + the physical-block count. + - `get_hybrid_layer_counts` returns per-type sub-layer counts (it counts + each character inside `[...]` separately). This is the right count for + FLOPs or parameter budgeting – a fused `[*-]` does attention + MLP + compute, so `{"*": 1, "-": 1}` accurately reflects the work done. + + Pipeline-parallelism caveats when fusion is used: + + - Auto-split PP (a pipeless pattern with `pipeline_model_parallel_size > + 1`) slices `layer_type_list` into equal-sized physical-block chunks. + Ranks whose chunk lands on fused blocks do more compute per block than + ranks that get stand-alone blocks, which widens the PP bubble. Use + explicit "|" separators to place fused blocks by compute rather than by + block count. + - `--decoder-first-pipeline-num-layers` / `--decoder-last-pipeline-num-layers` + count physical blocks, not sub-layers. Fusing blocks on only the first + or last stage therefore silently skews compute relative to the middle + stages. + - fVPP with explicit "|" gives full control over per-segment layout, but + note that two patterns with the same sub-layer compute can have + *different* physical-block counts (e.g., `*-M*` = 4 blocks vs. `[*-]M*` + = 3 blocks). VPP scheduling is driven by physical-block segments, so + fused and unfused segments are not schedule-equivalent even when their + FLOPs match. + - Per-layer metrics (grad norms, router stats, loss contributions) are + keyed on `layer.layer_number`, which is the physical-block index. A + fused block's attention and MLP sub-layers share the same layer number, + so dashboards that separate metrics per sub-layer type will conflate + them inside a fused block. + Args: pattern: Unified pattern string, e.g., "M*M*/MM/MM" or just "M*M*" @@ -228,22 +272,30 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: >>> parse_hybrid_pattern("M-M-|M-M*-/MM/MM") ParsedHybridPattern(main_pattern="M-M-|M-M*-", mtp_pattern="MM", mtp_num_depths=2) + + >>> parse_hybrid_pattern("[M-][M-]|[M-]M[*-]/MM/MM") + ParsedHybridPattern(main_pattern="[M-][M-]|[M-]M[*-]", mtp_pattern="MM", mtp_num_depths=2) """ if pattern is None: return ParsedHybridPattern(main_pattern=None, mtp_pattern=None, mtp_num_depths=0) + # Validate bracket structure before splitting on '/', otherwise a fusion + # group that crosses the MTP boundary looks like two unrelated unmatched + # bracket errors. + _validate_brackets(pattern, "hybrid") + parts = pattern.split(Symbols.MTP_SEPARATOR) if len(parts) == 1: # No MTP separator found - pattern is main decoder only main_pattern = parts[0] - _validate_pattern(main_pattern, "main", allow_pipe=True) + _validate_pattern(main_pattern, "main", allow_pipe=True, allow_brackets=True) return ParsedHybridPattern(main_pattern=main_pattern, mtp_pattern=None, mtp_num_depths=0) # First part is main decoder pattern main_pattern = parts[0] if main_pattern: - _validate_pattern(main_pattern, "main", allow_pipe=True) + _validate_pattern(main_pattern, "main", allow_pipe=True, allow_brackets=True) # Remaining parts are MTP patterns (one per depth) mtp_parts = parts[1:] @@ -264,7 +316,7 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: f"Full pattern: '{pattern}'" ) - _validate_pattern(mtp_pattern, "MTP", allow_pipe=False) + _validate_pattern(mtp_pattern, "MTP", allow_pipe=False, allow_brackets=True) return ParsedHybridPattern( main_pattern=main_pattern if main_pattern else None, @@ -273,18 +325,27 @@ def parse_hybrid_pattern(pattern: Optional[str]) -> ParsedHybridPattern: ) -def _validate_pattern(pattern: str, pattern_name: str, allow_pipe: bool = False) -> None: +def _validate_pattern( + pattern: str, pattern_name: str, allow_pipe: bool = False, allow_brackets: bool = False +) -> None: """Validate that a pattern contains only valid layer symbols. Args: pattern: Layer pattern string to validate pattern_name: Name of pattern for error messages (e.g., "main" or "MTP") allow_pipe: Whether to allow the pipe '|' separator (for main patterns) + allow_brackets: Whether to allow fusion bracket markers '[' and ']' Raises: - ValueError: If pattern contains invalid symbols + ValueError: If pattern contains invalid symbols or brackets are malformed """ - valid_chars = Symbols.VALID_LAYERS | {Symbols.PIPE} if allow_pipe else Symbols.VALID_LAYERS + valid_chars = Symbols.VALID_LAYERS.copy() + if allow_pipe: + valid_chars.add(Symbols.PIPE) + if allow_brackets: + valid_chars.add(Symbols.FUSION_START) + valid_chars.add(Symbols.FUSION_END) + for char in pattern: if char not in valid_chars: raise ValueError( @@ -292,39 +353,300 @@ def _validate_pattern(pattern: str, pattern_name: str, allow_pipe: bool = False) f"Valid symbols are: {valid_chars}" ) + if allow_brackets: + _validate_brackets(pattern, pattern_name) + # Disallow Attention + MLA/DSA hybridity. if Symbols.ATTENTION in pattern and Symbols.DS_ATTENTION in pattern: raise ValueError("Not supported to have both Attention and MLA/DSA in one model") +def _validate_brackets(pattern: str, pattern_name: str) -> None: + """Validate that fusion brackets in `pattern` are well-formed. + + This only enforces syntactic well-formedness. Semantic constraints on + fusion groups (exactly two layers, sequence-mixer followed by channel-mixer) + are validated later, at the point where the fused layer is actually + constructed. + + Rules: + - Brackets must be balanced: every '[' must have a matching ']'. + - No nesting: '[' inside an open bracket group is invalid. + - No empty groups: '[]' with no layer symbols inside is invalid. + - A bracket group must contain at least 2 layer symbols (fusion of one + layer is meaningless). + - Bracket groups may not span across pipe ('|') or MTP ('/') boundaries. + + Args: + pattern: The pattern string (already validated for character legality). + pattern_name: Human-readable name for error messages. + + Raises: + ValueError: On any bracket violation. + """ + depth = 0 + group_layer_count = 0 + + for i, char in enumerate(pattern): + if char == Symbols.FUSION_START: + if depth > 0: + raise ValueError( + f"In {pattern_name} pattern, nested '[' at position {i} is not allowed. " + f"Pattern: '{pattern}'" + ) + depth += 1 + group_layer_count = 0 + elif char == Symbols.FUSION_END: + if depth <= 0: + raise ValueError( + f"In {pattern_name} pattern, unmatched ']' at position {i}. " + f"Pattern: '{pattern}'" + ) + if group_layer_count == 0: + raise ValueError( + f"In {pattern_name} pattern, empty fusion group '[]' at position {i}. " + f"Pattern: '{pattern}'" + ) + if group_layer_count < 2: + raise ValueError( + f"In {pattern_name} pattern, fusion group ending at position {i} " + f"contains only {group_layer_count} layer – need at least 2. " + f"Pattern: '{pattern}'" + ) + depth -= 1 + elif char == Symbols.PIPE: + if depth > 0: + raise ValueError( + f"In {pattern_name} pattern, pipe '|' at position {i} appears " + f"inside a fusion group '[...|...]'. Fusion groups may not cross " + f"pipeline boundaries. Pattern: '{pattern}'" + ) + elif char == Symbols.MTP_SEPARATOR: + if depth > 0: + raise ValueError( + f"In {pattern_name} pattern, MTP separator '/' at position {i} appears " + f"inside a fusion group '[.../...]'. Fusion groups may not cross " + f"multi-token prediction boundaries. Pattern: '{pattern}'" + ) + else: + # Must be a valid layer symbol. + if depth > 0: + group_layer_count += 1 + + if depth > 0: + raise ValueError( + f"In {pattern_name} pattern, unmatched '[' with no closing ']'. " + f"Pattern: '{pattern}'" + ) + # This should never happen with the above logic. + if depth < 0: + raise ValueError( + f"In {pattern_name} pattern, found more ']' than '[' brackets. Pattern: '{pattern}'" + ) + + +def strip_brackets(pattern: str) -> str: + """Remove fusion bracket markers from a pattern string. + + Returns the pattern with all '[' and ']' characters removed, leaving only + layer symbols and (if present) pipe and MTP separators. + + Args: + pattern: A pattern string that may contain fusion brackets. + + Returns: + The pattern with brackets stripped. + + Examples: + >>> strip_brackets("[*-]M[*-]M") + '*-M*-M' + >>> strip_brackets("M*M*") + 'M*M*' + """ + return pattern.replace(Symbols.FUSION_START, '').replace(Symbols.FUSION_END, '') + + +def count_pattern_layers(pattern: str) -> int: + """Count the number of physical layer blocks in a pattern. + + Each fusion group `[...]` counts as a single layer (because its sub-layers + are fused into one transformer block at runtime), regardless of how many + sub-layer symbols it contains. Pipe and MTP separators are ignored. + + Args: + pattern: A pattern string that may contain fusion brackets and/or + pipe/MTP separators. Assumed to have already passed bracket + validation. + + Returns: + Number of physical layer blocks the pattern represents. + + Examples: + >>> count_pattern_layers("M*M*") + 4 + >>> count_pattern_layers("[*-]M[*-]M") # 2 fused + 2 mamba + 4 + >>> count_pattern_layers("[*-]M|[*-]M") + 4 + """ + count = 0 + in_group = False + for char in pattern: + if char == Symbols.FUSION_START: + in_group = True + count += 1 # whole group counts as one block + elif char == Symbols.FUSION_END: + in_group = False + elif char in (Symbols.PIPE, Symbols.MTP_SEPARATOR): + continue + elif not in_group and char in Symbols.VALID_LAYERS: + count += 1 + return count + + +def get_sub_layer_offset(main_pattern: str, physical_offset: int) -> int: + """Count sub-layers corresponding to the first `physical_offset` physical blocks. + + A fused group `[XY]` counts as one physical block but contains `len(XY)` + sub-layers; a stand-alone symbol counts as one of each. Pipe '|' separators + are ignored – sub-layer indices are global across pipeline segments, so + this helper returns the value `HybridStack` needs to emit checkpoint keys + in the canonical unfused layout regardless of fusion placement. + + This is the sub-layer analogue of the physical-block offset returned by + `select_pipeline_segment`. + + Args: + main_pattern: Main decoder pattern (may contain '|' and '[...]'). + Assumed already validated (`_validate_pattern` with + `allow_pipe=True, allow_brackets=True`). + physical_offset: Number of physical blocks preceding the point of + interest (e.g., the value returned by `select_pipeline_segment` + for the current pipeline segment). Must be non-negative. + + Returns: + Sub-layer count across the `physical_offset` earliest physical blocks. + + Examples: + >>> get_sub_layer_offset("M*M*", 0) + 0 + >>> get_sub_layer_offset("M*M*", 2) + 2 + >>> get_sub_layer_offset("[M-]M", 1) + 2 + >>> get_sub_layer_offset("[M-]M", 2) + 3 + >>> get_sub_layer_offset("M|[*-]", 2) + 3 + """ + if physical_offset <= 0: + return 0 + + sub_count = 0 + physical_count = 0 + in_group = False + group_sub_count = 0 + + for char in main_pattern: + if physical_count >= physical_offset: + break + if char == Symbols.FUSION_START: + in_group = True + group_sub_count = 0 + elif char == Symbols.FUSION_END: + in_group = False + sub_count += group_sub_count + physical_count += 1 + elif char == Symbols.PIPE: + continue + elif in_group: + group_sub_count += 1 + elif char in Symbols.VALID_LAYERS: + sub_count += 1 + physical_count += 1 + return sub_count + + +def parse_fusion_groups(pattern: str) -> List[Tuple[int, int]]: + """Extract fusion groups from a bracket-annotated pattern string. + + Each `[...]` group maps to a `(start, end)` tuple of layer indices + (inclusive) in the *bracket-stripped* pattern. + + Args: + pattern: A single segment pattern (no pipe separators) that may + contain fusion brackets. + + Returns: + List of `(start_layer_index, end_layer_index)` tuples (inclusive). + Empty list when the pattern has no brackets. + + Examples: + >>> parse_fusion_groups("[*-]M[*-]M") + [(0, 1), (3, 4)] + >>> parse_fusion_groups("M*M*") + [] + """ + _validate_brackets(pattern, "fusion") + + groups: List[Tuple[int, int]] = [] + layer_index = 0 + group_start: Optional[int] = None + + for char in pattern: + if char == Symbols.FUSION_START: + group_start = layer_index + elif char == Symbols.FUSION_END: + assert group_start is not None + # end is inclusive – the last layer index in the group + groups.append((group_start, layer_index - 1)) + group_start = None + elif char in Symbols.VALID_LAYERS: + layer_index += 1 + + return groups + + def validate_segment_layers(segment: str) -> List[str]: """Validate and convert a single pipeline segment pattern to a layer type list. This is used after the main pattern has been split by '|' into segments. - Each segment should contain only valid layer symbols (no '|'). + Each segment should contain only valid layer symbols (no '|'). Fusion + bracket markers ('[', ']') are permitted and collapse their contents into + a single multi-character list entry: `"[*-]M"` becomes `['*-', 'M']`. + Each entry in the returned list therefore corresponds to exactly one + physical transformer block – a single-character entry is a stand-alone + layer, a multi-character entry is a fused layer where each character + names one fused sub-layer. Args: - segment: A single pipeline segment pattern string (e.g., "M-M*-") + segment: A single pipeline segment pattern string (e.g., "[*-]M[*-]M") Returns: - List of layer type characters. + List of layer entries. Each entry is a string of length 1 (normal + layer) or length ≥ 2 (fused layer, one character per fused sub-layer). Raises: - ValueError: If segment contains invalid layer symbols. + ValueError: If segment contains invalid layer symbols or malformed brackets. """ - layer_type_list = list(segment) - for layer_char in layer_type_list: - if layer_char not in Symbols.VALID_LAYERS: - raise ValueError( - f"In hybrid layer pattern segment, '{layer_char}' is not " - f"one of {Symbols.VALID_LAYERS}" - ) - - # Disallow Attention + MLA/DSA hybridity. - if Symbols.ATTENTION in segment and Symbols.DS_ATTENTION in segment: - raise ValueError("Not supported to have both Attention and MLA/DSA in one model") - - return layer_type_list + _validate_pattern(segment, "segment", allow_pipe=False, allow_brackets=True) + + result: List[str] = [] + group_chars: List[str] = [] + in_group = False + for char in segment: + if char == Symbols.FUSION_START: + in_group = True + group_chars.clear() + elif char == Symbols.FUSION_END: + in_group = False + result.append(''.join(group_chars)) + group_chars.clear() + elif in_group: + group_chars.append(char) + else: + result.append(char) + return result def select_pipeline_segment( @@ -397,6 +719,23 @@ def select_pipeline_segment( layer_type_list = validate_segment_layers(full_pattern) num_layers = len(layer_type_list) + # Auto-split PP counts physical blocks, but a fused block does more + # compute per block than a stand-alone one. A pipeline rank whose + # slice lands on fused blocks can become a straggler and widen the PP + # bubble. Warn once so users know to insert explicit '|' separators + # if they need compute-aware balancing. + if any(len(entry) > 1 for entry in layer_type_list): + log_single_rank( + logger, + logging.WARNING, + "Auto-split PP on a hybrid pattern that contains fusion groups " + "'[...]' may produce imbalanced compute across ranks: fused " + "layers do more work per physical block than stand-alone " + "ones, so ranks whose slice lands on fused blocks become " + "stragglers. Consider inserting explicit '|' separators in " + "--hybrid-layer-pattern to control per-stage compute.", + ) + if first_stage_layers is not None or last_stage_layers is not None: first = first_stage_layers or 0 last = last_stage_layers or 0 @@ -467,7 +806,7 @@ def select_pipeline_segment( f"the current PP/VPP configuration." ) - layer_offset = sum(len(segments[i]) for i in range(segment_index)) + layer_offset = sum(count_pattern_layers(segments[i]) for i in range(segment_index)) my_segment = segments[segment_index] layer_type_list = validate_segment_layers(my_segment) @@ -488,11 +827,18 @@ def get_layer_maps_from_layer_type_list(layer_type_list: list[str]) -> dict[str, """ Returns maps from global layer index to the corresponding layer index for each valid layer type (those in Symbols.VALID_LAYERS) given a layer type list. + + Each entry of `layer_type_list` is expected to be a 1-character layer + symbol (a normal layer) or a multi-character string (a fused layer where + each character names one fused sub-layer). Fused sub-layers all share the + same global layer index – they are contained in one physical block – but + contribute to their respective per-type maps independently. """ layer_types = [symbol for symbol in Symbols.name_sorted_valid_layer_symbols()] layer_maps = {layer_type: {} for layer_type in layer_types} - for global_layer_idx, layer_type in enumerate(layer_type_list): - layer_map = layer_maps[layer_type] - local_layer_idx = len(layer_map) - layer_map[global_layer_idx] = local_layer_idx + for global_layer_idx, layer_entry in enumerate(layer_type_list): + for layer_type in layer_entry: + layer_map = layer_maps[layer_type] + local_layer_idx = len(layer_map) + layer_map[global_layer_idx] = local_layer_idx return layer_maps diff --git a/megatron/core/models/hybrid/hybrid_layer_fusion.py b/megatron/core/models/hybrid/hybrid_layer_fusion.py new file mode 100644 index 00000000000..84091c70bc5 --- /dev/null +++ b/megatron/core/models/hybrid/hybrid_layer_fusion.py @@ -0,0 +1,349 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Construction of fused hybrid-pattern layers. + +A hybrid layer pattern may contain bracketed groups like `[*-]` or `[ME]` +that instruct `HybridStack` to fuse two adjacent layers into a single +`TransformerLayer`. This module owns the logic that takes such a +multi-symbol entry and builds the corresponding `TransformerLayer` at +runtime – the sequence mixer (first symbol) becomes `self_attention` and +the channel mixer (second symbol) becomes `mlp`. + +The public entry point is `build_fused_layer`, called by +`megatron.core.models.hybrid.hybrid_block.HybridStack` when it +encounters a multi-character entry in its `layer_type_list`. +""" + +from typing import TYPE_CHECKING + +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import apply_prefix_mapping +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.mamba_mixer import MambaMixer +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +if TYPE_CHECKING: + # Avoid a circular import at runtime: `HybridStackSubmodules` lives in + # `hybrid_block` which imports from this module. + from megatron.core.models.hybrid.hybrid_block import HybridStackSubmodules + + +# Sequence mixers legal as the first element of a fusion group; maps the +# pattern symbol to the attribute on `HybridStackSubmodules` that supplies +# the primitive spec used in the `self_attention` slot of the fused +# `TransformerLayer`. +_FUSION_SEQUENCE_MIXERS: dict[str, str] = { + LayerSymbols.MAMBA: "mamba_mixer", + LayerSymbols.GDN: "gdn_mixer", + LayerSymbols.ATTENTION: "attention_mixer", + LayerSymbols.DS_ATTENTION: "dsa_mixer", +} + +# Channel mixers legal as the second element of a fusion group; maps the +# pattern symbol to the attribute on `HybridStackSubmodules` that supplies +# the primitive spec used in the `mlp` slot of the fused `TransformerLayer`. +_FUSION_CHANNEL_MIXERS: dict[str, str] = { + LayerSymbols.MLP: "mlp_mixer", + LayerSymbols.MOE: "moe_mixer", +} + + +class MambaMixerForTransformerLayer(MambaMixer): + """`MambaMixer` adapted for use as `TransformerLayer.self_attention`. + + `MambaMixer` isn't quite drop-in for TransformerLayer's self-attention + slot: its constructor requires a `d_model` kwarg that TransformerLayer + does not forward, and its `forward` only accepts `hidden_states, + inference_context, packed_seq_params` while `TransformerLayer.forward` + hands its `self_attention` a richer set (`attention_mask`, + `rotary_pos_*`, `attention_bias`, `sequence_len_offset`). + + This subclass bridges both gaps without touching the underlying Mamba + mechanism: + + - `__init__` defaults `d_model` to `config.hidden_size` when the caller + didn't supply it; an explicit `d_model` still wins. + - `forward` accepts every kwarg that `TransformerLayer.forward` + currently emits (plus a `**_unused_future_kwargs` sink for forward + compatibility), but only forwards the three that `MambaMixer` uses. + + Stand-alone `MambaLayer` paths continue to use the plain `MambaMixer` + class – this wrapper only sits in the fusion code path. + """ + + def __init__(self, config, submodules, **kwargs): + # TransformerLayer does not forward `d_model` when constructing + # `self_attention`; fall back to the model hidden size. + kwargs.setdefault("d_model", config.hidden_size) + super().__init__(config, submodules, **kwargs) + + def forward( + self, + hidden_states, + # kwargs MambaMixer actually uses: + inference_context=None, + packed_seq_params=None, + *, + inference_params=None, + # kwargs TransformerLayer.forward hands to self_attention that + # MambaMixer does not need – absorbed and ignored. Listed + # explicitly so the contract with TransformerLayer is visible here. + attention_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + rotary_pos_cos_sin=None, + attention_bias=None, + sequence_len_offset=None, + **_unused_future_kwargs, + ): + """Dispatch to `MambaMixer.forward`, discarding unused arguments.""" + return super().forward( + hidden_states, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + inference_params=inference_params, + ) + + +def build_fused_layer( + fused_symbols: str, + submodules: "HybridStackSubmodules", + config: TransformerConfig, + layer_number: int, + pg_collection: ProcessGroupCollection, + pp_layer_offset: int, + is_mtp_layer: bool, + add_layer_offset: bool, +): + """Build a single fused TransformerLayer from a multi-symbol pattern entry. + + A pattern like `[*-]` is passed in as `fused_symbols="*-"` and becomes + a standard `TransformerLayer` whose `self_attention` slot is the + sequence-mixer primitive for the first symbol and whose `mlp` slot is + the channel-mixer primitive for the second. + + Not every primitive has the same `__init__` signature. `TransformerLayer` + forwards only a small, fixed set of kwargs to its `self_attention` slot + (`config, layer_number, pg_collection, pp_layer_offset, cp_comm_type`), + so we pick appropriate kwargs to pass to the outer `TransformerLayer` + based on what the inner sequence mixer will actually accept, mirroring + the existing per-primitive dispatch in `HybridStack`. + + `MambaMixer`'s signature mismatches (needs `d_model`, forward rejects + `TransformerLayer`'s extra kwargs) are resolved by + `MambaMixerForTransformerLayer`, a thin `MambaMixer` subclass wired up + as the `mamba_mixer` entry in `hybrid_stack_spec.submodules`. + + Args: + fused_symbols: The fused group contents (e.g., "*-", "ME"). + submodules: The `HybridStackSubmodules` carrying the primitive + mixer specs to draw from. + config: The shared `TransformerConfig`. + layer_number: 1-indexed global layer number for this block + (`pp_layer_offset` already added by the caller). + pg_collection: Process-group collection. + pp_layer_offset: Offset to add to layer numbers for pipeline stages. + is_mtp_layer: Whether this block sits in an MTP stage. + add_layer_offset: Whether the enclosing `TransformerLayer` should + add its own pipeline offset to `layer_number`. Callers inside + `HybridStack` have already included the offset, so they + pass `False` here – matching the existing stand-alone dispatches. + + Raises: + ValueError: If the fused group is not exactly two symbols, or if the + first symbol is not a sequence mixer, or if the second symbol is + not a channel mixer. + """ + if len(fused_symbols) != 2: + raise ValueError( + f"Hybrid-layer fusion currently supports exactly two fused layers, " + f"but got {len(fused_symbols)} in group '[{fused_symbols}]'. The " + f"first must be a sequence mixer " + f"({sorted(_FUSION_SEQUENCE_MIXERS)}) and the second must be a " + f"channel mixer ({sorted(_FUSION_CHANNEL_MIXERS)})." + ) + + seq_sym, chan_sym = fused_symbols[0], fused_symbols[1] + if seq_sym not in _FUSION_SEQUENCE_MIXERS: + raise ValueError( + f"Hybrid-layer fusion requires the first fused layer to be a " + f"sequence mixer (one of {sorted(_FUSION_SEQUENCE_MIXERS)}), but " + f"got '{seq_sym}' in group '[{fused_symbols}]'." + ) + if chan_sym not in _FUSION_CHANNEL_MIXERS: + raise ValueError( + f"Hybrid-layer fusion requires the second fused layer to be a " + f"channel mixer (one of {sorted(_FUSION_CHANNEL_MIXERS)}), but got " + f"'{chan_sym}' in group '[{fused_symbols}]'." + ) + + self_attention = getattr(submodules, _FUSION_SEQUENCE_MIXERS[seq_sym]) + mlp = getattr(submodules, _FUSION_CHANNEL_MIXERS[chan_sym]) + + # Norms that are not already fused into a primitive's linear layer need + # to be supplied externally by the enclosing TransformerLayer. Currently + # that is only DSA (input layernorm) and MoE (pre-MLP layernorm). + input_layernorm = TENorm if seq_sym == LayerSymbols.DS_ATTENTION else IdentityOp + pre_mlp_layernorm = TENorm if chan_sym == LayerSymbols.MOE else IdentityOp + + fused_spec = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=input_layernorm, + self_attention=self_attention, + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + # Build kwargs that match the inner sequence mixer's signature + # + # TransformerLayer forwards these to `self_attention` (minus the ones + # TransformerLayer consumes itself: `is_mtp_layer`, `add_layer_offset`). + # The stand-alone dispatches in HybridStack pick which optional kwargs to + # pass based on the primitive – we do the same here. + build_kwargs: dict = dict( + config=config, + layer_number=layer_number, + pg_collection=pg_collection, + is_mtp_layer=is_mtp_layer, + add_layer_offset=add_layer_offset, + ) + # GatedDeltaNet.__init__ does not accept `pp_layer_offset`; SelfAttention, + # MLASelfAttention, and MambaMixer do. Mirror the existing stand-alone + # dispatches by conditionally including it. + if seq_sym != LayerSymbols.GDN: + build_kwargs["pp_layer_offset"] = pp_layer_offset + + return build_module(fused_spec, **build_kwargs) + + +# Canonical slot name each layer symbol uses in its stand-alone +# sharded-state-dict output – i.e., the attribute path under which the +# primitive's weights live when the block contains just that symbol. For a +# fused `[XY]` block, the same primitives sit under `self_attention` (for the +# sequence mixer X) and `mlp` (for the channel mixer Y); canonicalization +# uses this table to rewrite fused keys back into the stand-alone layout so +# fused and unfused patterns produce the same checkpoint keys. Only Mamba +# needs an intra-block rename (`self_attention.` -> `mixer.`); every other +# sequence mixer is stand-alone-hosted in a TransformerLayer whose +# `self_attention` slot already matches the fused layout. +_CANONICAL_SLOT_FOR_SYMBOL: dict[str, str] = { + LayerSymbols.MAMBA: "mixer", + LayerSymbols.GDN: "self_attention", + LayerSymbols.ATTENTION: "self_attention", + LayerSymbols.DS_ATTENTION: "self_attention", + LayerSymbols.MLP: "mlp", + LayerSymbols.MOE: "mlp", +} + + +def canonicalize_hybrid_sharded_state_dict( + sharded_state_dict: ShardedStateDict, + layer_prefix: str, + layer_type_list: list[str], + physical_offset: int = 0, + sub_layer_offset: int = 0, +) -> None: + """Rewrite HybridStack layer keys into the canonical (unfused) layout, in place. + + `HybridStack.sharded_state_dict` emits keys indexed by global physical + block position within the model (a fused `[XY]` group still occupies a + single physical block). Fused blocks are realized as `TransformerLayer`s + whose `self_attention` slot holds the sequence mixer and `mlp` slot + holds the channel mixer, so their keys do not match what a stand-alone + `X` followed by stand-alone `Y` would produce. This function rewrites + each fused block's keys into two sub-layer-indexed prefixes that + do match: `layers.{sub_layer_offset + i}.mixer.*` for mamba sub-layers, + `layers.{sub_layer_offset + i}.mlp.*` for MLP, etc. Stand-alone blocks + are simply re-indexed from physical to sub-layer index. + + The resulting keys are fusion-independent – a checkpoint written with + `[*-]M` and one written with `*-M` end up with the same set of keys, so + the dist_checkpointing layer can load either into either. + + Args: + sharded_state_dict: The sharded state dict to rewrite in place. Only + entries whose keys start with `layer_prefix` are touched. + layer_prefix: The full prefix up to and including `"layers."`, e.g. + `"decoder.layers."`. Keys outside this prefix are left alone. + layer_type_list: The per-physical-block layer-type symbols for this + pipeline segment (a single `"M"`, `"*"`, etc. for a stand-alone + block; a two-char string like `"*-"` for a fused group). + physical_offset: The global physical-block index at which this + pipeline segment starts (i.e., the value `HybridStack` uses to + derive each layer's `layer_number`). Defaults to 0, which is + correct for non-pipeline-parallel runs. + sub_layer_offset: The global sub-layer index at which this pipeline + segment starts. Accounts for sub-layers contributed by earlier + pipeline segments so that fused groups in those segments are + correctly counted. When the model is not pipeline-parallel (or + no earlier segment contains a fusion group), this is equal to + `physical_offset` and may be left at its default. + + Notes: + - Build up a combined prefix map across all layers and apply it in + one `apply_prefix_mapping` pass. This keeps the rewrite narrow + (sibling keys outside `layer_prefix` are untouched) and lets the + function safely run on the full model state dict without tripping + on embedding or output-layer entries. + - Order matters inside the combined prefix map: `apply_prefix_mapping` + picks the first matching prefix, so the specific sub-prefixes + (e.g. `"input_layernorm."`) are inserted before the bare block + prefix fallback. + - Norms attached to the X sub-layer (e.g. `input_layernorm` for DSA) + stay with X's sub-layer index; norms attached to Y (e.g. + `pre_mlp_layernorm` for MoE) attach to Y's sub-layer index. + """ + prefix_map: dict[str, str] = {} + sub_layer_cursor = sub_layer_offset + + for local_layer_idx, layer_type in enumerate(layer_type_list): + physical_prefix = f'{layer_prefix}{physical_offset + local_layer_idx}.' + + if len(layer_type) == 1: + # Stand-alone block: one physical block == one sub-layer, and the + # block's attribute layout already matches the canonical + # stand-alone layout. Only the outer block index needs to move + # from the local module-list index to the global sub-layer index. + canonical_prefix = f'{layer_prefix}{sub_layer_cursor}.' + prefix_map[physical_prefix] = canonical_prefix + sub_layer_cursor += 1 + else: + # Fused block `[XY]`: split the single physical block's keys into + # two sub-layer prefixes so the checkpoint looks exactly as it + # would for stand-alone `X` followed by stand-alone `Y`. + x_sym, y_sym = layer_type[0], layer_type[1] + canonical_x_prefix = f'{layer_prefix}{sub_layer_cursor}.' + canonical_y_prefix = f'{layer_prefix}{sub_layer_cursor + 1}.' + slot_for_x = _CANONICAL_SLOT_FOR_SYMBOL[x_sym] + slot_for_y = _CANONICAL_SLOT_FOR_SYMBOL[y_sym] + + # Specific sub-prefixes before the bare block prefix fallback. + prefix_map[f'{physical_prefix}input_layernorm.'] = ( + f'{canonical_x_prefix}input_layernorm.' + ) + prefix_map[f'{physical_prefix}self_attention.'] = f'{canonical_x_prefix}{slot_for_x}.' + prefix_map[f'{physical_prefix}self_attn_bda.'] = f'{canonical_x_prefix}self_attn_bda.' + prefix_map[f'{physical_prefix}pre_mlp_layernorm.'] = ( + f'{canonical_y_prefix}pre_mlp_layernorm.' + ) + prefix_map[f'{physical_prefix}mlp.'] = f'{canonical_y_prefix}{slot_for_y}.' + prefix_map[f'{physical_prefix}mlp_bda.'] = f'{canonical_y_prefix}mlp_bda.' + # Fallback for any stray top-level fused-block keys (e.g., + # `_extra_state` attached to the TransformerLayer itself); + # attach them to X's sub-layer index by convention. + prefix_map[physical_prefix] = canonical_x_prefix + sub_layer_cursor += 2 + + if prefix_map: + apply_prefix_mapping(sharded_state_dict, prefix_map) diff --git a/megatron/core/models/hybrid/hybrid_layer_specs.py b/megatron/core/models/hybrid/hybrid_layer_specs.py index a34a45a32ba..fe70914c489 100755 --- a/megatron/core/models/hybrid/hybrid_layer_specs.py +++ b/megatron/core/models/hybrid/hybrid_layer_specs.py @@ -14,6 +14,7 @@ get_moe_module_spec, ) from megatron.core.models.hybrid.hybrid_block import HybridStack, HybridStackSubmodules +from megatron.core.models.hybrid.hybrid_layer_fusion import MambaMixerForTransformerLayer from megatron.core.ssm.gated_delta_net import GatedDeltaNet, GatedDeltaNetSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules @@ -61,6 +62,154 @@ moe_inference = get_inference_optimized_moe_spec() +# Primitive (non-block) specs used when building a fused TransformerLayer at +# runtime. Each primitive is the exact same `ModuleSpec` that already sits +# inside the corresponding stand-alone block spec below – defining them once +# up here lets both places share a single source of truth and lets +# `hybrid_block.py` assemble a `TransformerLayer` on demand without knowing +# the internal wiring of any individual primitive. + +# Sequence mixers (fill the `self_attention` slot of a `TransformerLayer`). +_mamba_mixer_spec = ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), +) + +# Fusion-slot variant: same submodules as `_mamba_mixer_spec`, but the +# top-level module is the `MambaMixerForTransformerLayer` subclass that +# adapts `__init__` and `forward` signatures to what TransformerLayer expects. +_mamba_mixer_fusion_spec = ModuleSpec( + module=MambaMixerForTransformerLayer, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), +) + +_gdn_mixer_spec = ModuleSpec( + module=GatedDeltaNet, + submodules=GatedDeltaNetSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_norm=TENorm, out_proj=TERowParallelLinear + ), +) + +_attention_mixer_spec = ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), +) + +_dsa_mixer_spec = ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=TEColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TEColumnParallelLinear, + core_attention=ModuleSpec( + module=DSAttention, + submodules=DSAttentionSubmodules( + indexer=ModuleSpec( + module=DSAIndexer, + submodules=DSAIndexerSubmodules( + linear_wq_b=TELinear, + linear_wk=TELinear, + k_norm=TENorm, + linear_weights_proj=TELinear, + ), + ) + ), + ), + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), +) + +# Channel mixers (fill the `mlp` slot of a `TransformerLayer`). +_mlp_mixer_spec = ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), +) + +_moe_mixer_spec = moe + +# Inference-variant primitives – identical shape, but route through the +# inference-optimised linear classes where the training specs use TE ones. +_mamba_mixer_inference_spec = ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=InferenceLayerNormColumnParallelLinear, out_proj=InferenceRowParallelLinear + ), +) + +# Fusion-slot inference variant: same as `_mamba_mixer_inference_spec` but +# with the TransformerLayer-adapted subclass as the top-level module. +_mamba_mixer_inference_fusion_spec = ModuleSpec( + module=MambaMixerForTransformerLayer, + submodules=MambaMixerSubmodules( + in_proj=InferenceLayerNormColumnParallelLinear, out_proj=InferenceRowParallelLinear + ), +) + +_attention_mixer_inference_spec = ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=InferenceLayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=InferenceRowParallelLinear, + ), +) + +_dsa_mixer_inference_spec = ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=TEColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TEColumnParallelLinear, + core_attention=ModuleSpec( + module=DSAttention, + submodules=DSAttentionSubmodules( + indexer=ModuleSpec( + module=DSAIndexer, + submodules=DSAIndexerSubmodules( + linear_wq_b=TELinear, + linear_wk=TELinear, + k_norm=TENorm, + linear_weights_proj=TELinear, + ), + ) + ), + ), + linear_proj=InferenceRowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), +) + +_mlp_mixer_inference_spec = ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=InferenceLayerNormColumnParallelLinear, linear_fc2=InferenceRowParallelLinear + ), +) + +_moe_mixer_inference_spec = moe_inference + + # MTP block spec - provides norms and projection only. # Inner layers are built by MultiTokenPredictionLayer using nested HybridStack _hybrid_mtp_block_spec = ModuleSpec( @@ -88,27 +237,13 @@ mamba_layer=ModuleSpec( module=MambaLayer, submodules=MambaLayerSubmodules( - mixer=ModuleSpec( - module=MambaMixer, - submodules=MambaMixerSubmodules( - in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear - ), - ), - mamba_bda=get_bias_dropout_add, + mixer=_mamba_mixer_spec, mamba_bda=get_bias_dropout_add ), ), gdn_layer=ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=GatedDeltaNet, - submodules=GatedDeltaNetSubmodules( - in_proj=TELayerNormColumnParallelLinear, - out_norm=TENorm, - out_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, + self_attention=_gdn_mixer_spec, self_attn_bda=get_bias_dropout_add ), ), # Started with spec from gpt_layer_specs.py (with MLP removed) @@ -117,50 +252,14 @@ attention_layer=ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, + self_attention=_attention_mixer_spec, self_attn_bda=get_bias_dropout_add ), ), dsa_layer=ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( input_layernorm=TENorm, - self_attention=ModuleSpec( - module=MLASelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=MLASelfAttentionSubmodules( - linear_q_proj=TEColumnParallelLinear, - linear_q_down_proj=TELinear, - linear_q_up_proj=TEColumnParallelLinear, - linear_kv_down_proj=TELinear, - linear_kv_up_proj=TEColumnParallelLinear, - core_attention=ModuleSpec( - module=DSAttention, - submodules=DSAttentionSubmodules( - indexer=ModuleSpec( - module=DSAIndexer, - submodules=DSAIndexerSubmodules( - linear_wq_b=TELinear, - linear_wk=TELinear, - k_norm=TENorm, - linear_weights_proj=TELinear, - ), - ) - ), - ), - linear_proj=TERowParallelLinear, - q_layernorm=IdentityOp, - kv_layernorm=IdentityOp, - ), - ), + self_attention=_dsa_mixer_spec, self_attn_bda=get_bias_dropout_add, ), ), @@ -170,22 +269,28 @@ mlp_layer=ModuleSpec( module=MLPLayer, submodules=TransformerLayerSubmodules( - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ), - mlp_bda=get_bias_dropout_add, + mlp=_mlp_mixer_spec, mlp_bda=get_bias_dropout_add ), ), moe_layer=ModuleSpec( module=MoETransformerLayer, submodules=TransformerLayerSubmodules( - pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + pre_mlp_layernorm=TENorm, mlp=_moe_mixer_spec, mlp_bda=get_bias_dropout_add ), ), mtp_block_spec=_hybrid_mtp_block_spec, + # Primitives reused when the hybrid-layer pattern fuses two adjacent + # layers via the `[XY]` syntax; see `hybrid_layer_fusion.build_fused_layer`. + # Sequence mixers. MambaMixer uses the TransformerLayer-adapter + # subclass so its __init__/forward signatures match what + # TransformerLayer emits in its `self_attention` slot. + mamba_mixer=_mamba_mixer_fusion_spec, + gdn_mixer=_gdn_mixer_spec, + attention_mixer=_attention_mixer_spec, + dsa_mixer=_dsa_mixer_spec, + # Channel mixers + mlp_mixer=_mlp_mixer_spec, + moe_mixer=_moe_mixer_spec, ), ) @@ -196,14 +301,7 @@ mamba_layer=ModuleSpec( module=MambaLayer, submodules=MambaLayerSubmodules( - mixer=ModuleSpec( - module=MambaMixer, - submodules=MambaMixerSubmodules( - in_proj=InferenceLayerNormColumnParallelLinear, - out_proj=InferenceRowParallelLinear, - ), - ), - mamba_bda=get_bias_dropout_add, + mixer=_mamba_mixer_inference_spec, mamba_bda=get_bias_dropout_add ), ), # Started with spec from gpt_layer_specs.py (with MLP removed) @@ -212,50 +310,14 @@ attention_layer=ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=InferenceLayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=InferenceRowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, + self_attention=_attention_mixer_inference_spec, self_attn_bda=get_bias_dropout_add ), ), dsa_layer=ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( input_layernorm=TENorm, - self_attention=ModuleSpec( - module=MLASelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=MLASelfAttentionSubmodules( - linear_q_proj=TEColumnParallelLinear, - linear_q_down_proj=TELinear, - linear_q_up_proj=TEColumnParallelLinear, - linear_kv_down_proj=TELinear, - linear_kv_up_proj=TEColumnParallelLinear, - core_attention=ModuleSpec( - module=DSAttention, - submodules=DSAttentionSubmodules( - indexer=ModuleSpec( - module=DSAIndexer, - submodules=DSAIndexerSubmodules( - linear_wq_b=TELinear, - linear_wk=TELinear, - k_norm=TENorm, - linear_weights_proj=TELinear, - ), - ) - ), - ), - linear_proj=InferenceRowParallelLinear, - q_layernorm=IdentityOp, - kv_layernorm=IdentityOp, - ), - ), + self_attention=_dsa_mixer_inference_spec, self_attn_bda=get_bias_dropout_add, ), ), @@ -265,21 +327,16 @@ mlp_layer=ModuleSpec( module=MLPLayer, submodules=TransformerLayerSubmodules( - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=InferenceLayerNormColumnParallelLinear, - linear_fc2=InferenceRowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, + mlp=_mlp_mixer_inference_spec, mlp_bda=get_bias_dropout_add ), ), moe_layer=ModuleSpec( # Use inference-optimized MoE layer for end-to-end CUDA graph support module=TransformerLayer, submodules=TransformerLayerSubmodules( - pre_mlp_layernorm=TENorm, mlp=moe_inference, mlp_bda=get_bias_dropout_add + pre_mlp_layernorm=TENorm, + mlp=_moe_mixer_inference_spec, + mlp_bda=get_bias_dropout_add, ), ), mtp_block_spec=ModuleSpec( @@ -299,6 +356,14 @@ ] ), ), + # Inference variants of the fusion primitives; GDN is intentionally + # omitted (the inference stack does not support GDN layers). MambaMixer + # uses the TransformerLayer-adapter subclass, same as the training spec. + mamba_mixer=_mamba_mixer_inference_fusion_spec, + attention_mixer=_attention_mixer_inference_spec, + dsa_mixer=_dsa_mixer_inference_spec, + mlp_mixer=_mlp_mixer_inference_spec, + moe_mixer=_moe_mixer_inference_spec, ), ) diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 88a97ec777f..1b090d2b5c6 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -177,6 +177,7 @@ def __init__( # Parse unified pattern to extract main and MTP components, and # determine the pipeline segment for this model instance. from megatron.core.models.hybrid.hybrid_layer_allocation import ( + get_sub_layer_offset, parse_hybrid_pattern, select_pipeline_segment, ) @@ -185,13 +186,24 @@ def __init__( self.mtp_pattern = parsed.mtp_pattern self.mtp_num_depths = parsed.mtp_num_depths + main_pattern = parsed.main_pattern or '' layer_type_list, layer_offset = select_pipeline_segment( - parsed.main_pattern or '', + main_pattern, self.pg_collection.pp, vp_stage, first_stage_layers=self.config.num_layers_in_first_pipeline_stage, last_stage_layers=self.config.num_layers_in_last_pipeline_stage, ) + # Read at checkpoint save/load time by + # `megatron.training.checkpointing._apply_hybrid_canonicalization_if_applicable`, + # which rewrites the decoder's sharded keys into a fusion-independent + # layout so checkpoints saved under one fusion configuration load + # cleanly under another. `_decoder_physical_offset` mirrors `layer_offset` + # (physical-block index where this pipeline segment starts); + # `_decoder_sub_layer_offset` is its sub-layer counterpart, counting + # each character of a fused `[XY]` group separately. + self._decoder_physical_offset = layer_offset + self._decoder_sub_layer_offset = get_sub_layer_offset(main_pattern, layer_offset) # Determine if MTP is needed (based on pattern parsing) self.mtp_process = ( diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5f1968dcc27..caba9229148 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -685,11 +685,18 @@ def validate_args(args, defaults={}): if args.hybrid_layer_pattern is not None: # Derive num_layers from pattern; hybrid_layer_pattern always overrides --num-layers when # both are present (e.g. when loading from checkpoint with --use-checkpoint-args). + # NOTE: `num_layers_in_pattern` counts physical transformer blocks: a fusion group + # `[XY]` in the pattern contributes 1 (one fused TransformerLayer), not 2. This is the + # right count for --pipeline-model-parallel-size divisibility and for sizing nn.Module + # containers, but it is not the sub-layer / compute-unit count: FLOPs or parameter + # budgeting should count sub-symbols (see `get_hybrid_layer_counts()` for per-type + # sub-layer counts). num_layers_in_pattern = get_hybrid_total_layer_count(args.hybrid_layer_pattern) if args.num_layers is not None and args.num_layers != num_layers_in_pattern: warn_rank_0( f'--hybrid-layer-pattern is set; ignoring --num-layers ({args.num_layers}) and ' - f'using the layer count derived from the pattern ({num_layers_in_pattern}).', + f'using the physical-block count derived from the pattern ' + f'({num_layers_in_pattern}). Fusion groups "[...]" each count as one block.', args.rank, ) args.num_layers = num_layers_in_pattern @@ -709,6 +716,25 @@ def validate_args(args, defaults={}): '--decoder-last-pipeline-num-layers should not be specified ' 'as the pipeline layout is explicitly defined.' ) + # Uneven PP + fusion: warn the user that first/last stage layer counts + # are measured in physical blocks, not compute units. A stage + # consisting of fused blocks does more work per block than one of + # stand-alone blocks, so a value picked against compute targets will + # under-count the actual load on the uneven stages. + pattern_has_fusion = Symbols.FUSION_START in args.hybrid_layer_pattern + if pattern_has_fusion and ( + args.decoder_first_pipeline_num_layers is not None + or args.decoder_last_pipeline_num_layers is not None + ): + warn_rank_0( + 'Using --decoder-first/last-pipeline-num-layers with a ' + '--hybrid-layer-pattern that contains fusion groups "[...]": ' + 'the arguments count physical blocks, not sub-layer compute ' + 'units, so fused stages will do more work per block than ' + 'stand-alone stages. Adjust the values to compensate if ' + 'balancing by compute rather than by block count.', + args.rank, + ) assert args.num_layers_per_virtual_pipeline_stage is None, ( '--num-layers-per-virtual-pipeline-stage should not be used with ' '--hybrid-layer-pattern. To specify virtual pipelining, describe a number of ' diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 1441a71518d..971e2554167 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -983,6 +983,53 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path) torch.save(dataloader_save_dict, data_state_save_path) +def _apply_hybrid_canonicalization_if_applicable(model, model_sd): + """Rewrite HybridModel decoder keys into the fusion-independent layout. + + A `HybridModel` whose `--hybrid-layer-pattern` contains fused groups + (e.g. `[*-]M[*-]M`) builds each fused group as a single + `TransformerLayer` in the decoder's `nn.ModuleList`. Its + `sharded_state_dict` therefore emits keys indexed by physical block + position and uses the `TransformerLayer` slot names + (`self_attention`, `mlp`). For the checkpoint on disk to be + fusion-independent – so a checkpoint saved under one fusion placement + can be loaded under another – those keys need to be rewritten into the + layout a stand-alone unfused pattern would produce (sub-layer + indexing, `mixer.*` for mamba sub-layers, etc.). That rewrite is a + serialization concern only: it happens here, at the boundary between + the in-memory model state and the on-disk checkpoint, and does not + leak into model construction or forward. + + No-op when `model` is not a `HybridModel` (or does not expose the + introspection attributes we need). + """ + from megatron.core.models.hybrid.hybrid_layer_fusion import ( + canonicalize_hybrid_sharded_state_dict, + ) + + unwrapped = unwrap_model(model) + # Duck-type on the HybridModel attributes we need; avoids importing + # HybridModel here (and therefore avoids the heavy downstream imports + # its module triggers). HybridModel populates these in `__init__`, so + # their presence is a reliable signal. + decoder = getattr(unwrapped, 'decoder', None) + if ( + decoder is None + or not hasattr(decoder, 'layer_type_list') + or not hasattr(unwrapped, '_decoder_sub_layer_offset') + or not hasattr(unwrapped, '_decoder_physical_offset') + ): + return + + canonicalize_hybrid_sharded_state_dict( + model_sd, + layer_prefix='decoder.layers.', + layer_type_list=decoder.layer_type_list, + physical_offset=unwrapped._decoder_physical_offset, + sub_layer_offset=unwrapped._decoder_sub_layer_offset, + ) + + def generate_state_dict( args, model, @@ -1016,6 +1063,17 @@ def generate_state_dict( } }) ) + # HybridModel with fused layer patterns emits decoder keys indexed + # by physical block position and with fused-TransformerLayer slot + # names (`self_attention`, `mlp`). Apply the fusion-independent + # canonicalization right here, at the serialization boundary, so + # that (a) the transformation is visible only in the state dict + # that flows into save/load and (b) the model itself stays + # unaware of checkpoint-specific key conventions. Runs on both + # the save path (via save_checkpoint) and the load path + # (via load_checkpoint, which also routes through generate_state_dict + # to build the expected state dict shape). + _apply_hybrid_canonicalization_if_applicable(model[i], model_sd) else: # torch, torch_dcp, fsdp_dtensor model_sd = model[i].state_dict_for_save_checkpoint() diff --git a/tests/unit_tests/ssm/test_hybrid_block.py b/tests/unit_tests/ssm/test_hybrid_block.py index 08bf7f2bc28..4e9cae6d017 100644 --- a/tests/unit_tests/ssm/test_hybrid_block.py +++ b/tests/unit_tests/ssm/test_hybrid_block.py @@ -103,6 +103,58 @@ def test_gpu_forward(self): assert output.shape[2] == block.config.hidden_size assert output.dtype == torch.float32 + @pytest.mark.parametrize( + ("layer_pattern", "expected_layer_type"), + [ + ( + Symbols.FUSION_START + Symbols.ATTENTION + Symbols.MLP + Symbols.FUSION_END, + Symbols.ATTENTION + Symbols.MLP, + ), + ( + Symbols.FUSION_START + Symbols.MAMBA + Symbols.MLP + Symbols.FUSION_END, + Symbols.MAMBA + Symbols.MLP, + ), + ], + ) + def test_fused_gpu_forward_backward(self, layer_pattern, expected_layer_type): + """Test CUDA forward+backward through a fused hybrid-pattern block.""" + block = self.get_mamba_block(layer_pattern) + assert block.layer_type_list == [expected_layer_type] + assert len(block.layers) == 1 + assert isinstance(block.layers[0], TransformerLayer) + + block.cuda() + block.train() + micro_batch_size = 2 + sequence_length = 32 + hidden_states = torch.randn( + (sequence_length, micro_batch_size, block.config.hidden_size), + device="cuda", + requires_grad=True, + ) + attention_mask = torch.ones( + (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool, device="cuda" + ) + + output = block(hidden_states, attention_mask=attention_mask) + assert output.shape == hidden_states.shape + assert output.dtype == torch.float32 + + loss = output.float().square().mean() + loss.backward() + + assert hidden_states.grad is not None + assert torch.isfinite(hidden_states.grad).all().item() + + grads = [ + param.grad + for param in block.parameters() + if param.requires_grad and param.grad is not None + ] + assert grads + assert all(torch.isfinite(grad).all().item() for grad in grads) + assert any(torch.count_nonzero(grad).item() > 0 for grad in grads) + def test_layer_types(self): """ Make sure that the layer types specified with layer_pattern @@ -190,3 +242,485 @@ def test_mixed_attention_and_dsa_layer_types(self): layer_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.DS_ATTENTION + Symbols.MAMBA with pytest.raises(ValueError): block = self.get_dsa_mamba_block(layer_pattern) + + +@pytest.mark.internal +class TestFusedLayerValidation: + """Unit tests for the construction-time validation of fused layers. + + These tests exercise the failure paths in `build_fused_layer` directly + so they don't need a process group or CUDA set up. The happy path is + covered by higher-level forward tests in `TestHybridBlock`. + """ + + def _call(self, fused_symbols: str): + from megatron.core.models.hybrid.hybrid_block import HybridStackSubmodules + from megatron.core.models.hybrid.hybrid_layer_fusion import build_fused_layer + + # The mixer attributes are only read for valid fusion groups; the + # defaults (IdentityOp) are never touched for the error paths here. + return build_fused_layer( + fused_symbols, + submodules=HybridStackSubmodules(), + config=None, + layer_number=1, + pg_collection=None, + pp_layer_offset=0, + is_mtp_layer=False, + add_layer_offset=False, + ) + + def test_single_layer_fusion_rejected(self): + with pytest.raises(ValueError, match="exactly two fused layers"): + self._call(Symbols.ATTENTION) + + def test_three_layer_fusion_rejected(self): + with pytest.raises(ValueError, match="exactly two fused layers"): + self._call(Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP) + + def test_channel_mixer_first_rejected(self): + # MLP followed by attention – wrong order. + with pytest.raises(ValueError, match="first fused layer.*sequence mixer"): + self._call(Symbols.MLP + Symbols.ATTENTION) + + def test_sequence_mixer_second_rejected(self): + # Attention followed by Mamba – second slot must be a channel mixer. + with pytest.raises(ValueError, match="second fused layer.*channel mixer"): + self._call(Symbols.ATTENTION + Symbols.MAMBA) + + def test_two_channel_mixers_rejected(self): + with pytest.raises(ValueError, match="first fused layer.*sequence mixer"): + self._call(Symbols.MLP + Symbols.MOE) + + +@pytest.mark.internal +class TestMambaMixerForTransformerLayer: + """Unit tests for the `MambaMixerForTransformerLayer` adapter subclass. + + These probe only the adapter behaviour (kwarg defaulting in `__init__`, + kwarg filtering in `forward`), so they don't need CUDA, process groups, + or even a real `MambaMixer` instance – we poke at the class via stubbed-out + base methods so the tests also work when `mamba-ssm` isn't installed. + """ + + def test_init_defaults_d_model_to_hidden_size(self): + from types import SimpleNamespace + + from megatron.core.models.hybrid.hybrid_layer_fusion import MambaMixerForTransformerLayer + from megatron.core.ssm import mamba_mixer as mm + + captured = {} + + def fake_init(self, config, submodules, **kwargs): + captured["config"] = config + captured["submodules"] = submodules + captured["kwargs"] = kwargs + + original = mm.MambaMixer.__init__ + mm.MambaMixer.__init__ = fake_init + try: + instance = object.__new__(MambaMixerForTransformerLayer) + # Simulate TransformerLayer's all-keyword call without `d_model`. + MambaMixerForTransformerLayer.__init__( + instance, + config=SimpleNamespace(hidden_size=128), + submodules="sub-sentinel", + layer_number=3, + pg_collection="pg-sentinel", + pp_layer_offset=0, + ) + assert captured["kwargs"]["d_model"] == 128 + assert captured["kwargs"]["layer_number"] == 3 + assert captured["kwargs"]["pg_collection"] == "pg-sentinel" + assert captured["kwargs"]["pp_layer_offset"] == 0 + finally: + mm.MambaMixer.__init__ = original + + def test_init_preserves_explicit_d_model(self): + from types import SimpleNamespace + + from megatron.core.models.hybrid.hybrid_layer_fusion import MambaMixerForTransformerLayer + from megatron.core.ssm import mamba_mixer as mm + + captured = {} + + def fake_init(self, config, submodules, **kwargs): + captured["kwargs"] = kwargs + + original = mm.MambaMixer.__init__ + mm.MambaMixer.__init__ = fake_init + try: + instance = object.__new__(MambaMixerForTransformerLayer) + MambaMixerForTransformerLayer.__init__( + instance, + config=SimpleNamespace(hidden_size=128), + submodules="sub-sentinel", + d_model=999, # caller's explicit value must win + ) + assert captured["kwargs"]["d_model"] == 999 + finally: + mm.MambaMixer.__init__ = original + + def test_forward_filters_transformer_layer_kwargs(self): + from megatron.core.models.hybrid.hybrid_layer_fusion import MambaMixerForTransformerLayer + from megatron.core.ssm import mamba_mixer as mm + + captured = {} + + def fake_forward( + self, + hidden_states, + inference_context=None, + *, + inference_params=None, + packed_seq_params=None, + ): + captured["hidden_states"] = hidden_states + captured["inference_context"] = inference_context + captured["inference_params"] = inference_params + captured["packed_seq_params"] = packed_seq_params + return ("out-sentinel", "bias-sentinel") + + original = mm.MambaMixer.forward + mm.MambaMixer.forward = fake_forward + try: + instance = object.__new__(MambaMixerForTransformerLayer) + # Simulate TransformerLayer.forward's call – passes a bunch of + # kwargs the base MambaMixer doesn't accept. The adapter must + # swallow them and forward only the three it uses. + out, bias = MambaMixerForTransformerLayer.forward( + instance, + "hidden-sentinel", + attention_mask="mask-sentinel", + inference_context="ctx-sentinel", + rotary_pos_emb="rot-sentinel", + rotary_pos_cos="cos-sentinel", + rotary_pos_sin="sin-sentinel", + rotary_pos_cos_sin="cos-sin-sentinel", + attention_bias="bias-sentinel", + packed_seq_params="packed-sentinel", + sequence_len_offset="offset-sentinel", + ) + assert out == "out-sentinel" + assert bias == "bias-sentinel" + assert captured["hidden_states"] == "hidden-sentinel" + assert captured["inference_context"] == "ctx-sentinel" + assert captured["packed_seq_params"] == "packed-sentinel" + assert captured["inference_params"] is None + finally: + mm.MambaMixer.forward = original + + def test_forward_accepts_unknown_future_kwargs(self): + """If TransformerLayer.forward grows a new kwarg, the adapter should + silently absorb it rather than breaking the fused build at runtime. + """ + from megatron.core.models.hybrid.hybrid_layer_fusion import MambaMixerForTransformerLayer + from megatron.core.ssm import mamba_mixer as mm + + def fake_forward( + self, + hidden_states, + inference_context=None, + *, + inference_params=None, + packed_seq_params=None, + ): + return ("out", None) + + original = mm.MambaMixer.forward + mm.MambaMixer.forward = fake_forward + try: + instance = object.__new__(MambaMixerForTransformerLayer) + # A brand-new kwarg we've never heard of must not raise. + out, _ = MambaMixerForTransformerLayer.forward( + instance, "hidden-sentinel", some_future_kwarg_added_in_2027=42 + ) + assert out == "out" + finally: + mm.MambaMixer.forward = original + + +@pytest.mark.internal +class TestMambaStateShapesWithFusion: + """Regression test: `HybridStack.mamba_state_shapes_per_request` must + recognise fused blocks whose sequence mixer is Mamba. Before the fix the + method only matched the bare `"M"` layer type, so a stack where every + Mamba was inside a `[M-]` / `[ME]` group returned `None` and + inference cache allocation silently broke. + """ + + def _make_stub_stack(self, layer_type_list, layers): + from megatron.core.models.hybrid.hybrid_block import HybridStack + + stub = object.__new__(HybridStack) + stub.layer_type_list = layer_type_list + stub.layers = layers + return stub + + def test_standalone_mamba_is_found(self): + from megatron.core.models.hybrid.hybrid_block import HybridStack + + class FakeMambaLayer: + def mamba_state_shapes_per_request(self): + return ("conv-shape", "ssm-shape") + + stub = self._make_stub_stack(["*", "M", "-"], ["attn", FakeMambaLayer(), "mlp"]) + assert HybridStack.mamba_state_shapes_per_request(stub) == ("conv-shape", "ssm-shape") + + def test_fused_mamba_is_found_via_self_attention(self): + from megatron.core.models.hybrid.hybrid_block import HybridStack + + class FakeMambaMixer: + def mamba_state_shapes_per_request(self): + return ("conv-shape", "ssm-shape") + + class FakeTransformerLayer: + self_attention = FakeMambaMixer() + + stub = self._make_stub_stack(["*", "M-"], ["attn", FakeTransformerLayer()]) + assert HybridStack.mamba_state_shapes_per_request(stub) == ("conv-shape", "ssm-shape") + + def test_standalone_wins_over_fused(self): + # When both forms are present, the loop returns the first match. + # The stand-alone "M" is earlier in this list, so its shapes win. + from megatron.core.models.hybrid.hybrid_block import HybridStack + + class FakeStandaloneMamba: + def mamba_state_shapes_per_request(self): + return "standalone" + + class FakeFusedMixer: + def mamba_state_shapes_per_request(self): + return "fused" + + class FakeTransformerLayer: + self_attention = FakeFusedMixer() + + stub = self._make_stub_stack(["M", "M-"], [FakeStandaloneMamba(), FakeTransformerLayer()]) + assert HybridStack.mamba_state_shapes_per_request(stub) == "standalone" + + def test_no_mamba_at_all_returns_none(self): + from megatron.core.models.hybrid.hybrid_block import HybridStack + + stub = self._make_stub_stack(["*", "-", "*-"], ["attn", "mlp", "fused-*-"]) + assert HybridStack.mamba_state_shapes_per_request(stub) is None + + def test_fused_me_is_found(self): + # "[ME]" – Mamba sequence mixer + MoE channel mixer, fused. + from megatron.core.models.hybrid.hybrid_block import HybridStack + + class FakeMambaMixer: + def mamba_state_shapes_per_request(self): + return "me-fused" + + class FakeTransformerLayer: + self_attention = FakeMambaMixer() + + stub = self._make_stub_stack(["ME"], [FakeTransformerLayer()]) + assert HybridStack.mamba_state_shapes_per_request(stub) == "me-fused" + + +@pytest.mark.internal +class TestCanonicalShardedStateDict: + """Tests for `canonicalize_hybrid_sharded_state_dict`. + + The helper rewrites checkpoint keys so a fused `[XY]` block looks + exactly as a stand-alone `X` followed by stand-alone `Y` would. This + makes checkpoints structurally equivalent across fusion placements, + so an unfused save can be loaded into a fused model (and vice versa) – + the dist_checkpointing layer never sees the difference. + + Tests are hermetic: they construct a sharded state dict from fake + layers whose `sharded_state_dict` methods return bare `ShardedObject`s + (no CUDA, no process group, no `nn.Module` init), then inspect the key + set after canonicalization. + """ + + def _make_sharded_object(self, key): + """Construct a minimal ShardedObject carrying just the key under test.""" + from megatron.core.dist_checkpointing.mapping import ShardedObject + + return ShardedObject(key=key, data=None, global_shape=(1,), global_offset=(0,)) + + def _fake_layer(self, sub_keys): + """Fake layer whose `sharded_state_dict` yields a ShardedObject per sub_key.""" + maker = self._make_sharded_object + + class FakeLayer: + def sharded_state_dict(self, prefix, sharded_pp_offset, metadata): + return {f"{prefix}{sub_key}": maker(f"{prefix}{sub_key}") for sub_key in sub_keys} + + return FakeLayer() + + def _make_stub_stack(self, layer_type_list, layers, sub_layer_offset=0): + # The canonicalization is a free function; the "stack" here is just + # a parameter bag tying the inputs together. + return { + "layer_type_list": layer_type_list, + "layers": layers, + "sub_layer_offset": sub_layer_offset, + } + + def _run(self, stub): + from megatron.core.models.hybrid.hybrid_layer_fusion import ( + canonicalize_hybrid_sharded_state_dict, + ) + + # Reproduce the input the canonicalization sees in real runs: + # `HybridStack.sharded_state_dict` outputs each layer's keys at + # `layers.{global_physical_idx}.*`. For a single, non-pipeline-parallel + # segment that index equals the local module-list index, which is + # what the fake layers below produce. + state_dict = {} + for local_idx, layer in enumerate(stub["layers"]): + state_dict.update(layer.sharded_state_dict(f"layers.{local_idx}.", [], None)) + canonicalize_hybrid_sharded_state_dict( + state_dict, + layer_prefix="layers.", + layer_type_list=stub["layer_type_list"], + sub_layer_offset=stub["sub_layer_offset"], + ) + return state_dict + + def _keys(self, sharded_state_dict): + return {v.key for v in sharded_state_dict.values()} + + def test_unfused_pattern_is_pure_index_passthrough(self): + # Stand-alone entries map 1:1 from local module-list index to global + # sub-layer index. With `sub_layer_offset=0` they are identity. + mamba = self._fake_layer(["mixer.weight"]) + attn = self._fake_layer(["self_attention.linear_qkv.weight"]) + stub = self._make_stub_stack(["M", "*"], [mamba, attn]) + assert self._keys(self._run(stub)) == { + "layers.0.mixer.weight", + "layers.1.self_attention.linear_qkv.weight", + } + + def test_fused_mamba_mlp_splits_and_renames(self): + # `[M-]`: TransformerLayer with self_attention=MambaMixer, mlp=MLP. + # Canonical: `layers.0.mixer.*` (renamed from self_attention) and + # `layers.1.mlp.*` (split into the next sub-layer index). + fused = self._fake_layer( + [ + "self_attention.in_proj.weight", + "self_attention.out_proj.weight", + "mlp.linear_fc1.weight", + "mlp.linear_fc2.weight", + ] + ) + stub = self._make_stub_stack(["M-"], [fused]) + assert self._keys(self._run(stub)) == { + "layers.0.mixer.in_proj.weight", + "layers.0.mixer.out_proj.weight", + "layers.1.mlp.linear_fc1.weight", + "layers.1.mlp.linear_fc2.weight", + } + + def test_fused_attention_mlp_splits_without_rename(self): + # `[*-]`: stand-alone attention also lives under `self_attention.*`, + # so only the outer block index needs to split – no intra-block rename. + fused = self._fake_layer(["self_attention.linear_qkv.weight", "mlp.linear_fc1.weight"]) + stub = self._make_stub_stack(["*-"], [fused]) + assert self._keys(self._run(stub)) == { + "layers.0.self_attention.linear_qkv.weight", + "layers.1.mlp.linear_fc1.weight", + } + + def test_fused_dsa_preserves_input_layernorm_on_x(self): + # `[D-]` fuses DSA (whose stand-alone block ships with a TENorm + # `input_layernorm`) with MLP. The norm belongs to the X sub-layer. + fused = self._fake_layer( + ["input_layernorm.weight", "self_attention.linear_proj.weight", "mlp.linear_fc1.weight"] + ) + stub = self._make_stub_stack(["D-"], [fused]) + assert self._keys(self._run(stub)) == { + "layers.0.input_layernorm.weight", + "layers.0.self_attention.linear_proj.weight", + "layers.1.mlp.linear_fc1.weight", + } + + def test_fused_moe_preserves_pre_mlp_layernorm_on_y(self): + # `[*E]` fuses attention with MoE (whose stand-alone block ships a + # TENorm `pre_mlp_layernorm`). That norm belongs to the Y sub-layer. + fused = self._fake_layer( + ["self_attention.linear_qkv.weight", "pre_mlp_layernorm.weight", "mlp.router.weight"] + ) + stub = self._make_stub_stack(["*E"], [fused]) + assert self._keys(self._run(stub)) == { + "layers.0.self_attention.linear_qkv.weight", + "layers.1.pre_mlp_layernorm.weight", + "layers.1.mlp.router.weight", + } + + def test_mixed_pattern_cursor_advances_per_sub_layer(self): + # `M[*-]M`: stand-alone M at sub 0, fused at sub 1+2, stand-alone M + # at sub 3. Physical indices 0, 1, 2 compress to sub 0, 1, 2, 3. + mamba0 = self._fake_layer(["mixer.A"]) + fused = self._fake_layer(["self_attention.Q", "mlp.W"]) + mamba1 = self._fake_layer(["mixer.B"]) + stub = self._make_stub_stack(["M", "*-", "M"], [mamba0, fused, mamba1]) + assert self._keys(self._run(stub)) == { + "layers.0.mixer.A", + "layers.1.self_attention.Q", + "layers.2.mlp.W", + "layers.3.mixer.B", + } + + def test_sub_layer_offset_is_added_to_every_entry(self): + # Second PP segment: earlier segments contributed 4 sub-layers. + mamba = self._fake_layer(["mixer.A"]) + fused = self._fake_layer(["self_attention.Q", "mlp.W"]) + stub = self._make_stub_stack(["M", "M-"], [mamba, fused], sub_layer_offset=4) + assert self._keys(self._run(stub)) == { + "layers.4.mixer.A", + "layers.5.mixer.Q", + "layers.6.mlp.W", + } + + def test_unfused_and_fused_produce_identical_keys(self): + # The central compatibility guarantee: a pattern containing fused + # groups and its bracket-stripped twin, fed the same sub-layer + # contents, yield identical sharded keys. That is what makes a + # checkpoint saved unfused loadable into a fused model and vice versa. + unfused_mamba = self._fake_layer(["mixer.W"]) + unfused_mlp = self._fake_layer(["mlp.W"]) + unfused_stub = self._make_stub_stack(["M", "-"], [unfused_mamba, unfused_mlp]) + + fused = self._fake_layer(["self_attention.W", "mlp.W"]) + fused_stub = self._make_stub_stack(["M-"], [fused]) + + assert self._keys(self._run(unfused_stub)) == self._keys(self._run(fused_stub)) + + def test_fused_and_another_fusion_layout_produce_identical_keys(self): + # Fused-to-fused with different placement: `[M-][M-]` vs `M[-M]-`... + # but that's not valid (-M is sequence-mixer-second). Use a + # three-layer equivalent: `[M-]M-` vs `M-[M-]`. Sub-layer sequence is + # M, -, M, - in both cases. + left_fused = self._fake_layer(["self_attention.A", "mlp.B"]) + left_mamba = self._fake_layer(["mixer.C"]) + left_mlp = self._fake_layer(["mlp.D"]) + left_stub = self._make_stub_stack(["M-", "M", "-"], [left_fused, left_mamba, left_mlp]) + + right_mamba = self._fake_layer(["mixer.A"]) + right_mlp = self._fake_layer(["mlp.B"]) + right_fused = self._fake_layer(["self_attention.C", "mlp.D"]) + right_stub = self._make_stub_stack(["M", "-", "M-"], [right_mamba, right_mlp, right_fused]) + + assert self._keys(self._run(left_stub)) == self._keys(self._run(right_stub)) + + def test_stray_top_level_block_key_falls_back_to_x(self): + # A bare top-level key on a fused block (e.g., a hypothetical + # `_extra_state` attached to the TransformerLayer itself) falls back + # to the X sub-layer's index, not the Y one. + fused = self._fake_layer( + [ + "_extra_state", # stray top-level key + "self_attention.linear_qkv.weight", + "mlp.linear_fc1.weight", + ] + ) + stub = self._make_stub_stack(["*-"], [fused]) + keys = self._keys(self._run(stub)) + assert "layers.0._extra_state" in keys + assert "layers.0.self_attention.linear_qkv.weight" in keys + assert "layers.1.mlp.linear_fc1.weight" in keys diff --git a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py index fe0d7c2dc1e..0e73b7c317e 100644 --- a/tests/unit_tests/ssm/test_hybrid_layer_allocation.py +++ b/tests/unit_tests/ssm/test_hybrid_layer_allocation.py @@ -12,9 +12,12 @@ get_hybrid_total_layer_count, get_hybrid_total_pipeline_segment_count, get_layer_maps_from_layer_type_list, + get_sub_layer_offset, + parse_fusion_groups, parse_hybrid_pattern, pattern_from_ratios, select_pipeline_segment, + strip_brackets, validate_segment_layers, ) @@ -688,3 +691,231 @@ def test_all_mamba(self): assert mamba_map == {0: 0, 1: 1, 2: 2} assert mlp_map == {} assert moe_map == {} + + def test_fused_entries(self): + """Fused multi-char entries contribute to each sub-type's map at the + same physical (global) layer index. + """ + # Physical layer 0 is fused attention+MLP; layer 1 is stand-alone mamba. + maps = get_layer_maps_from_layer_type_list(["*-", "M"]) + attention_map, mamba_map, mlp_map = operator.itemgetter( + Symbols.ATTENTION, Symbols.MAMBA, Symbols.MLP + )(maps) + assert attention_map == {0: 0} + assert mlp_map == {0: 0} + assert mamba_map == {1: 0} + + # Two fused blocks interleaved with stand-alone mamba. + maps = get_layer_maps_from_layer_type_list(["*-", "M", "*-", "M"]) + attention_map, mamba_map, mlp_map = operator.itemgetter( + Symbols.ATTENTION, Symbols.MAMBA, Symbols.MLP + )(maps) + assert attention_map == {0: 0, 2: 1} + assert mlp_map == {0: 0, 2: 1} + assert mamba_map == {1: 0, 3: 1} + + +@pytest.mark.internal +class TestStripBrackets: + + def test_no_brackets(self): + assert strip_brackets("M*M*") == "M*M*" + + def test_with_brackets(self): + assert strip_brackets("[*-]M[*-]M") == "*-M*-M" + + def test_with_brackets_and_pipes(self): + assert strip_brackets("[*-]M|[*-]M") == "*-M|*-M" + + def test_empty(self): + assert strip_brackets("") == "" + + def test_all_fused(self): + assert strip_brackets("[*-*-]") == "*-*-" + + +@pytest.mark.internal +class TestParseFusionGroups: + + def test_no_brackets(self): + assert parse_fusion_groups("M*M*") == [] + + def test_single_group(self): + assert parse_fusion_groups("[*-]MM") == [(0, 1)] + + def test_two_groups(self): + assert parse_fusion_groups("[*-]M[*-]M") == [(0, 1), (3, 4)] + + def test_group_at_end(self): + assert parse_fusion_groups("MM[*-]") == [(2, 3)] + + def test_three_layer_group(self): + assert parse_fusion_groups("[M*-]M") == [(0, 2)] + + def test_all_fused(self): + assert parse_fusion_groups("[*-*-]") == [(0, 3)] + + def test_adjacent_groups(self): + assert parse_fusion_groups("[*-][ME]") == [(0, 1), (2, 3)] + + +@pytest.mark.internal +class TestBracketValidation: + """Tests for bracket validation in _validate_pattern (via parse_hybrid_pattern + and validate_segment_layers). + """ + + def test_valid_brackets_in_main_pattern(self): + """Brackets in main pattern are accepted.""" + result = parse_hybrid_pattern("[*-]M[*-]M") + assert result.main_pattern == "[*-]M[*-]M" + + def test_valid_brackets_with_pipes(self): + """Brackets within pipe segments are accepted.""" + result = parse_hybrid_pattern("[*-]M|[*-]M") + assert result.main_pattern == "[*-]M|[*-]M" + + def test_valid_brackets_in_segment(self): + """validate_segment_layers collapses each fusion group into a single entry.""" + # Fused "[*-]" -> single multi-char entry, other layers stay single-char. + result = validate_segment_layers("[*-]M*") + assert result == ['*-', 'M', '*'] + + # Multiple fusion groups and mixed layers. + assert validate_segment_layers("[*-]M[*-]M") == ['*-', 'M', '*-', 'M'] + + # No brackets -> unchanged (every entry is single-character). + assert validate_segment_layers("M*-M") == ['M', '*', '-', 'M'] + + # Fused group of 3 sub-layers is still a single physical block. + assert validate_segment_layers("[M*-]M") == ['M*-', 'M'] + + def test_unmatched_open_bracket(self): + with pytest.raises(ValueError, match="unmatched '\\['"): + parse_hybrid_pattern("[*-M*") + + def test_unmatched_close_bracket(self): + with pytest.raises(ValueError, match="unmatched '\\]'"): + parse_hybrid_pattern("*-]M*") + + def test_nested_brackets(self): + with pytest.raises(ValueError, match="nested '\\['"): + parse_hybrid_pattern("[[*-]]M") + + def test_empty_brackets(self): + with pytest.raises(ValueError, match="empty fusion group"): + parse_hybrid_pattern("[]M*") + + def test_single_layer_bracket(self): + with pytest.raises(ValueError, match="only 1 layer"): + parse_hybrid_pattern("[*]M*") + + def test_brackets_crossing_pipe(self): + with pytest.raises(ValueError, match="pipe '\\|'.*inside"): + parse_hybrid_pattern("[*|M]") + + def test_brackets_crossing_mtp(self): + with pytest.raises(ValueError, match="MTP separator '/'.*inside"): + parse_hybrid_pattern("[*/M]") + + def test_layer_count_with_brackets(self): + """get_hybrid_total_layer_count counts each fusion group as 1 physical layer.""" + # "[*-]M[*-]M" -> 2 fused blocks + 2 mamba = 4 physical layers + assert get_hybrid_total_layer_count("[*-]M[*-]M") == 4 + assert get_hybrid_total_layer_count("[*-]M|[*-]M") == 4 + # 3-layer fusion still counts as 1 block + assert get_hybrid_total_layer_count("[M*-]M") == 2 + # No-bracket pattern is unchanged + assert get_hybrid_total_layer_count("M*M*") == 4 + + def test_layer_counts_with_brackets(self): + """get_hybrid_layer_counts correctly ignores brackets.""" + counts = get_hybrid_layer_counts("[*-]M[*-]M") + assert counts['*'] == 2 + assert counts['-'] == 2 + assert counts['M'] == 2 + + def test_brackets_in_mtp_pattern(self): + """Brackets in MTP patterns are accepted.""" + result = parse_hybrid_pattern("M*M*/[*-]/[*-]") + assert result.mtp_pattern == "[*-]" + assert result.mtp_num_depths == 2 + + @patch('megatron.core.models.hybrid.hybrid_layer_allocation.log_on_each_pipeline_stage') + def test_offset_with_brackets_in_segment(self, mock_log): + """select_pipeline_segment counts each fusion group as 1 layer for offset.""" + # Segment 0: "[*-]M" -> 1 fused block + 1 mamba = 2 physical layers + # Segment 1 should therefore start at offset 2 (not 3). + _, offset = select_pipeline_segment("[*-]M|M*", pp_group=None, vp_stage=1) + assert offset == 2 + + # Segment 0: "[*-][ME]" -> 2 fused blocks + _, offset = select_pipeline_segment("[*-][ME]|M", pp_group=None, vp_stage=1) + assert offset == 2 + + +@pytest.mark.internal +class TestGetSubLayerOffset: + """Tests for `get_sub_layer_offset`. + + The helper converts a physical-block offset (the value produced by + `select_pipeline_segment`) into the corresponding sub-layer offset – + i.e., what it would be if the pattern had no fusion brackets. This is + what `HybridStack` feeds into its canonical sharded-state-dict rewrite. + """ + + def test_zero_offset(self): + assert get_sub_layer_offset("M*M*", 0) == 0 + assert get_sub_layer_offset("[M-]M", 0) == 0 + assert get_sub_layer_offset("", 0) == 0 + + def test_negative_offset_returns_zero(self): + # Guard against the caller passing a negative physical offset – the + # helper should clamp rather than over-count. + assert get_sub_layer_offset("M*M*", -1) == 0 + + def test_pattern_without_fusion_is_identity(self): + # When no fusion groups are present, physical == sub-layer count. + assert get_sub_layer_offset("M*M*", 1) == 1 + assert get_sub_layer_offset("M*M*", 2) == 2 + assert get_sub_layer_offset("M*M*", 4) == 4 + + def test_fusion_group_counts_sub_layers(self): + # A [M-] group contributes 1 physical but 2 sub-layers. + assert get_sub_layer_offset("[M-]M", 1) == 2 + assert get_sub_layer_offset("[M-]M", 2) == 3 + + def test_three_layer_fusion(self): + # A [M*-] group contributes 1 physical but 3 sub-layers. + assert get_sub_layer_offset("[M*-]M", 1) == 3 + assert get_sub_layer_offset("[M*-]M", 2) == 4 + + def test_pipes_are_ignored(self): + # Sub-layer indices are global across PP segments, so pipes do not + # reset or contribute. + assert get_sub_layer_offset("M|[*-]", 1) == 1 + assert get_sub_layer_offset("M|[*-]", 2) == 3 + + def test_consistent_with_count_pattern_layers(self): + # Feeding the full physical-block count of the (bracket-stripped, + # pipe-stripped) pattern should yield the full sub-layer count. + pattern = "[M-]M|[*-][ME]" + from megatron.core.models.hybrid.hybrid_layer_allocation import count_pattern_layers + + total_physical = count_pattern_layers(pattern) + total_sub = sum(1 for ch in pattern if ch in Symbols.VALID_LAYERS) + assert get_sub_layer_offset(pattern, total_physical) == total_sub + + def test_offset_beyond_pattern_returns_full_sub_count(self): + # Asking for more physical blocks than the pattern contains just + # walks off the end and returns the cumulative sub-layer count. + total_sub = sum(1 for ch in "[M-]M" if ch in Symbols.VALID_LAYERS) + assert get_sub_layer_offset("[M-]M", 10) == total_sub + + def test_mirrors_physical_offset_for_unfused_segment(self): + # In a mixed pattern, segments up to the first fusion have + # sub_offset == physical_offset. + assert get_sub_layer_offset("M-M-|[*-]", 1) == 1 + assert get_sub_layer_offset("M-M-|[*-]", 4) == 4 + # Beyond the fusion boundary, the sub offset diverges. + assert get_sub_layer_offset("M-M-|[*-]", 5) == 6