Skip to content

feat(megatron): Qwen3.5 dense + MoE training/inference support via megatron-bridge#1384

Open
Adiactive wants to merge 9 commits into
areal-project:mainfrom
Adiactive:feat/megatron-qwen3_5-pr
Open

feat(megatron): Qwen3.5 dense + MoE training/inference support via megatron-bridge#1384
Adiactive wants to merge 9 commits into
areal-project:mainfrom
Adiactive:feat/megatron-qwen3_5-pr

Conversation

@Adiactive
Copy link
Copy Markdown
Contributor

@Adiactive Adiactive commented Jun 2, 2026

Description

Adds Megatron support for the Qwen3.5 series (both dense and MoE) on AReaL. The adaptations center on its new GDN (Gated Delta Net) hybrid-attention architecture, which mcore and megatron-bridge handle differently from a standard transformer — across weight conversion, the forward input format, and the vLLM rollout kernel.

Concretely it adds:

  • megatron-bridge-delegated live weight sync (use_bridge_for_update_weights on MegatronEngineConfig): _update_weights_from_distributed dispatches to bridge.export_hf_weights instead of the hand-rolled convert_*_to_hf registry, so GDN conversion (TP all-gather, GLU linear_fc1 stride-2 de-interleave) and MoE TEGroupedLinear expert weights are handled by the bridge itself.
  • BSHD padded forward path (use_padded_seq on MegatronEngineConfig): the GDN/SSM kernels reject packed (THD) input — mcore's GDN raises NotImplementedError("GDN does not support packed sequence for now.") (megatron/core/ssm/gated_delta_net.py#L301, megatron-core 0.17.0; THD support for GDN is landing upstream via Megatron-LM #2644 but is not yet merged into any release). This generalizes packed_context_parallel_forward so those layers run on [B, S] padded tensors reconstructed from cu_seqlens; the data container stays THD-packed, only the model forward is padded, sharing the reconstruction path with the existing VLM route. verl uses an equivalent control for these models — its data_format switch selects the same padded [B, S] input over packed THD.
  • megatron-bridge runtime patch (megatron_bridge_patches.py): monkey-patches Qwen3VLGPTModel._postprocess to restore word_embeddings on the MTP shadow-embedding closure (upstream bug, megatron-bridge PR #3143; not in any released version through 0.4.1). Auto-disables once a release ships the fix.
  • Vision-model registration: registers qwen3_5 / qwen3_5_moe in VALID_VISION_MODELS and fixes a non-contiguous broadcast in the engine.
  • vLLM gdn_prefill_backend config: exposes vLLM's GDN prefill backend selector so configs can set triton, avoiding the FlashInfer GDN prefill kernel hang (shm_broadcast stall -> sample_tokens timeout -> EngineDeadError). Upstream reports: vllm-project/vllm#38916, #36631, #35496; root cause in the FlashInfer kernel: #3329. The default None emits no flag, so non-GDN models are unaffected.
  • Example recipe (examples/vlm/qwen3_5_2b_megatron_geometry3k_grpo.yaml): a single-node (8-GPU) GRPO config for Qwen3.5-2B on geometry3k that wires the new flags end to end — bridge_type: megatron-bridge, use_padded_seq, use_bridge_for_update_weights on the actor, and gdn_prefill_backend: triton on the vLLM rollout.

Related Issue

N/A

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable): None. All new config fields default to their prior behavior (use_bridge_for_update_weights=False, use_padded_seq=False, gdn_prefill_backend=None), so existing models and configs are unaffected.

Tests

Added to tests/test_megatron_engine_distributed.py (+ runner scaffolding in tests/torchrun/run_megatron_engine_distributed.py). Constraints encoded in the configs: TP ≤ 2 (Qwen3.5-35B-A3B has num_key_value_heads = 2) and no context parallel (the GDN/SSM layers reject packed sequences; see Megatron-LM #4043 which is closed but not yet picked into any release; and the VLM-CP guard), so ranks are filled with PP/DP/EP, no CP.

Dense Qwen3.5 (validated on H100):

  • test_qwen3_5_single_gpu_forward — engine init + forward on 1 GPU (exercises the megatron-bridge load path, including the with torch.device("cpu") GDN ChunkedMapping fix).
  • test_qwen3_5_tensor_parallel — forward, TP=2.
  • test_qwen3_5_pipeline_parallel — forward, PP=2.
  • test_qwen3_5_hf_save_load — train -> HF save -> zero -> HF load -> retrain, weights must match, TP=2. (Uses HF safetensors rather than mcore DCP because dist_checkpointing does not yet support SSM/GDN flattened_range tensors.)
  • test_qwen3_5_virtual_pipeline_parallelskipped (@pytest.mark.skip): megatron-bridge's _broadcast_shared_embeddings does not support VPP + tied embeddings (the small dense Qwen3.5 models tie embeddings); VPP is an optional scheduling optimization, not required for initial support.
  • test_qwen3_5_grad_norm_mb_invarianceskipped (@pytest.mark.skip): the BSHD path (use_padded_seq) is not micro-batch invariant — per-microbatch padding boundaries cause a small grad_norm drift; verl sidesteps this by using micro-batch size 1.

Qwen3.5-MoE (Qwen3.5-35B-A3B; both pass on 4 GPUs):

  • test_qwen3_5_moe_expert_parallel — megatron forward + cross-rank logprob consistency under PP=2/TP=2/EP=2. The megatron-vs-FSDP logit cross-check is skipped for this model (_MODEL_SKIP_FSDP_COMPARE): AReaL's FSDP engine materializes the full fp32 35B per rank on load, which cannot fit alongside the megatron model. Forward-conversion correctness is instead covered by the round-trip below.
  • test_qwen3_5_moe_hf_save_load — HF save -> zero -> load -> compare under PP=2/TP=2/EP=2, validating MoE expert-weight conversion (TEGroupedLinear weight0..N + GLU linear_fc1 stride-2). The train step is skipped (_MODEL_SAVELOAD_SKIP_TRAIN) because a 35B AdamW optimizer state does not fit; the loaded HF weights are already non-trivial, so the round-trip still exercises conversion without an optimizer.

Additional Context

Dependency note: requires megatron-bridge (the engine selects it via bridge_type=megatron-bridge for the Qwen3.5 family); mbridge falls back to a Qwen3 substring match and emits wrong shapes for the GDN hybrid attention. Validated against megatron-core 0.17.0 and megatron-bridge 0.4.0.

Training Reward Example

Image: ghcr.io/inclusionai/areal-runtime:v1.0.4-vllm
Dataset: Geometry3k

Scheduler: Local
Qwen3.5-2B
image

Scheduler: Slurm
Qwen3.5-27B
image

Qwen3.5-35B-A3B
9c2fd3b3-e93d-4151-9d25-05194d4924ee

Adiactive added 7 commits June 2, 2026 16:37
…weights

Add use_bridge_for_update_weights flag that routes the live weight
update path through megatron-bridge.export_hf_weights instead of the
hand-rolled convert_to_hf registry. Required for new model families
(e.g. Qwen3.5) that don't have a registry entry.

The bridge handles TP/EP/PP gather and HF layout transformation
internally; AReaL keeps the bucketed broadcast loop unchanged. FP8
and LoRA paths fall back to the registry automatically.

Also fix a latent device-context bug in _load_model_from_hf:
megatron-bridge builds shard-index tensors via torch.arange() under
the caller's `with self.device:` context, putting them on CUDA while
HF weights are loaded on CPU. The resulting indexing error trips
ChunkedMapping for any model with GDN/Mamba conv1d weights (e.g.
Qwen3.5). Force CPU as the factory-op default just around the
bridge.load_hf_weights call.

Key changes:
- New MegatronEngineConfig.use_bridge_for_update_weights flag
- Refactor _update_weights_from_distributed into dispatch +
  _update_weights_via_registry helper
- New _update_weights_via_bridge streams from bridge.export_hf_weights
  and reuses the bucket broadcast loop
- Wrap bridge.load_hf_weights in `with torch.device("cpu"):` to
  prevent CUDA index / CPU tensor mismatch in ChunkedMapping
Add 1-GPU smoke + 5 multi-GPU tests (TP=2, PP=2, PP+VPP=2, DP=2
grad_norm invariance, DCP save/load) mirroring the Qwen3 dense set.
All Qwen3.5 tests route through bridge_type=megatron-bridge because
its GDN hybrid attention is only handled by megatron-bridge's model
definitions (mbridge would substring-match qwen3 and emit wrong
shapes).

NOTE: these tests currently fail at engine.forward because
megatron-core's GDN layer raises NotImplementedError on packed
(THD) sequences. A follow-up will add a BSHD path mirroring verl's
data_format switch; until then these tests document the expected
coverage and act as a regression target.

Key changes:
- Add qwen3_5 to MODEL_PATHS in run_megatron_engine_distributed.py
- Re-key MODEL_PATHS from areal.utils.testing_utils canonical paths
  so local-path overrides propagate from a single source
- Wire bridge_type via _MODEL_BRIDGE_OVERRIDES (qwen3_5 →
  megatron-bridge)
- Six test_qwen3_5_* tests in test_megatron_engine_distributed.py
…s for Qwen3.5

Qwen3.5's GDN (Gated Delta Net) layers reject packed (THD) sequences
in megatron-core. Add a BSHD path that reconstructs [B, S] padded input
from cu_seqlens inside packed_context_parallel_forward, mirroring the
existing VLM 2D-reconstruction logic but for text-only models.

Also add runtime monkey-patch for megatron-bridge PR #3143 (MTP shadow
embedding missing word_embeddings attribute under sequence_parallel +
tied embeddings). The patch lazily restores the attribute from the
closure before _postprocess runs, avoiding the need to replace the
full forward method.

Key changes:
- New MegatronEngineConfig.use_padded_seq flag (BSHD mode)
- Generalize VLM 2D-reconstruction path in packed_context_parallel_forward
  to also fire on use_padded_seq=True for non-VLM models
- CP>1 guard for use_padded_seq (same constraint VLM has)
- New megatron_bridge_patches.py with PR #3143 workaround
- New train_hf_save_load test_type in runner (replaces DCP for SSM
  models whose flattened_range tensors are unsupported by mcore DCP)
- Qwen3.5 tests: 1-GPU, TP=2, PP=2, HF save/load all pass;
  VPP and grad_norm_mb_invariance skipped with documented reasons
…broadcast

Add qwen3_5 and qwen3_5_moe to VALID_VISION_MODELS so the engine
loads the HF processor and passes pixel_values / image_grid_thw
through the VLM forward path. Qwen3.5's base architecture
(Qwen3_5ForConditionalGeneration) is inherently multimodal —
there is no separate qwen3_5_vl model_type.

Also fix a ValueError in _update_weights_via_bridge where
bridge.export_hf_weights yields non-contiguous tensor views
(from QKV split / gate-up chunk) that NCCL broadcast rejects.
Call .contiguous() before bucketing.
Qwen3.5 and other GDN hybrids default to vLLM's FlashInfer GDN prefill kernel, which hangs on SM90 — a runtime mbarrier deadlock (flashinfer #2623/#3329) and a JIT-compile deadlock (vLLM #41865/#39287), surfacing as shm_broadcast stall -> sample_tokens timeout -> EngineDeadError. Expose gdn_prefill_backend so configs can set "triton" (stable Triton/FLA kernel). None default emits no flag, so non-GDN models are unaffected.
Qwen3.5-35B-A3B megatron coverage via megatron-bridge, both running on 4 GPUs:

- test_qwen3_5_moe_expert_parallel: PP2/TP2/EP2 forward + cross-rank logprob consistency. CP is unavailable for the GDN series (Megatron-LM #4043) and the full-attention layers cap TP<=2, so ranks are filled with PP at EP=2.
- test_qwen3_5_moe_hf_save_load: save -> zero -> load -> compare round-trip validating MoE expert-weight conversion (TEGroupedLinear weight0..N + GLU linear_fc1 stride-2). The train step is skipped (_MODEL_SAVELOAD_SKIP_TRAIN) since a 35B-A3B optimizer state does not fit; the loaded HF weights are already non-trivial.

The megatron-vs-FSDP logit comparison is skipped for this model (_MODEL_SKIP_FSDP_COMPARE): AReaL's FSDP engine materializes the full fp32 35B per rank on load, which cannot fit.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen3.5 and Qwen3.5 MoE models, introducing configurations for padded sequence layouts and weight updates via megatron-bridge, along with runtime patches and distributed tests. Feedback highlights a critical issue where registering text-only Qwen3.5 models under VALID_VISION_MODELS will cause startup crashes due to missing processors. Additionally, suggestions are made to use getattr instead of direct __dict__ access for robustness, and to restrict gdn_prefill_backend choices using field metadata.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 9 to 12
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
"gemma3",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Registering the text-only qwen3_5 and qwen3_5_moe models in VALID_VISION_MODELS is problematic. When a model is classified as a vision model (self.is_vision_model = True), the engine initialization attempts to load a processor via load_hf_processor_and_tokenizer(self.config.path). Since standard text-only Qwen3.5 models (like Qwen/Qwen3.5-2B in the example recipe) do not have a processor, this will raise an OSError and crash the engine on startup in production.

Since the padded sequence reconstruction path in packed_context_parallel_forward is already guarded by use_padded_seq (which is set to True for Qwen3.5), these models do not need to be registered as vision models to run on padded inputs. They should be removed from VALID_VISION_MODELS.

Suggested change
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
"gemma3",
"qwen3_vl_moe",
"gemma3",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3.5 has no text-only variant — every model in the series uses the multimodal Qwen3_5ForConditionalGeneration / Qwen3_5MoeForConditionalGeneration architecture and ships a preprocessor_config.json (plus video_preprocessor_config.json), so load_hf_processor_and_tokenizer resolves a processor for whole family. The dense single-GPU / TP / PP tests added in this PR run with this registration and pass, so there is no startup OSError. Registering the whole family as a vision model is intentional: there is no separate qwen3_5_vl model_type.

_orig_postprocess = Qwen3VLGPTModel._postprocess

def _patched_postprocess(self, *args, **kwargs):
emb = self.__dict__.get("embedding")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using getattr(self, "embedding", None) is more robust and idiomatic than directly accessing self.__dict__.get("embedding"). Direct __dict__ access bypasses standard Python attribute resolution (including properties, custom __getattr__ overrides, and inheritance) and is generally discouraged unless strictly necessary.

Suggested change
emb = self.__dict__.get("embedding")
emb = getattr(self, "embedding", None)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.__dict__.get("embedding") is intentional. The upstream bug temporarily replaces self.embedding with a plain closure (_sp_scatter_embedding); because that closure is not an nn.Module, nn.Module.__setattr__ stores it directly in self.__dict__, whereas the real embedding lives in self._modules and is only reachable via nn.Module.__getattr__. So __dict__.get("embedding") detects specifically the installed shadow closure (returning None when it isn't installed), which is exactly the state we want to act on — getattr(self, "embedding") would instead return the real LanguageModelEmbedding when the shadow is absent. The downstream callable(emb) and not hasattr(emb, "word_embeddings") guard keeps it correct either way.

Comment thread areal/api/cli_args.py Outdated
Comment thread areal/api/cli_args.py Outdated
},
)

use_padded_seq: bool = field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend changing use_padded_seq from a CLI flag to an automatic decision based on model_type (following the pattern of is_valid_vision_model).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently use_padded_seq is only a must for Qwen3.5, i.e. other supported model can simply ignore and use default field, which pretty much implies an "automatic decision" to use packed seq unless this config is set explicitly.

I do agree the engine should auto-derive this config e.g. default use_padded_seq to True when the model is GDN, since BSHD is a fallback only when THD is not supported. So user shouldn't bother adding this when training Qwen3.5 models.

Do you wanna keep this flag as an optional override (as what verl does, we provide these two options after all) or remove and let engine fully decide?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current PR can first remove this option, as it is only needed for Qwen3.5 at the moment. If other models depend on this option in the future, it can be added separately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the use_padded_seq field entirely and made the layout an automatic decision from model_type, mirroring is_valid_vision_model. If a future model ever needs to override the layout, we can reintroduce an explicit knob at that point

self.bridge_cls == "megatron-bridge"
and self.mcore_config.use_bridge_for_update_weights
and not self.quantization_config
and not self.config.use_lora
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user sets both use_lora and use_bridge_for_update_weights at the same time, it is recommended to add a log to report this issue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a one-time startup warning in MegatronEngine.initialize that fires when use_bridge_for_update_weights=True but a fallback condition forces the registry conversion path. It covers all silent-fallback cases - FP8/quantized training/lora

Adiactive added a commit to Adiactive/AReaL that referenced this pull request Jun 3, 2026
The padded (BSHD) vs packed (THD) forward layout is a hard architectural
property of the model -- GDN/SSM kernels (the Qwen3.5 family) reject
packed sequences -- not a user tunable. Exposing it as the
`use_padded_seq` config field let it be mis-set and risked silent
correctness or crash issues. Derive it from `model_type` instead so the
layout can never disagree with the architecture.

Also surface a startup warning when `use_bridge_for_update_weights=True`
but a fallback condition (non-megatron-bridge, FP8/quantized, or LoRA)
silently routes live weight sync through the registry path, so the
effective behavior is visible in logs.

Key changes:
- Add requires_padded_seq(model_type) helper in engine/core/model.py
- Derive self.use_padded_seq from model_type in MegatronEngine.initialize
- Remove use_padded_seq from MegatronEngineConfig and regenerate CLI docs
- Warn once when bridge weight-sync falls back to the registry path
- Drop the test-runner override map and example yaml flag

Refs: areal-project#1384
The padded (BSHD) vs packed (THD) forward layout is a hard architectural
property of the model -- GDN/SSM kernels (the Qwen3.5 family) reject
packed sequences -- not a user tunable. Exposing it as the
`use_padded_seq` config field let it be mis-set and risked silent
correctness or crash issues. Derive it from `model_type` instead so the
layout can never disagree with the architecture.

Also surface a startup warning when `use_bridge_for_update_weights=True`
but a fallback condition (non-megatron-bridge, FP8/quantized, or LoRA)
silently routes live weight sync through the registry path, so the
effective behavior is visible in logs.

Key changes:
- Add requires_padded_seq(model_type) helper in engine/core/model.py
- Derive self.use_padded_seq from model_type in MegatronEngine.initialize
- Remove use_padded_seq from MegatronEngineConfig and regenerate CLI docs
- Warn once when bridge weight-sync falls back to the registry path
- Drop the test-runner override map and example yaml flag

Refs: areal-project#1384
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants