Skip to content

[codex] Fix Mamba conv params under fine-grained FSDP gather#4467

Open
ilml wants to merge 3 commits intoNVIDIA:mainfrom
ilml:codex/fix-mamba-fsdp-direct-conv-gather
Open

[codex] Fix Mamba conv params under fine-grained FSDP gather#4467
ilml wants to merge 3 commits intoNVIDIA:mainfrom
ilml:codex/fix-mamba-fsdp-direct-conv-gather

Conversation

@ilml
Copy link
Copy Markdown
Contributor

@ilml ilml commented Apr 24, 2026

Summary

Fix fine-grained Megatron-FSDP parameter gathering for Mamba's fused conv path.

This branch is now stacked with the MambaLayer FSDP support from #4329 first, then the direct conv-param gather fix. #4329 makes MambaLayer an FSDP unit and fixes TP annotation for SSM parameters; this PR handles the remaining case where MambaMixer.forward() reads a child module's parameters without invoking that child module's forward hook.

Mamba's memory-efficient fused path reads conv1d.weight and conv1d.bias directly and passes them into mamba_split_conv1d_scan_combined, instead of calling Conv1d.forward(). With fine-grained Megatron-FSDP gather enabled, that means the child conv1d module's pre-forward gather hook never runs. After the first forward releases parameter storage, the second forward can pass a null-base sharded view into causal-conv, producing an illegal memory access.

This PR adds an opt-in _fsdp_extra_forward_param_modules hook for modules that directly read child-module params, and uses it from MambaMixer for self.conv1d. It also makes Mamba context-parallel parameter access resolve through the live mixer object so FSDP-updated parameters are not bypassed by stale cached references.

Validation

  • python3 -m py_compile megatron/core/distributed/fsdp/mcore_fsdp_adapter.py megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py megatron/core/distributed/torch_fully_sharded_data_parallel.py megatron/core/ssm/mamba_context_parallel.py megatron/core/ssm/mamba_mixer.py tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py
  • python3 -m pytest tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py -q could not run on the login node because pytest is not installed there.
  • Before stacking Add MambaLayer FSDP support and fix TP annotation for SSM parameters #4329, a clean interactive run of /home/tolong/work/dsv3/interactive_nt.sh completed through iteration 50 and saved checkpoint with no CUDA error, illegal memory, or FAILED in the log:
    /lustre/fsw/coreai_dlalgo_llm/tolong/results/nemo_megatron/megatron/nemotron6/hybrid/debug/interactive_nt_debug/logs/interactive_nt_full_20260424_150942.log
  • Earlier two-iteration blocking-CUDA smoke test completed and saved checkpoint at iteration 2 after reproducing the pre-fix crash at that boundary.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cspades
Copy link
Copy Markdown
Member

cspades commented Apr 24, 2026

CC'd from DM's

Hmm, so the MambaMixer has a Conv1D submodule but the Conv1D submodule's weights are used in Autograd functions during the MambaMixer.forward() pass?

  • I assume you are using fine_grained_all_gather since the Conv1D is actually a sub-module of MambaMixer so only in the case where the Conv1D weights would not be AG'd is if we are using fine-grained module AG hooks, in which case we never trigger the Conv1D hook.
  • There is also a post-forward / post-backward hook that calls this function: release_module_parameters. It needs to be called on the modules in extra_forward_param_modules so we can re-shard them. Try asking Codex what is the cleanest, most generalizable way to have the hooks support re-sharding these untouched modules as well.
  • Are you setting fsdp_unit_modules=[TransformerLayer, MambaLayer]?  Either you hit the storage error on Step 1 (the very first AG) but if Mamba is not a FSDP unit, it will never be re-sharded, skipping the post-forward / post-backward hook.

Phlip79 and others added 2 commits April 24, 2026 16:20
MambaLayer (GraphableMegatronModule) was not recognized as an FSDP
sharding unit, causing its parameters to remain in the root group and
defeating ZeRO-3 param sharding for Mamba and hybrid models.

