[codex] Fix Mamba conv params under fine-grained FSDP gather#4467
[codex] Fix Mamba conv params under fine-grained FSDP gather#4467ilml wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
|
CC'd from DM's Hmm, so the MambaMixer has a Conv1D submodule but the Conv1D submodule's weights are used in Autograd functions during the MambaMixer.forward() pass?
|
MambaLayer (GraphableMegatronModule) was not recognized as an FSDP sharding unit, causing its parameters to remain in the root group and defeating ZeRO-3 param sharding for Mamba and hybrid models. Additionally, MambaMixer sets tensor_model_parallel and partition_dim directly on parameters (conv1d, A_log, dt_bias, D, norm.weight) rather than on the owning module. The TP annotation logic only checked module-level attributes, so these parameters were either unclassified or misclassified by the norm-name fallback (e.g. ExtendedRMSNorm treated as replicated when actually TP-sharded). Changes: - Register MambaLayer in default fsdp_unit_modules (mcore_fsdp_adapter) and sub_modules_to_wrap (torch_fully_sharded_data_parallel) - Add param-level TP attribute fallback in _detect_parallelism_type, placed before the norm-name fallback so TP-sharded norm weights are correctly classified - Pass param through from _annotate_tensor_parallelism - Add tests for param-level TP detection, norm override, and a MambaMixer-like end-to-end annotation test Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Mamba's fused path reads conv1d weights directly instead of calling Conv1d.forward(), so fine-grained Megatron-FSDP never gathered those child parameters before the second forward. Register the conv module as an extra forward-gather source and resolve context-parallel Mamba params from the live mixer object.
5590832 to
19b794b
Compare
|
Also, if we do implement this, I think Torch FSDP2 has something similar for "hitch-hiking" parameters into the AG: That's per-parameter / per-Tensor though, so we're basically doing the same thing here with modules, I suppose. |
cspades
left a comment
There was a problem hiding this comment.
Looks pretty good, per-module parameter hitch-hiking seems reasonable for me since our FSDP units are based on modules as well.
@shjwudp @Autumn1998 @wujingyue in the re-write we should expose this as an API, but this PR defines the feature, is this general enough?
| return self._slice_conv_param(self.conv1d_cp1.weight) | ||
| conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1 | ||
| return self._slice_conv_param(conv1d.weight) |
There was a problem hiding this comment.
Have to ask, where/when do we have "stale" references such that we can no longer directly retrieve the weight?
wujingyue
left a comment
There was a problem hiding this comment.
Thanks for the fix!
IIUC,
is wheredef conv1d calls the F.conv1d instead of a submodule.
Instead, could MambaContextParallel's constructor create the submodule from conv1d_cp1? For your reference, https://gitlab-master.nvidia.com/clara-discovery/boltz/-/blob/dev/src/boltz/distributed/model/layers/triangular_attention.py#L1490 is an internal example that adopts this practice for context parallelism.
cc @cspades
| extra_forward_param_modules = getattr(module, "_fsdp_extra_forward_param_modules", ()) | ||
| if isinstance(extra_forward_param_modules, nn.Module): | ||
| extra_forward_param_modules = (extra_forward_param_modules,) | ||
| if extra_forward_param_modules: | ||
| seen_param_ids = {id(param) for param in param_list} | ||
| for extra_module in extra_forward_param_modules: | ||
| for extra_param in extra_module.parameters(): | ||
| if id(extra_param) not in seen_param_ids: | ||
| param_list.append(extra_param) | ||
| seen_param_ids.add(id(extra_param)) | ||
|
|
There was a problem hiding this comment.
There is also a post-forward / post-backward hook that calls this function: release_module_parameters. It needs to be called on the modules in extra_forward_param_modules so we can re-shard them.
Early note, what about the pre-backward param unshard?
We need to ensure that any rogue weights are re-sharded, and un-sharded during the backward pass. Did you check that the Conv-1D weights are re-sharded?
cspades
left a comment
There was a problem hiding this comment.
Still WIP, will approve when all features are implemented!
Summary
Fix fine-grained Megatron-FSDP parameter gathering for Mamba's fused conv path.
This branch is now stacked with the MambaLayer FSDP support from #4329 first, then the direct conv-param gather fix. #4329 makes
MambaLayeran FSDP unit and fixes TP annotation for SSM parameters; this PR handles the remaining case whereMambaMixer.forward()reads a child module's parameters without invoking that child module's forward hook.Mamba's memory-efficient fused path reads
conv1d.weightandconv1d.biasdirectly and passes them intomamba_split_conv1d_scan_combined, instead of callingConv1d.forward(). With fine-grained Megatron-FSDP gather enabled, that means the childconv1dmodule's pre-forward gather hook never runs. After the first forward releases parameter storage, the second forward can pass a null-base sharded view into causal-conv, producing an illegal memory access.This PR adds an opt-in
_fsdp_extra_forward_param_moduleshook for modules that directly read child-module params, and uses it fromMambaMixerforself.conv1d. It also makes Mamba context-parallel parameter access resolve through the live mixer object so FSDP-updated parameters are not bypassed by stale cached references.Validation
python3 -m py_compile megatron/core/distributed/fsdp/mcore_fsdp_adapter.py megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py megatron/core/distributed/torch_fully_sharded_data_parallel.py megatron/core/ssm/mamba_context_parallel.py megatron/core/ssm/mamba_mixer.py tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.pypython3 -m pytest tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py -qcould not run on the login node becausepytestis not installed there./home/tolong/work/dsv3/interactive_nt.shcompleted through iteration 50 and saved checkpoint with noCUDA error,illegal memory, orFAILEDin the log:/lustre/fsw/coreai_dlalgo_llm/tolong/results/nemo_megatron/megatron/nemotron6/hybrid/debug/interactive_nt_debug/logs/interactive_nt_full_20260424_150942.log