-
Notifications
You must be signed in to change notification settings - Fork 514
feat(megatron): Qwen3.5 dense + MoE training/inference support via megatron-bridge #1384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
d6d1157
f6d9900
efbfaa1
ba669e0
d2eede4
5acffeb
ff83b72
1d91eec
a3f73e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |||||||||||||
| "qwen2_5_vl", | ||||||||||||||
| "qwen3_vl", | ||||||||||||||
| "qwen3_vl_moe", | ||||||||||||||
| "qwen3_5", | ||||||||||||||
| "qwen3_5_moe", | ||||||||||||||
| "gemma3", | ||||||||||||||
|
Comment on lines
9
to
12
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Registering the text-only Since the padded sequence reconstruction path in
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Qwen3.5 has no text-only variant — every model in the series uses the multimodal |
||||||||||||||
| ] | ||||||||||||||
| # This registry is used to check if a model is a vision model that we have checked it works with AReaL. | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,7 @@ | |
| is_valid_vision_model, | ||
| lang_config, | ||
| ) | ||
| from areal.engine.megatron_utils import megatron_bridge_patches # noqa: F401 | ||
| from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager | ||
| from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms | ||
| from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper | ||
|
|
@@ -361,6 +362,17 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): | |
| f"Loaded processor and tokenizer." | ||
| ) | ||
|
|
||
| if ( | ||
| self.mcore_config.use_padded_seq | ||
| and self.parallel_strategy.context_parallel_size > 1 | ||
| ): | ||
| raise NotImplementedError( | ||
| "Context parallel (CP > 1) is not supported with " | ||
| "use_padded_seq=True (the padded BSHD path operates on " | ||
| "[B, S] tensors and the CP path packs sequences). " | ||
| f"Got context_parallel_size={self.parallel_strategy.context_parallel_size}." | ||
| ) | ||
|
|
||
| self.quantization_config = getattr( | ||
| self.hf_config, "quantization_config", None | ||
| ) | ||
|
|
@@ -844,6 +856,7 @@ def forward_step(batch_iter, model): | |
| mb_input.padded_mb, | ||
| gather_cp_output=not cp_local, | ||
| is_vision_model=self.is_vision_model, | ||
| use_padded_seq=self.mcore_config.use_padded_seq, | ||
| ) | ||
|
|
||
| # Release tree attention metadata after forward pass | ||
|
|
@@ -1753,6 +1766,35 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: | |
|
|
||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| # Bridge delegation: when bridge_type=megatron-bridge and the user opts in, | ||
| # stream HF tensors directly from bridge.export_hf_weights. Falls back to | ||
| # the hand-rolled registry path for FP8 (quant_mapping in megatron-bridge | ||
| # is amax-style, not TE blockwise) and for LoRA (separate adapter export | ||
| # path not yet wired here). | ||
| use_bridge = ( | ||
| self.bridge_cls == "megatron-bridge" | ||
| and self.mcore_config.use_bridge_for_update_weights | ||
| and not self.quantization_config | ||
| and not self.config.use_lora | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a user sets both use_lora and use_bridge_for_update_weights at the same time, it is recommended to add a log to report this issue.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a one-time startup warning in |
||
| ) | ||
| if use_bridge: | ||
| self._update_weights_via_bridge(meta) | ||
| else: | ||
| self._update_weights_via_registry(meta) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| self.rollout_engine.continue_generation() | ||
|
|
||
| current_platform.synchronize() | ||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| def _update_weights_via_registry(self, meta: WeightUpdateMeta) -> None: | ||
| """Hand-rolled conversion path via convert_to_hf registry. | ||
|
|
||
| Used for FP8, LoRA, and models with a converter entry. Iterates this PP | ||
| rank's local params, TP-gathers per param, converts to HF layout, and | ||
| bucket-broadcasts to the rollout engine. | ||
| """ | ||
| num_moe_experts = self.tf_config.num_moe_experts | ||
| weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 | ||
|
|
||
|
|
@@ -1807,10 +1849,37 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: | |
|
|
||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| self.rollout_engine.continue_generation() | ||
| def _update_weights_via_bridge(self, meta: WeightUpdateMeta) -> None: | ||
| """Delegate live weight sync to megatron-bridge.export_hf_weights. | ||
|
|
||
| Streams (hf_name, hf_tensor) directly from the bridge, which handles | ||
| TP/EP/PP gather and layout transformation internally. Each PP rank | ||
| iterates the global parameter set (vs registry path which iterates only | ||
| local layers); non-PP-heads participate in collectives but do not bucket. | ||
| MoE expert weights are yielded inline by the bridge's grouped-export | ||
| path, so no separate second pass is needed. | ||
| """ | ||
| weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 | ||
| bucket: list[tuple[str, torch.Tensor]] = [] | ||
| bucket_size = 0 | ||
|
|
||
| for hf_name, hf_tensor in self.bridge.export_hf_weights( | ||
| self.model, | ||
| cpu=False, | ||
| show_progress=False, | ||
| ): | ||
| if not self.is_pipeline_parallel_head(): | ||
| continue | ||
| size = hf_tensor.numel() * hf_tensor.element_size() | ||
| if bucket_size + size > weight_chunked_mem_size: | ||
| self._update_bucket_weights_from_distributed(meta, bucket) | ||
| bucket_size = 0 | ||
| bucket.append((hf_name, hf_tensor.contiguous())) | ||
| bucket_size += size | ||
|
|
||
| if bucket: | ||
| self._update_bucket_weights_from_distributed(meta, bucket) | ||
|
|
||
| current_platform.synchronize() | ||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| @trace_perf("megatron_engine.update_weights_from_disk", category="io") | ||
|
|
@@ -1909,7 +1978,17 @@ def _load_model_from_hf(self, path: str) -> None: | |
| raise ValueError( | ||
| "Loading critic model is not supported with megatron-bridge." | ||
| ) | ||
| self.bridge.load_hf_weights(self.model, hf_path=path) | ||
| # megatron-bridge's load path builds shard-index tensors via | ||
| # ``torch.arange(...)`` to index HF weights that live on CPU. Under | ||
| # the caller's ``with self.device:`` (CUDA) context, those indices | ||
| # become CUDA tensors and the CPU-tensor indexing raises | ||
| # ``RuntimeError: indices should be either on cpu or on the same | ||
| # device as the indexed tensor (cpu)`` — triggered by ChunkedMapping | ||
| # for any model with GDN/Mamba-style conv1d weights (e.g. Qwen3.5). | ||
| # Force CPU as the factory-op default here; tensor data assignment | ||
| # to GPU model params is unaffected (handled by .copy_()). | ||
| with torch.device("cpu"): | ||
| self.bridge.load_hf_weights(self.model, hf_path=path) | ||
| else: | ||
| load_weights_from_hf_with_mbridge_fast( | ||
| bridge=self.bridge, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,82 @@ | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
|
|
||||||
| """Runtime patches for megatron-bridge bugs not yet in a released version. | ||||||
|
|
||||||
| Each patch is keyed to an upstream PR and is auto-disabled once megatron-bridge | ||||||
| ships a release containing the fix. Apply patches at import time via | ||||||
| ``_apply_patches_on_import()`` at module bottom. | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import areal.utils.logging as logging | ||||||
|
|
||||||
| logger = logging.getLogger("MegatronBridgePatches") | ||||||
|
|
||||||
|
|
||||||
| def _patch_qwen3vl_pr3143_word_embeddings() -> None: | ||||||
| """megatron-bridge PR #3143: expose word_embeddings on MTP shadow embedding. | ||||||
|
|
||||||
| Bug (issue #3112 / PR #3143): in ``Qwen3VLGPTModel.forward``, when | ||||||
| ``mtp_process and sequence_parallel`` are both True, ``self.embedding`` is | ||||||
| temporarily replaced with a plain closure ``_sp_scatter_embedding``. The | ||||||
| closure lacks the ``word_embeddings`` attribute that | ||||||
| ``shared_embedding_or_output_weight()`` accesses during ``_postprocess`` | ||||||
| when ``share_embeddings_and_output_weights=True`` — typical for the | ||||||
| smaller Qwen3.5 dense models (0.8B/2B/4B). | ||||||
|
|
||||||
| Failure mode: | ||||||
| ``AttributeError: 'function' object has no attribute 'word_embeddings'`` | ||||||
|
|
||||||
| Affected versions: megatron-bridge 0.4.0 and 0.4.1. Fixed on ``main`` | ||||||
| by commit 20749b09 (PR #3143) but not in any non-alpha release yet. | ||||||
|
|
||||||
| Strategy: wrap ``Qwen3VLGPTModel._postprocess`` so it lazily restores | ||||||
| ``word_embeddings`` on the shadow embedding by inspecting its closure. | ||||||
| Closure-based recovery is non-invasive — we don't touch ``forward`` | ||||||
| itself (~70 LoC method). | ||||||
| """ | ||||||
| try: | ||||||
| from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import ( | ||||||
| Qwen3VLGPTModel, | ||||||
| ) | ||||||
| except ImportError: | ||||||
| return | ||||||
|
|
||||||
| if getattr(Qwen3VLGPTModel, "_areal_pr3143_applied", False): | ||||||
| return | ||||||
|
|
||||||
| _orig_postprocess = Qwen3VLGPTModel._postprocess | ||||||
|
|
||||||
| def _patched_postprocess(self, *args, **kwargs): | ||||||
| emb = self.__dict__.get("embedding") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| # Only intervene when the shadow closure is currently installed and | ||||||
| # lacks the expected attribute. | ||||||
| if ( | ||||||
| callable(emb) | ||||||
| and not hasattr(emb, "word_embeddings") | ||||||
| and emb.__closure__ is not None | ||||||
| ): | ||||||
| for cell in emb.__closure__: | ||||||
| try: | ||||||
| target = cell.cell_contents | ||||||
| except ValueError: | ||||||
| continue | ||||||
| if hasattr(target, "word_embeddings"): | ||||||
| emb.word_embeddings = target.word_embeddings | ||||||
| break | ||||||
| return _orig_postprocess(self, *args, **kwargs) | ||||||
|
|
||||||
| Qwen3VLGPTModel._postprocess = _patched_postprocess | ||||||
| Qwen3VLGPTModel._areal_pr3143_applied = True | ||||||
| logger.info( | ||||||
| "Applied megatron-bridge PR #3143 workaround: " | ||||||
| "Qwen3VLGPTModel shadow embedding word_embeddings restoration." | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def _apply_patches_on_import() -> None: | ||||||
| _patch_qwen3vl_pr3143_word_embeddings() | ||||||
|
|
||||||
|
|
||||||
| _apply_patches_on_import() | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend changing use_padded_seq from a CLI flag to an automatic decision based on model_type (following the pattern of is_valid_vision_model).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently
use_padded_seqis only a must for Qwen3.5, i.e. other supported model can simply ignore and use default field, which pretty much implies an "automatic decision" to use packed seq unless this config is set explicitly.I do agree the engine should auto-derive this config e.g. default
use_padded_seqtoTruewhen the model is GDN, since BSHD is a fallback only when THD is not supported. So user shouldn't bother adding this when training Qwen3.5 models.Do you wanna keep this flag as an optional override (as what verl does, we provide these two options after all) or remove and let engine fully decide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current PR can first remove this option, as it is only needed for Qwen3.5 at the moment. If other models depend on this option in the future, it can be added separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the
use_padded_seqfield entirely and made the layout an automatic decision frommodel_type, mirroringis_valid_vision_model. If a future model ever needs to override the layout, we can reintroduce an explicit knob at that point