-
Notifications
You must be signed in to change notification settings - Fork 514
feat(megatron): Qwen3.5 dense + MoE training/inference support via megatron-bridge #1384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d6d1157
f6d9900
efbfaa1
ba669e0
d2eede4
5acffeb
ff83b72
1d91eec
a3f73e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,7 +60,9 @@ | |
| disable_dropout_in_model, | ||
| is_valid_vision_model, | ||
| lang_config, | ||
| requires_padded_seq, | ||
| ) | ||
| from areal.engine.megatron_utils import megatron_bridge_patches # noqa: F401 | ||
| from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager | ||
| from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms | ||
| from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper | ||
|
|
@@ -346,6 +348,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): | |
| ) | ||
|
|
||
| self.is_vision_model = is_valid_vision_model(self.hf_config.model_type) | ||
| # GDN/SSM models (e.g. Qwen3.5) reject packed THD input and must run | ||
| # the padded BSHD forward. Derived from model type rather than a | ||
| # config flag so the layout can't be mis-set. | ||
| self.use_padded_seq = requires_padded_seq(self.hf_config.model_type) | ||
| if self.is_vision_model: | ||
| if self.parallel_strategy.context_parallel_size > 1: | ||
| raise NotImplementedError( | ||
|
|
@@ -361,13 +367,40 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): | |
| f"Loaded processor and tokenizer." | ||
| ) | ||
|
|
||
| if self.use_padded_seq and self.parallel_strategy.context_parallel_size > 1: | ||
| raise NotImplementedError( | ||
| f"Context parallel (CP > 1) is not supported for " | ||
| f"model_type={self.hf_config.model_type!r}, which requires the " | ||
| "padded BSHD forward (it operates on [B, S] tensors while the " | ||
| "CP path packs sequences). " | ||
| f"Got context_parallel_size={self.parallel_strategy.context_parallel_size}." | ||
| ) | ||
|
|
||
| self.quantization_config = getattr( | ||
| self.hf_config, "quantization_config", None | ||
| ) | ||
|
|
||
| self._check_and_apply_fp8_config() | ||
| self._validate_fp8_consistency() | ||
|
|
||
| # Warn once if bridge-delegated weight sync was requested but a | ||
| # fallback condition forces the registry conversion path (the | ||
| # dispatch in _update_weights_from_distributed silently falls back). | ||
| if self.mcore_config.use_bridge_for_update_weights: | ||
| fallback_reasons = [] | ||
| if self.bridge_cls != "megatron-bridge": | ||
| fallback_reasons.append(f"bridge_type={self.bridge_cls!r}") | ||
| if self.quantization_config: | ||
| fallback_reasons.append("FP8/quantized training") | ||
| if self.config.use_lora: | ||
| fallback_reasons.append("LoRA enabled") | ||
| if fallback_reasons: | ||
| self.logger.warning( | ||
| "use_bridge_for_update_weights=True, but live weight sync " | ||
| "will use the registry conversion path instead because: " | ||
| f"{', '.join(fallback_reasons)}." | ||
| ) | ||
|
|
||
| with self.device: | ||
| models = make_mcore_model( | ||
| hf_config=self.hf_config, | ||
|
|
@@ -844,6 +877,7 @@ def forward_step(batch_iter, model): | |
| mb_input.padded_mb, | ||
| gather_cp_output=not cp_local, | ||
| is_vision_model=self.is_vision_model, | ||
| use_padded_seq=self.use_padded_seq, | ||
| ) | ||
|
|
||
| # Release tree attention metadata after forward pass | ||
|
|
@@ -1753,6 +1787,35 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: | |
|
|
||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| # Bridge delegation: when bridge_type=megatron-bridge and the user opts in, | ||
| # stream HF tensors directly from bridge.export_hf_weights. Falls back to | ||
| # the hand-rolled registry path for FP8 (quant_mapping in megatron-bridge | ||
| # is amax-style, not TE blockwise) and for LoRA (separate adapter export | ||
| # path not yet wired here). | ||
| use_bridge = ( | ||
| self.bridge_cls == "megatron-bridge" | ||
| and self.mcore_config.use_bridge_for_update_weights | ||
| and not self.quantization_config | ||
| and not self.config.use_lora | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a user sets both use_lora and use_bridge_for_update_weights at the same time, it is recommended to add a log to report this issue.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a one-time startup warning in |
||
| ) | ||
| if use_bridge: | ||
| self._update_weights_via_bridge(meta) | ||
| else: | ||
| self._update_weights_via_registry(meta) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| self.rollout_engine.continue_generation() | ||
|
|
||
| current_platform.synchronize() | ||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| def _update_weights_via_registry(self, meta: WeightUpdateMeta) -> None: | ||
| """Hand-rolled conversion path via convert_to_hf registry. | ||
|
|
||
| Used for FP8, LoRA, and models with a converter entry. Iterates this PP | ||
| rank's local params, TP-gathers per param, converts to HF layout, and | ||
| bucket-broadcasts to the rollout engine. | ||
| """ | ||
| num_moe_experts = self.tf_config.num_moe_experts | ||
| weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 | ||
|
|
||
|
|
@@ -1807,10 +1870,37 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: | |
|
|
||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| self.rollout_engine.continue_generation() | ||
| def _update_weights_via_bridge(self, meta: WeightUpdateMeta) -> None: | ||
| """Delegate live weight sync to megatron-bridge.export_hf_weights. | ||
|
|
||
| Streams (hf_name, hf_tensor) directly from the bridge, which handles | ||
| TP/EP/PP gather and layout transformation internally. Each PP rank | ||
| iterates the global parameter set (vs registry path which iterates only | ||
| local layers); non-PP-heads participate in collectives but do not bucket. | ||
| MoE expert weights are yielded inline by the bridge's grouped-export | ||
| path, so no separate second pass is needed. | ||
| """ | ||
| weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 | ||
| bucket: list[tuple[str, torch.Tensor]] = [] | ||
| bucket_size = 0 | ||
|
|
||
| for hf_name, hf_tensor in self.bridge.export_hf_weights( | ||
| self.model, | ||
| cpu=False, | ||
| show_progress=False, | ||
| ): | ||
| if not self.is_pipeline_parallel_head(): | ||
| continue | ||
| size = hf_tensor.numel() * hf_tensor.element_size() | ||
| if bucket_size + size > weight_chunked_mem_size: | ||
| self._update_bucket_weights_from_distributed(meta, bucket) | ||
| bucket_size = 0 | ||
| bucket.append((hf_name, hf_tensor.contiguous())) | ||
| bucket_size += size | ||
|
|
||
| if bucket: | ||
| self._update_bucket_weights_from_distributed(meta, bucket) | ||
|
|
||
| current_platform.synchronize() | ||
| dist.barrier(group=self.cpu_group) | ||
|
|
||
| @trace_perf("megatron_engine.update_weights_from_disk", category="io") | ||
|
|
@@ -1909,7 +1999,17 @@ def _load_model_from_hf(self, path: str) -> None: | |
| raise ValueError( | ||
| "Loading critic model is not supported with megatron-bridge." | ||
| ) | ||
| self.bridge.load_hf_weights(self.model, hf_path=path) | ||
| # megatron-bridge's load path builds shard-index tensors via | ||
| # ``torch.arange(...)`` to index HF weights that live on CPU. Under | ||
| # the caller's ``with self.device:`` (CUDA) context, those indices | ||
| # become CUDA tensors and the CPU-tensor indexing raises | ||
| # ``RuntimeError: indices should be either on cpu or on the same | ||
| # device as the indexed tensor (cpu)`` — triggered by ChunkedMapping | ||
| # for any model with GDN/Mamba-style conv1d weights (e.g. Qwen3.5). | ||
| # Force CPU as the factory-op default here; tensor data assignment | ||
| # to GPU model params is unaffected (handled by .copy_()). | ||
| with torch.device("cpu"): | ||
| self.bridge.load_hf_weights(self.model, hf_path=path) | ||
| else: | ||
| load_weights_from_hf_with_mbridge_fast( | ||
| bridge=self.bridge, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,84 @@ | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
|
|
||||||
| """Runtime patches for megatron-bridge bugs not yet in a released version. | ||||||
|
|
||||||
| Each patch is keyed to an upstream PR. Patches are not version-gated; instead | ||||||
| each one's hot path becomes a no-op once the upstream fix is present (the patch | ||||||
| checks for the missing attribute/behavior before acting), and an idempotency | ||||||
| sentinel prevents double-application. Apply patches at import time via | ||||||
| ``_apply_patches_on_import()`` at module bottom. | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import areal.utils.logging as logging | ||||||
|
|
||||||
| logger = logging.getLogger("MegatronBridgePatches") | ||||||
|
|
||||||
|
|
||||||
| def _patch_qwen3vl_pr3143_word_embeddings() -> None: | ||||||
| """megatron-bridge PR #3143: expose word_embeddings on MTP shadow embedding. | ||||||
|
|
||||||
| Bug (issue #3112 / PR #3143): in ``Qwen3VLGPTModel.forward``, when | ||||||
| ``mtp_process and sequence_parallel`` are both True, ``self.embedding`` is | ||||||
| temporarily replaced with a plain closure ``_sp_scatter_embedding``. The | ||||||
| closure lacks the ``word_embeddings`` attribute that | ||||||
| ``shared_embedding_or_output_weight()`` accesses during ``_postprocess`` | ||||||
| when ``share_embeddings_and_output_weights=True`` — typical for the | ||||||
| smaller Qwen3.5 dense models (0.8B/2B/4B). | ||||||
|
|
||||||
| Failure mode: | ||||||
| ``AttributeError: 'function' object has no attribute 'word_embeddings'`` | ||||||
|
|
||||||
| Affected versions: megatron-bridge 0.4.0 and 0.4.1. Fixed on ``main`` | ||||||
| by commit 20749b09 (PR #3143) but not in any non-alpha release yet. | ||||||
|
|
||||||
| Strategy: wrap ``Qwen3VLGPTModel._postprocess`` so it lazily restores | ||||||
| ``word_embeddings`` on the shadow embedding by inspecting its closure. | ||||||
| Closure-based recovery is non-invasive — we don't touch ``forward`` | ||||||
| itself (~70 LoC method). | ||||||
| """ | ||||||
| try: | ||||||
| from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import ( | ||||||
| Qwen3VLGPTModel, | ||||||
| ) | ||||||
| except ImportError: | ||||||
| return | ||||||
|
|
||||||
| if getattr(Qwen3VLGPTModel, "_areal_pr3143_applied", False): | ||||||
| return | ||||||
|
|
||||||
| _orig_postprocess = Qwen3VLGPTModel._postprocess | ||||||
|
|
||||||
| def _patched_postprocess(self, *args, **kwargs): | ||||||
| emb = self.__dict__.get("embedding") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| # Only intervene when the shadow closure is currently installed and | ||||||
| # lacks the expected attribute. | ||||||
| if ( | ||||||
| callable(emb) | ||||||
| and not hasattr(emb, "word_embeddings") | ||||||
| and emb.__closure__ is not None | ||||||
| ): | ||||||
| for cell in emb.__closure__: | ||||||
| try: | ||||||
| target = cell.cell_contents | ||||||
| except ValueError: | ||||||
| continue | ||||||
| if hasattr(target, "word_embeddings"): | ||||||
| emb.word_embeddings = target.word_embeddings | ||||||
| break | ||||||
| return _orig_postprocess(self, *args, **kwargs) | ||||||
|
|
||||||
| Qwen3VLGPTModel._postprocess = _patched_postprocess | ||||||
| Qwen3VLGPTModel._areal_pr3143_applied = True | ||||||
| logger.info( | ||||||
| "Applied megatron-bridge PR #3143 workaround: " | ||||||
| "Qwen3VLGPTModel shadow embedding word_embeddings restoration." | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def _apply_patches_on_import() -> None: | ||||||
| _patch_qwen3vl_pr3143_word_embeddings() | ||||||
|
|
||||||
|
|
||||||
| _apply_patches_on_import() | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registering the text-only
qwen3_5andqwen3_5_moemodels inVALID_VISION_MODELSis problematic. When a model is classified as a vision model (self.is_vision_model = True), the engine initialization attempts to load a processor viaload_hf_processor_and_tokenizer(self.config.path). Since standard text-only Qwen3.5 models (likeQwen/Qwen3.5-2Bin the example recipe) do not have a processor, this will raise anOSErrorand crash the engine on startup in production.Since the padded sequence reconstruction path in
packed_context_parallel_forwardis already guarded byuse_padded_seq(which is set toTruefor Qwen3.5), these models do not need to be registered as vision models to run on padded inputs. They should be removed fromVALID_VISION_MODELS.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qwen3.5 has no text-only variant — every model in the series uses the multimodal
Qwen3_5ForConditionalGeneration/Qwen3_5MoeForConditionalGenerationarchitecture and ships apreprocessor_config.json(plusvideo_preprocessor_config.json), soload_hf_processor_and_tokenizerresolves a processor for whole family. The dense single-GPU / TP / PP tests added in this PR run with this registration and pass, so there is no startupOSError. Registering the whole family as a vision model is intentional: there is no separateqwen3_5_vlmodel_type.