Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions megatron/core/models/hybrid/hybrid_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from megatron.core.fp8_utils import get_fp8_context
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.models.hybrid.hybrid_layer_allocation import Symbols as LayerSymbols
from megatron.core.models.hybrid.hybrid_layer_fusion import build_fused_layer
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer import TransformerConfig
Expand All @@ -36,8 +37,19 @@
class HybridStackSubmodules:
"""
A class for the module specs for the HybridStack.

The `*_layer` fields specify the full block used for a stand-alone layer
(a single symbol in the hybrid layer pattern).

The `*_mixer` fields specify the primitive module plugged into a
`TransformerLayer` when two layers are fused via the `[XY]` syntax
in the hybrid layer pattern. For a fused group `[XY]`, the sequence mixer
corresponding to `X` is used as `self_attention` and the channel mixer
corresponding to `Y` is used as `mlp` inside a freshly-built
`TransformerLayer` spec.
"""

# Stand-alone layer specs – one block per pattern symbol.
mamba_layer: Union[ModuleSpec, type] = IdentityOp
gdn_layer: Union[ModuleSpec, type] = IdentityOp
attention_layer: Union[ModuleSpec, type] = IdentityOp
Expand All @@ -46,6 +58,16 @@ class HybridStackSubmodules:
moe_layer: Union[ModuleSpec, type] = IdentityOp
mtp_block_spec: Optional[ModuleSpec] = None

# Primitive specs used when building a fused TransformerLayer at runtime.
# Sequence mixers fill the `self_attention` slot:
mamba_mixer: Union[ModuleSpec, type] = IdentityOp
gdn_mixer: Union[ModuleSpec, type] = IdentityOp
attention_mixer: Union[ModuleSpec, type] = IdentityOp
dsa_mixer: Union[ModuleSpec, type] = IdentityOp
# Channel mixers fill the `mlp` slot:
mlp_mixer: Union[ModuleSpec, type] = IdentityOp
moe_mixer: Union[ModuleSpec, type] = IdentityOp


class HybridStack(GraphableMegatronModule, MegatronModule):
"""
Expand All @@ -60,7 +82,8 @@ class HybridStack(GraphableMegatronModule, MegatronModule):
this pipeline segment. When provided (by HybridModel), pipeline stage
selection has already been done via '|' separators in the pattern.
pp_layer_offset (int, optional): the global layer offset for this pipeline
segment. Defaults to 0.
segment, measured in physical blocks (fused groups count as one).
Defaults to 0.
post_layer_norm (bool, optional): whether to include a final layer norm.
Defaults to True.
post_process (bool, optional): whether to include an output layer.
Expand Down Expand Up @@ -118,7 +141,24 @@ def __init__(
else:
quant_init_context = nullcontext()
with quant_init_context:
if layer_type == LayerSymbols.MAMBA:
if len(layer_type) > 1:
# Multi-character entries come from bracketed fusion groups
# in the hybrid layer pattern, e.g., "[*-]" -> "*-".
# `layer_number` already includes `pp_layer_offset` (see
# the computation at the top of this loop), so the outer
# TransformerLayer must not add it again – same contract
# as the stand-alone TransformerLayer dispatches below.
layer = build_fused_layer(
layer_type,
submodules,
config=self.config,
layer_number=layer_number,
pg_collection=pg_collection,
pp_layer_offset=pp_layer_offset,
is_mtp_layer=is_mtp_layer,
add_layer_offset=False,
)
elif layer_type == LayerSymbols.MAMBA:
layer = build_module(
submodules.mamba_layer,
config=self.config,
Expand Down Expand Up @@ -200,10 +240,19 @@ def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int
"""
Returns the Mamba conv and ssm states shapes per input sequence
if this block contains Mamba layers (this may not be the case with PP > 1).

A stand-alone Mamba block exposes `mamba_state_shapes_per_request`
directly (it is a `MambaLayer`). A fused `[M...]` block is a
`TransformerLayer` whose `self_attention` slot holds the MambaMixer,
so we have to descend one level.
"""
for layer_type, layer in zip(self.layer_type_list, self.layers):
if layer_type == LayerSymbols.MAMBA:
return layer.mamba_state_shapes_per_request()
# Fused Mamba is surfaced via the enclosing TransformerLayer's `self_attention`
# attribute.
elif LayerSymbols.MAMBA in layer_type:
return layer.self_attention.mamba_state_shapes_per_request()
return None

def _should_call_local_cudagraph(self, *args, **kwargs):
Expand Down
Loading
Loading