diff --git a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py index 996860afe26..46ea75491c5 100644 --- a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +++ b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py @@ -40,6 +40,7 @@ from megatron.core.distributed.data_parallel_base import _BaseDataParallel from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import is_te_min_version, log_single_rank @@ -151,7 +152,7 @@ def __init__( self.fsdp_unit_modules = fsdp_unit_modules else: if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": - self.fsdp_unit_modules = [TransformerLayer] + self.fsdp_unit_modules = [TransformerLayer, MambaLayer] else: self.fsdp_unit_modules = [] @@ -211,7 +212,9 @@ def load_state_dict(self, state_dict, strict=True): self.module.load_state_dict(custom_state_dict, strict=strict) - def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Optional[str]: + def _detect_parallelism_type( + self, param_name: str, module: nn.Module, param: nn.Parameter = None + ) -> Optional[str]: """ Infer tensor-parallelism type for a parameter under a given module (forked from Megatron-Bridge). @@ -252,6 +255,18 @@ def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Option return "replicated" return "row" + # Fallback to inspecting parameter-level TP attributes. + # Some modules (e.g. MambaMixer) set tensor_model_parallel and partition_dim + # directly on parameters rather than on the owning module. + if param is not None and getattr(param, "tensor_model_parallel", False): + partition_dim = getattr(param, "partition_dim", None) + if partition_dim == 0: + return "column" + elif partition_dim == 1: + if "bias" in param_name: + return "replicated" + return "row" + # Fallback for normalization layers if any(norm in module_type for norm in ["Norm", "Normalization"]): return "replicated" @@ -277,7 +292,7 @@ def _annotate_tensor_parallelism(self, root_module: nn.Module) -> None: """ for submodule in root_module.modules(): for name, param in submodule.named_parameters(recurse=False): - detected_type = self._detect_parallelism_type(name, submodule) + detected_type = self._detect_parallelism_type(name, submodule, param) if detected_type is not None: setattr(param, "_tensor_parallel_mode", detected_type) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 2a17315611a..64bbf206c14 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -759,6 +759,17 @@ def _pre_forward_param_unshard(module: nn.Module, *unused): if self.enable_fine_grained_param_gather_hook: param_list = list(module.parameters(recurse=False)) + extra_forward_param_modules = getattr(module, "_fsdp_extra_forward_param_modules", ()) + if isinstance(extra_forward_param_modules, nn.Module): + extra_forward_param_modules = (extra_forward_param_modules,) + if extra_forward_param_modules: + seen_param_ids = {id(param) for param in param_list} + for extra_module in extra_forward_param_modules: + for extra_param in extra_module.parameters(): + if id(extra_param) not in seen_param_ids: + param_list.append(extra_param) + seen_param_ids.add(id(extra_param)) + # All-gather the parameters before the forward pass. self.all_gather_and_wait_parameters_ready( params=param_list, diff --git a/megatron/core/distributed/torch_fully_sharded_data_parallel.py b/megatron/core/distributed/torch_fully_sharded_data_parallel.py index 43321bc78cc..d397aac4b5b 100644 --- a/megatron/core/distributed/torch_fully_sharded_data_parallel.py +++ b/megatron/core/distributed/torch_fully_sharded_data_parallel.py @@ -19,6 +19,7 @@ from .. import parallel_state, tensor_parallel from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from ..ssm.mamba_layer import MambaLayer from ..transformer.transformer_config import TransformerConfig from ..transformer.transformer_layer import TransformerLayer from .data_parallel_base import _BaseDataParallel @@ -59,6 +60,7 @@ def __init__( module: torch.nn.Module, sub_modules_to_wrap: Set[torch.nn.Module] = { TransformerLayer, + MambaLayer, LanguageModelEmbedding, RotaryEmbedding, tensor_parallel.ColumnParallelLinear, diff --git a/megatron/core/ssm/mamba_context_parallel.py b/megatron/core/ssm/mamba_context_parallel.py index 3925f8bd8df..f8297766de0 100644 --- a/megatron/core/ssm/mamba_context_parallel.py +++ b/megatron/core/ssm/mamba_context_parallel.py @@ -70,6 +70,7 @@ def __init__( A_log_cp1: torch.Tensor, D_cp1: torch.Tensor, D_has_hdim: bool, + mixer=None, ) -> None: if not HAVE_EINOPS: raise ImportError("einops is required by the Mamba model but cannot be imported") @@ -84,6 +85,7 @@ def __init__( self.A_log_cp1 = A_log_cp1 self.D_cp1 = D_cp1 self.D_has_hdim = D_has_hdim + self._mixer = mixer self.cp_size = self.cp_group.size() @@ -231,24 +233,29 @@ def conv1d_channels(self): def get_conv1d_weight(self) -> torch.Tensor: """Returns a slice of the conv1d weight relevant to the current context parallel rank""" # weight shape: [conv_dim, 1, d_conv] - return self._slice_conv_param(self.conv1d_cp1.weight) + conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1 + return self._slice_conv_param(conv1d.weight) def get_conv1d_bias(self) -> torch.Tensor: """Returns a slice of the conv1d bias relevant to the current context parallel rank""" # bias shape: [conv_dim] - return self._slice_conv_param(self.conv1d_cp1.bias) + conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1 + return self._slice_conv_param(conv1d.bias) def get_dt_bias(self) -> torch.Tensor: """Returns a slice of dt_bias relevant to the current context parallel rank""" - return self._slice_vector_param(self.dt_bias_cp1) + param = self._mixer.dt_bias if self._mixer is not None else self.dt_bias_cp1 + return self._slice_vector_param(param) def get_A_log(self) -> torch.Tensor: """Returns a slice of A_log relevant to the current context parallel rank""" - return self._slice_vector_param(self.A_log_cp1) + param = self._mixer.A_log if self._mixer is not None else self.A_log_cp1 + return self._slice_vector_param(param) def get_D(self) -> torch.Tensor: """Returns a slice of D relevant to the current context parallel rank""" - return self._slice_vector_param(self.D_cp1, has_hdim=self.D_has_hdim) + param = self._mixer.D if self._mixer is not None else self.D_cp1 + return self._slice_vector_param(param, has_hdim=self.D_has_hdim) def _slice_conv_param(self, param: torch.Tensor) -> torch.Tensor: """ diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 727c6ef5fd6..0b3eb6b7708 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -316,6 +316,9 @@ def __init__( else: nn.init.kaiming_uniform_(self.conv1d.weight, a=math.sqrt(5)) + # The fused Mamba path reads conv1d weights directly instead of calling conv1d.forward(). + self._fsdp_extra_forward_param_modules = (self.conv1d,) + self.activation = "silu" self.act = nn.SiLU() @@ -408,6 +411,7 @@ def __init__( A_log_cp1=self.A_log, D_cp1=self.D, D_has_hdim=self.D_has_hdim, + mixer=self, ) self.tp_group = pg_collection.tp @@ -700,17 +704,22 @@ def _ssm_training( assert sequence_packing_available, reason_for_no_sequence_packing seq_idx = packed_seq_params.seq_idx + conv1d_weight = rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w") + conv1d_bias = self.cp.get_conv1d_bias() + dt_bias = self.cp.get_dt_bias().float() + D = ( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ) + y = mamba_split_conv1d_scan_combined( zxBCdt, - rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), - self.cp.get_conv1d_bias(), - self.cp.get_dt_bias().float(), + conv1d_weight, + conv1d_bias, + dt_bias, A, - D=( - rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.cp.get_D() - ), + D=D, chunk_size=self.chunk_size, activation=self.activation, headdim=None if self.D_has_hdim else self.headdim, diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py index c69ca817872..63aa24e34cb 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py @@ -203,6 +203,64 @@ def __init__(self, mode): assert fsdp._detect_parallelism_type("weight", TELinear("none")) == "replicated" +def test_detect_parallelism_param_level_tp_attributes(): + """Parameters with tensor_model_parallel/partition_dim set directly on them + (rather than on the owning module) should be detected via the param-level fallback. + This is the pattern used by MambaMixer's conv1d, A_log, dt_bias, D parameters. + """ + fsdp = _make_fsdp_for_unit_tests() + + class PlainModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(4, 4)) + + module = PlainModule() + + # Without param-level attrs -> None (no inference possible) + assert fsdp._detect_parallelism_type("weight", module) is None + + # With param-level column (partition_dim=0) + module.weight.tensor_model_parallel = True + module.weight.partition_dim = 0 + assert fsdp._detect_parallelism_type("weight", module, module.weight) == "column" + + # With param-level row (partition_dim=1) + module.weight.partition_dim = 1 + assert fsdp._detect_parallelism_type("weight", module, module.weight) == "row" + + # Row-parallel bias should be replicated + bias = nn.Parameter(torch.empty(4)) + bias.tensor_model_parallel = True + bias.partition_dim = 1 + module.bias = bias + assert fsdp._detect_parallelism_type("bias", module, module.bias) == "replicated" + + +def test_detect_parallelism_param_level_tp_overrides_norm_fallback(): + """A Norm-like module whose weight has param-level TP attributes should be + classified by the param-level check, NOT the norm-name fallback. + This is the pattern used by MambaMixer's ExtendedRMSNorm, whose weight is + TP-sharded (partition_dim=0) rather than replicated. + """ + fsdp = _make_fsdp_for_unit_tests() + + class ExtendedRMSNorm(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(8)) + + module = ExtendedRMSNorm() + + # Without param-level attrs, norm fallback should return "replicated" + assert fsdp._detect_parallelism_type("weight", module) == "replicated" + + # With param-level TP attrs, should return "column" instead + module.weight.tensor_model_parallel = True + module.weight.partition_dim = 0 + assert fsdp._detect_parallelism_type("weight", module, module.weight) == "column" + + def test_detect_parallelism_returns_none_when_cannot_infer(): fsdp = _make_fsdp_for_unit_tests() @@ -251,3 +309,82 @@ def __init__(self): # For unknown module type, _detect_parallelism_type should return None # and _annotate_tensor_parallelism must not set the attribute. assert not hasattr(root.plain.weight, "_tensor_parallel_mode") + + +def test_annotate_tensor_parallelism_mamba_mixer_like_module(): + """Simulate a MambaMixer-like module hierarchy where TP attributes are set on + parameters rather than modules. Verify that _annotate_tensor_parallelism + correctly classifies all parameters. + """ + fsdp = _make_fsdp_for_unit_tests() + + class ColumnParallelLinear(nn.Module): + """Stands in for in_proj (module-level TP, detected via registry).""" + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(8, 4)) + + class RowParallelLinear(nn.Module): + """Stands in for out_proj (module-level TP, detected via registry).""" + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(4, 8)) + + class ExtendedRMSNorm(nn.Module): + """Norm with param-level TP (should NOT fall through to norm fallback).""" + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(8)) + self.weight.tensor_model_parallel = True + self.weight.partition_dim = 0 + + class MambaMixer(nn.Module): + """Simulates MambaMixer with conv1d and raw TP-sharded parameters.""" + + def __init__(self): + super().__init__() + self.in_proj = ColumnParallelLinear() + self.out_proj = RowParallelLinear() + self.norm = ExtendedRMSNorm() + + # conv1d: standard nn.Conv1d with param-level TP attrs + self.conv1d = nn.Conv1d(8, 8, 4, groups=8) + self.conv1d.weight.tensor_model_parallel = True + self.conv1d.weight.partition_dim = 0 + self.conv1d.bias.tensor_model_parallel = True + self.conv1d.bias.partition_dim = 0 + + # Raw parameters with param-level TP attrs + self.A_log = nn.Parameter(torch.empty(4)) + self.A_log.tensor_model_parallel = True + self.A_log.partition_dim = 0 + + self.dt_bias = nn.Parameter(torch.empty(4)) + self.dt_bias.tensor_model_parallel = True + self.dt_bias.partition_dim = 0 + + self.D = nn.Parameter(torch.empty(4)) + self.D.tensor_model_parallel = True + self.D.partition_dim = 0 + + mixer = MambaMixer() + fsdp._annotate_tensor_parallelism(mixer) + + # Module-level detection (via registry) + assert mixer.in_proj.weight._tensor_parallel_mode == "column" + assert mixer.out_proj.weight._tensor_parallel_mode == "row" + + # Param-level detection for conv1d + assert mixer.conv1d.weight._tensor_parallel_mode == "column" + assert mixer.conv1d.bias._tensor_parallel_mode == "column" + + # Param-level detection for raw parameters on MambaMixer + assert mixer.A_log._tensor_parallel_mode == "column" + assert mixer.dt_bias._tensor_parallel_mode == "column" + assert mixer.D._tensor_parallel_mode == "column" + + # Param-level detection overrides norm fallback + assert mixer.norm.weight._tensor_parallel_mode == "column"