Additionally, MambaMixer sets tensor_model_parallel and partition_dim
directly on parameters (conv1d, A_log, dt_bias, D, norm.weight) rather
than on the owning module. The TP annotation logic only checked
module-level attributes, so these parameters were either unclassified
or misclassified by the norm-name fallback (e.g. ExtendedRMSNorm
treated as replicated when actually TP-sharded).

Changes:
- Register MambaLayer in default fsdp_unit_modules (mcore_fsdp_adapter)
  and sub_modules_to_wrap (torch_fully_sharded_data_parallel)
- Add param-level TP attribute fallback in _detect_parallelism_type,
  placed before the norm-name fallback so TP-sharded norm weights are
  correctly classified
- Pass param through from _annotate_tensor_parallelism
- Add tests for param-level TP detection, norm override, and a
  MambaMixer-like end-to-end annotation test

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Mamba's fused path reads conv1d weights directly instead of calling Conv1d.forward(), so fine-grained Megatron-FSDP never gathered those child parameters before the second forward. Register the conv module as an extra forward-gather source and resolve context-parallel Mamba params from the live mixer object.
@ilml ilml force-pushed the codex/fix-mamba-fsdp-direct-conv-gather branch from 5590832 to 19b794b Compare April 24, 2026 23:21
@cspades
Copy link
Copy Markdown
Member

cspades commented Apr 24, 2026

Also, if we do implement this, I think Torch FSDP2 has something similar for "hitch-hiking" parameters into the AG: Tensor.fsdp_pre_all_gather and Tensor.fsdp_post_all_gather (Example in TE: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/tensor/mxfp8_tensor.py#L613)

That's per-parameter / per-Tensor though, so we're basically doing the same thing here with modules, I suppose.

@ilml ilml marked this pull request as ready for review April 27, 2026 23:21
@ilml ilml requested review from a team as code owners April 27, 2026 23:21
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 27, 2026 23:21
@ilml ilml requested review from cspades and shjwudp April 27, 2026 23:23
Copy link
Copy Markdown
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, per-module parameter hitch-hiking seems reasonable for me since our FSDP units are based on modules as well.

@shjwudp @Autumn1998 @wujingyue in the re-write we should expose this as an API, but this PR defines the feature, is this general enough?

Comment on lines -234 to +237
return self._slice_conv_param(self.conv1d_cp1.weight)
conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1
return self._slice_conv_param(conv1d.weight)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to ask, where/when do we have "stale" references such that we can no longer directly retrieve the weight?

Copy link
Copy Markdown
Contributor

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

IIUC,

is where def conv1d calls the F.conv1d instead of a submodule.

Instead, could MambaContextParallel's constructor create the submodule from conv1d_cp1? For your reference, https://gitlab-master.nvidia.com/clara-discovery/boltz/-/blob/dev/src/boltz/distributed/model/layers/triangular_attention.py#L1490 is an internal example that adopts this practice for context parallelism.

cc @cspades

Comment on lines +762 to +772
extra_forward_param_modules = getattr(module, "_fsdp_extra_forward_param_modules", ())
if isinstance(extra_forward_param_modules, nn.Module):
extra_forward_param_modules = (extra_forward_param_modules,)
if extra_forward_param_modules:
seen_param_ids = {id(param) for param in param_list}
for extra_module in extra_forward_param_modules:
for extra_param in extra_module.parameters():
if id(extra_param) not in seen_param_ids:
param_list.append(extra_param)
seen_param_ids.add(id(extra_param))

Copy link
Copy Markdown
Member

@cspades cspades Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a post-forward / post-backward hook that calls this function: release_module_parameters. It needs to be called on the modules in extra_forward_param_modules so we can re-shard them.

Early note, what about the pre-backward param unshard?

We need to ensure that any rogue weights are re-sharded, and un-sharded during the backward pass. Did you check that the Conv-1D weights are re-sharded?

Copy link
Copy Markdown
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still WIP, will approve when all features are implemented!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants