diff --git a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py index 996860afe26..46ea75491c5 100644 --- a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +++ b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py @@ -40,6 +40,7 @@ from megatron.core.distributed.data_parallel_base import _BaseDataParallel from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import is_te_min_version, log_single_rank @@ -151,7 +152,7 @@ def __init__( self.fsdp_unit_modules = fsdp_unit_modules else: if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": - self.fsdp_unit_modules = [TransformerLayer] + self.fsdp_unit_modules = [TransformerLayer, MambaLayer] else: self.fsdp_unit_modules = [] @@ -211,7 +212,9 @@ def load_state_dict(self, state_dict, strict=True): self.module.load_state_dict(custom_state_dict, strict=strict) - def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Optional[str]: + def _detect_parallelism_type( + self, param_name: str, module: nn.Module, param: nn.Parameter = None + ) -> Optional[str]: """ Infer tensor-parallelism type for a parameter under a given module (forked from Megatron-Bridge). @@ -252,6 +255,18 @@ def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Option return "replicated" return "row" + # Fallback to inspecting parameter-level TP attributes. + # Some modules (e.g. MambaMixer) set tensor_model_parallel and partition_dim + # directly on parameters rather than on the owning module. + if param is not None and getattr(param, "tensor_model_parallel", False): + partition_dim = getattr(param, "partition_dim", None) + if partition_dim == 0: + return "column" + elif partition_dim == 1: + if "bias" in param_name: + return "replicated" + return "row" + # Fallback for normalization layers if any(norm in module_type for norm in ["Norm", "Normalization"]): return "replicated" @@ -277,7 +292,7 @@ def _annotate_tensor_parallelism(self, root_module: nn.Module) -> None: """ for submodule in root_module.modules(): for name, param in submodule.named_parameters(recurse=False): - detected_type = self._detect_parallelism_type(name, submodule) + detected_type = self._detect_parallelism_type(name, submodule, param) if detected_type is not None: setattr(param, "_tensor_parallel_mode", detected_type) diff --git a/megatron/core/distributed/torch_fully_sharded_data_parallel.py b/megatron/core/distributed/torch_fully_sharded_data_parallel.py index 43321bc78cc..d397aac4b5b 100644 --- a/megatron/core/distributed/torch_fully_sharded_data_parallel.py +++ b/megatron/core/distributed/torch_fully_sharded_data_parallel.py @@ -19,6 +19,7 @@ from .. import parallel_state, tensor_parallel from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from ..ssm.mamba_layer import MambaLayer from ..transformer.transformer_config import TransformerConfig from ..transformer.transformer_layer import TransformerLayer from .data_parallel_base import _BaseDataParallel @@ -59,6 +60,7 @@ def __init__( module: torch.nn.Module, sub_modules_to_wrap: Set[torch.nn.Module] = { TransformerLayer, + MambaLayer, LanguageModelEmbedding, RotaryEmbedding, tensor_parallel.ColumnParallelLinear, diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py index c69ca817872..63aa24e34cb 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py @@ -203,6 +203,64 @@ def __init__(self, mode): assert fsdp._detect_parallelism_type("weight", TELinear("none")) == "replicated" +def test_detect_parallelism_param_level_tp_attributes(): + """Parameters with tensor_model_parallel/partition_dim set directly on them + (rather than on the owning module) should be detected via the param-level fallback. + This is the pattern used by MambaMixer's conv1d, A_log, dt_bias, D parameters. + """ + fsdp = _make_fsdp_for_unit_tests() + + class PlainModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(4, 4)) + + module = PlainModule() + + # Without param-level attrs -> None (no inference possible) + assert fsdp._detect_parallelism_type("weight", module) is None + + # With param-level column (partition_dim=0) + module.weight.tensor_model_parallel = True + module.weight.partition_dim = 0 + assert fsdp._detect_parallelism_type("weight", module, module.weight) == "column" + + # With param-level row (partition_dim=1) + module.weight.partition_dim = 1 + assert fsdp._detect_parallelism_type("weight", module, module.weight) == "row" + + # Row-parallel bias should be replicated + bias = nn.Parameter(torch.empty(4)) + bias.tensor_model_parallel = True + bias.partition_dim = 1 + module.bias = bias + assert fsdp._detect_parallelism_type("bias", module, module.bias) == "replicated" + + +def test_detect_parallelism_param_level_tp_overrides_norm_fallback(): + """A Norm-like module whose weight has param-level TP attributes should be + classified by the param-level check, NOT the norm-name fallback. + This is the pattern used by MambaMixer's ExtendedRMSNorm, whose weight is + TP-sharded (partition_dim=0) rather than replicated. + """ + fsdp = _make_fsdp_for_unit_tests() + + class ExtendedRMSNorm(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(8)) + + module = ExtendedRMSNorm() + + # Without param-level attrs, norm fallback should return "replicated" + assert fsdp._detect_parallelism_type("weight", module) == "replicated" + + # With param-level TP attrs, should return "column" instead + module.weight.tensor_model_parallel = True + module.weight.partition_dim = 0 + assert fsdp._detect_parallelism_type("weight", module, module.weight) == "column" + + def test_detect_parallelism_returns_none_when_cannot_infer(): fsdp = _make_fsdp_for_unit_tests() @@ -251,3 +309,82 @@ def __init__(self): # For unknown module type, _detect_parallelism_type should return None # and _annotate_tensor_parallelism must not set the attribute. assert not hasattr(root.plain.weight, "_tensor_parallel_mode") + + +def test_annotate_tensor_parallelism_mamba_mixer_like_module(): + """Simulate a MambaMixer-like module hierarchy where TP attributes are set on + parameters rather than modules. Verify that _annotate_tensor_parallelism + correctly classifies all parameters. + """ + fsdp = _make_fsdp_for_unit_tests() + + class ColumnParallelLinear(nn.Module): + """Stands in for in_proj (module-level TP, detected via registry).""" + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(8, 4)) + + class RowParallelLinear(nn.Module): + """Stands in for out_proj (module-level TP, detected via registry).""" + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(4, 8)) + + class ExtendedRMSNorm(nn.Module): + """Norm with param-level TP (should NOT fall through to norm fallback).""" + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(8)) + self.weight.tensor_model_parallel = True + self.weight.partition_dim = 0 + + class MambaMixer(nn.Module): + """Simulates MambaMixer with conv1d and raw TP-sharded parameters.""" + + def __init__(self): + super().__init__() + self.in_proj = ColumnParallelLinear() + self.out_proj = RowParallelLinear() + self.norm = ExtendedRMSNorm() + + # conv1d: standard nn.Conv1d with param-level TP attrs + self.conv1d = nn.Conv1d(8, 8, 4, groups=8) + self.conv1d.weight.tensor_model_parallel = True + self.conv1d.weight.partition_dim = 0 + self.conv1d.bias.tensor_model_parallel = True + self.conv1d.bias.partition_dim = 0 + + # Raw parameters with param-level TP attrs + self.A_log = nn.Parameter(torch.empty(4)) + self.A_log.tensor_model_parallel = True + self.A_log.partition_dim = 0 + + self.dt_bias = nn.Parameter(torch.empty(4)) + self.dt_bias.tensor_model_parallel = True + self.dt_bias.partition_dim = 0 + + self.D = nn.Parameter(torch.empty(4)) + self.D.tensor_model_parallel = True + self.D.partition_dim = 0 + + mixer = MambaMixer() + fsdp._annotate_tensor_parallelism(mixer) + + # Module-level detection (via registry) + assert mixer.in_proj.weight._tensor_parallel_mode == "column" + assert mixer.out_proj.weight._tensor_parallel_mode == "row" + + # Param-level detection for conv1d + assert mixer.conv1d.weight._tensor_parallel_mode == "column" + assert mixer.conv1d.bias._tensor_parallel_mode == "column" + + # Param-level detection for raw parameters on MambaMixer + assert mixer.A_log._tensor_parallel_mode == "column" + assert mixer.dt_bias._tensor_parallel_mode == "column" + assert mixer.D._tensor_parallel_mode == "column" + + # Param-level detection overrides norm fallback + assert mixer.norm.weight._tensor_parallel_mode == "column"