Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.ssm.mamba_layer import MambaLayer
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import is_te_min_version, log_single_rank
Expand Down Expand Up @@ -151,7 +152,7 @@ def __init__(
self.fsdp_unit_modules = fsdp_unit_modules
else:
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
self.fsdp_unit_modules = [TransformerLayer]
self.fsdp_unit_modules = [TransformerLayer, MambaLayer]
else:
self.fsdp_unit_modules = []

Expand Down Expand Up @@ -211,7 +212,9 @@ def load_state_dict(self, state_dict, strict=True):

self.module.load_state_dict(custom_state_dict, strict=strict)

def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Optional[str]:
def _detect_parallelism_type(
self, param_name: str, module: nn.Module, param: nn.Parameter = None
) -> Optional[str]:
"""
Infer tensor-parallelism type for a parameter under a given module
(forked from Megatron-Bridge).
Expand Down Expand Up @@ -252,6 +255,18 @@ def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Option
return "replicated"
return "row"

# Fallback to inspecting parameter-level TP attributes.
# Some modules (e.g. MambaMixer) set tensor_model_parallel and partition_dim
# directly on parameters rather than on the owning module.
if param is not None and getattr(param, "tensor_model_parallel", False):
partition_dim = getattr(param, "partition_dim", None)
if partition_dim == 0:
return "column"
elif partition_dim == 1:
if "bias" in param_name:
return "replicated"
return "row"

# Fallback for normalization layers
if any(norm in module_type for norm in ["Norm", "Normalization"]):
return "replicated"
Expand All @@ -277,7 +292,7 @@ def _annotate_tensor_parallelism(self, root_module: nn.Module) -> None:
"""
for submodule in root_module.modules():
for name, param in submodule.named_parameters(recurse=False):
detected_type = self._detect_parallelism_type(name, submodule)
detected_type = self._detect_parallelism_type(name, submodule, param)
if detected_type is not None:
setattr(param, "_tensor_parallel_mode", detected_type)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,17 @@ def _pre_forward_param_unshard(module: nn.Module, *unused):
if self.enable_fine_grained_param_gather_hook:
param_list = list(module.parameters(recurse=False))

extra_forward_param_modules = getattr(module, "_fsdp_extra_forward_param_modules", ())
if isinstance(extra_forward_param_modules, nn.Module):
extra_forward_param_modules = (extra_forward_param_modules,)
if extra_forward_param_modules:
seen_param_ids = {id(param) for param in param_list}
for extra_module in extra_forward_param_modules:
for extra_param in extra_module.parameters():
if id(extra_param) not in seen_param_ids:
param_list.append(extra_param)
seen_param_ids.add(id(extra_param))

Comment on lines +762 to +772
Copy link
Copy Markdown
Member

@cspades cspades Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a post-forward / post-backward hook that calls this function: release_module_parameters. It needs to be called on the modules in extra_forward_param_modules so we can re-shard them.

Early note, what about the pre-backward param unshard?

We need to ensure that any rogue weights are re-sharded, and un-sharded during the backward pass. Did you check that the Conv-1D weights are re-sharded?

# All-gather the parameters before the forward pass.
self.all_gather_and_wait_parameters_ready(
params=param_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .. import parallel_state, tensor_parallel
from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from ..ssm.mamba_layer import MambaLayer
from ..transformer.transformer_config import TransformerConfig
from ..transformer.transformer_layer import TransformerLayer
from .data_parallel_base import _BaseDataParallel
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
module: torch.nn.Module,
sub_modules_to_wrap: Set[torch.nn.Module] = {
TransformerLayer,
MambaLayer,
LanguageModelEmbedding,
RotaryEmbedding,
tensor_parallel.ColumnParallelLinear,
Expand Down
17 changes: 12 additions & 5 deletions megatron/core/ssm/mamba_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
A_log_cp1: torch.Tensor,
D_cp1: torch.Tensor,
D_has_hdim: bool,
mixer=None,
) -> None:
if not HAVE_EINOPS:
raise ImportError("einops is required by the Mamba model but cannot be imported")
Expand All @@ -84,6 +85,7 @@ def __init__(
self.A_log_cp1 = A_log_cp1
self.D_cp1 = D_cp1
self.D_has_hdim = D_has_hdim
self._mixer = mixer

self.cp_size = self.cp_group.size()

Expand Down Expand Up @@ -231,24 +233,29 @@ def conv1d_channels(self):
def get_conv1d_weight(self) -> torch.Tensor:
"""Returns a slice of the conv1d weight relevant to the current context parallel rank"""
# weight shape: [conv_dim, 1, d_conv]
return self._slice_conv_param(self.conv1d_cp1.weight)
conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1
return self._slice_conv_param(conv1d.weight)
Comment on lines -234 to +237
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to ask, where/when do we have "stale" references such that we can no longer directly retrieve the weight?


def get_conv1d_bias(self) -> torch.Tensor:
"""Returns a slice of the conv1d bias relevant to the current context parallel rank"""
# bias shape: [conv_dim]
return self._slice_conv_param(self.conv1d_cp1.bias)
conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1
return self._slice_conv_param(conv1d.bias)

def get_dt_bias(self) -> torch.Tensor:
"""Returns a slice of dt_bias relevant to the current context parallel rank"""
return self._slice_vector_param(self.dt_bias_cp1)
param = self._mixer.dt_bias if self._mixer is not None else self.dt_bias_cp1
return self._slice_vector_param(param)

def get_A_log(self) -> torch.Tensor:
"""Returns a slice of A_log relevant to the current context parallel rank"""
return self._slice_vector_param(self.A_log_cp1)
param = self._mixer.A_log if self._mixer is not None else self.A_log_cp1
return self._slice_vector_param(param)

def get_D(self) -> torch.Tensor:
"""Returns a slice of D relevant to the current context parallel rank"""
return self._slice_vector_param(self.D_cp1, has_hdim=self.D_has_hdim)
param = self._mixer.D if self._mixer is not None else self.D_cp1
return self._slice_vector_param(param, has_hdim=self.D_has_hdim)

def _slice_conv_param(self, param: torch.Tensor) -> torch.Tensor:
"""
Expand Down
25 changes: 17 additions & 8 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def __init__(
else:
nn.init.kaiming_uniform_(self.conv1d.weight, a=math.sqrt(5))

# The fused Mamba path reads conv1d weights directly instead of calling conv1d.forward().
self._fsdp_extra_forward_param_modules = (self.conv1d,)

self.activation = "silu"
self.act = nn.SiLU()

Expand Down Expand Up @@ -408,6 +411,7 @@ def __init__(
A_log_cp1=self.A_log,
D_cp1=self.D,
D_has_hdim=self.D_has_hdim,
mixer=self,
)
self.tp_group = pg_collection.tp

Expand Down Expand Up @@ -700,17 +704,22 @@ def _ssm_training(
assert sequence_packing_available, reason_for_no_sequence_packing
seq_idx = packed_seq_params.seq_idx

conv1d_weight = rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w")
conv1d_bias = self.cp.get_conv1d_bias()
dt_bias = self.cp.get_dt_bias().float()
D = (
rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.cp.get_D()
)

y = mamba_split_conv1d_scan_combined(
zxBCdt,
rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"),
self.cp.get_conv1d_bias(),
self.cp.get_dt_bias().float(),
conv1d_weight,
conv1d_bias,
dt_bias,
A,
D=(
rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.cp.get_D()
),
D=D,
chunk_size=self.chunk_size,
activation=self.activation,
headdim=None if self.D_has_hdim else self.headdim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,64 @@ def __init__(self, mode):
assert fsdp._detect_parallelism_type("weight", TELinear("none")) == "replicated"


def test_detect_parallelism_param_level_tp_attributes():
"""Parameters with tensor_model_parallel/partition_dim set directly on them
(rather than on the owning module) should be detected via the param-level fallback.
This is the pattern used by MambaMixer's conv1d, A_log, dt_bias, D parameters.
"""
fsdp = _make_fsdp_for_unit_tests()

class PlainModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.empty(4, 4))

module = PlainModule()

# Without param-level attrs -> None (no inference possible)
assert fsdp._detect_parallelism_type("weight", module) is None

# With param-level column (partition_dim=0)
module.weight.tensor_model_parallel = True
module.weight.partition_dim = 0
assert fsdp._detect_parallelism_type("weight", module, module.weight) == "column"

# With param-level row (partition_dim=1)
module.weight.partition_dim = 1
assert fsdp._detect_parallelism_type("weight", module, module.weight) == "row"

# Row-parallel bias should be replicated
bias = nn.Parameter(torch.empty(4))
bias.tensor_model_parallel = True
bias.partition_dim = 1
module.bias = bias
assert fsdp._detect_parallelism_type("bias", module, module.bias) == "replicated"


def test_detect_parallelism_param_level_tp_overrides_norm_fallback():
"""A Norm-like module whose weight has param-level TP attributes should be
classified by the param-level check, NOT the norm-name fallback.
This is the pattern used by MambaMixer's ExtendedRMSNorm, whose weight is
TP-sharded (partition_dim=0) rather than replicated.
"""
fsdp = _make_fsdp_for_unit_tests()

class ExtendedRMSNorm(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.empty(8))

module = ExtendedRMSNorm()

# Without param-level attrs, norm fallback should return "replicated"
assert fsdp._detect_parallelism_type("weight", module) == "replicated"

# With param-level TP attrs, should return "column" instead
module.weight.tensor_model_parallel = True
module.weight.partition_dim = 0
assert fsdp._detect_parallelism_type("weight", module, module.weight) == "column"


def test_detect_parallelism_returns_none_when_cannot_infer():
fsdp = _make_fsdp_for_unit_tests()

Expand Down Expand Up @@ -251,3 +309,82 @@ def __init__(self):
# For unknown module type, _detect_parallelism_type should return None
# and _annotate_tensor_parallelism must not set the attribute.
assert not hasattr(root.plain.weight, "_tensor_parallel_mode")


def test_annotate_tensor_parallelism_mamba_mixer_like_module():
"""Simulate a MambaMixer-like module hierarchy where TP attributes are set on
parameters rather than modules. Verify that _annotate_tensor_parallelism
correctly classifies all parameters.
"""
fsdp = _make_fsdp_for_unit_tests()

class ColumnParallelLinear(nn.Module):
"""Stands in for in_proj (module-level TP, detected via registry)."""

def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.empty(8, 4))

class RowParallelLinear(nn.Module):
"""Stands in for out_proj (module-level TP, detected via registry)."""

def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.empty(4, 8))

class ExtendedRMSNorm(nn.Module):
"""Norm with param-level TP (should NOT fall through to norm fallback)."""

def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.empty(8))
self.weight.tensor_model_parallel = True
self.weight.partition_dim = 0

class MambaMixer(nn.Module):
"""Simulates MambaMixer with conv1d and raw TP-sharded parameters."""

def __init__(self):
super().__init__()
self.in_proj = ColumnParallelLinear()
self.out_proj = RowParallelLinear()
self.norm = ExtendedRMSNorm()

# conv1d: standard nn.Conv1d with param-level TP attrs
self.conv1d = nn.Conv1d(8, 8, 4, groups=8)
self.conv1d.weight.tensor_model_parallel = True
self.conv1d.weight.partition_dim = 0
self.conv1d.bias.tensor_model_parallel = True
self.conv1d.bias.partition_dim = 0

# Raw parameters with param-level TP attrs
self.A_log = nn.Parameter(torch.empty(4))
self.A_log.tensor_model_parallel = True
self.A_log.partition_dim = 0

self.dt_bias = nn.Parameter(torch.empty(4))
self.dt_bias.tensor_model_parallel = True
self.dt_bias.partition_dim = 0

self.D = nn.Parameter(torch.empty(4))
self.D.tensor_model_parallel = True
self.D.partition_dim = 0

mixer = MambaMixer()
fsdp._annotate_tensor_parallelism(mixer)

# Module-level detection (via registry)
assert mixer.in_proj.weight._tensor_parallel_mode == "column"
assert mixer.out_proj.weight._tensor_parallel_mode == "row"

# Param-level detection for conv1d
assert mixer.conv1d.weight._tensor_parallel_mode == "column"
assert mixer.conv1d.bias._tensor_parallel_mode == "column"

# Param-level detection for raw parameters on MambaMixer
assert mixer.A_log._tensor_parallel_mode == "column"
assert mixer.dt_bias._tensor_parallel_mode == "column"
assert mixer.D._tensor_parallel_mode == "column"

# Param-level detection overrides norm fallback
assert mixer.norm.weight._tensor_parallel_mode == "column"
Loading