Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,16 @@ class MegatronEngineConfig:
},
)

use_bridge_for_update_weights: bool = field(
default=False,
metadata={
"help": "When True and bridge_type='megatron-bridge', delegate live "
"weight sync to bridge.export_hf_weights instead of the hand-rolled "
"convert_to_hf registry. Required for models without a registry entry "
"(e.g. Qwen3.5). FP8 paths fall back to the registry automatically.",
},
)


class SchedulingStrategyType(str, Enum):
separation = "separation"
Expand Down Expand Up @@ -1713,6 +1723,16 @@ class vLLMConfig:
)
enable_sleep_mode: bool = False
uvicorn_log_level: str = "warning"
# GDN prefill backend for hybrid models like Qwen3.5; "triton" avoids the
# FlashInfer GDN-kernel hang (vLLM #38916). None leaves vLLM's default, so
# no flag is emitted and non-GDN models are unaffected.
gdn_prefill_backend: str | None = field(
default=None,
metadata={
"help": "GDN prefill backend for hybrid models like Qwen3.5.",
"choices": ["triton", "flashinfer"],
},
)
# lora
enable_lora: bool = False
max_lora_rank: int = 16 # vllm's default
Expand Down
12 changes: 12 additions & 0 deletions areal/engine/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
"gemma3",
Comment on lines 9 to 12
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Registering the text-only qwen3_5 and qwen3_5_moe models in VALID_VISION_MODELS is problematic. When a model is classified as a vision model (self.is_vision_model = True), the engine initialization attempts to load a processor via load_hf_processor_and_tokenizer(self.config.path). Since standard text-only Qwen3.5 models (like Qwen/Qwen3.5-2B in the example recipe) do not have a processor, this will raise an OSError and crash the engine on startup in production.

Since the padded sequence reconstruction path in packed_context_parallel_forward is already guarded by use_padded_seq (which is set to True for Qwen3.5), these models do not need to be registered as vision models to run on padded inputs. They should be removed from VALID_VISION_MODELS.

Suggested change
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
"gemma3",
"qwen3_vl_moe",
"gemma3",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3.5 has no text-only variant — every model in the series uses the multimodal Qwen3_5ForConditionalGeneration / Qwen3_5MoeForConditionalGeneration architecture and ships a preprocessor_config.json (plus video_preprocessor_config.json), so load_hf_processor_and_tokenizer resolves a processor for whole family. The dense single-GPU / TP / PP tests added in this PR run with this registration and pass, so there is no startup OSError. Registering the whole family as a vision model is intentional: there is no separate qwen3_5_vl model_type.

]
# This registry is used to check if a model is a vision model that we have checked it works with AReaL.
Expand Down Expand Up @@ -83,6 +85,16 @@ def is_qwen3_5_model(model_type: str) -> bool:
return model_type in ["qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"]


def requires_padded_seq(model_type: str) -> bool:
"""Whether the model must run the padded (BSHD) forward instead of packed (THD).

GDN/SSM models (currently the Qwen3.5 family) reject packed sequences in their
attention/SSM kernels, so they must run on padded ``[B, S]`` input. THD stays
the default for every other model.
"""
return is_qwen3_5_model(model_type)


# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
Expand Down
108 changes: 104 additions & 4 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
disable_dropout_in_model,
is_valid_vision_model,
lang_config,
requires_padded_seq,
)
from areal.engine.megatron_utils import megatron_bridge_patches # noqa: F401
from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager
from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms
from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper
Expand Down Expand Up @@ -346,6 +348,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
)

self.is_vision_model = is_valid_vision_model(self.hf_config.model_type)
# GDN/SSM models (e.g. Qwen3.5) reject packed THD input and must run
# the padded BSHD forward. Derived from model type rather than a
# config flag so the layout can't be mis-set.
self.use_padded_seq = requires_padded_seq(self.hf_config.model_type)
if self.is_vision_model:
if self.parallel_strategy.context_parallel_size > 1:
raise NotImplementedError(
Expand All @@ -361,13 +367,40 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
f"Loaded processor and tokenizer."
)

if self.use_padded_seq and self.parallel_strategy.context_parallel_size > 1:
raise NotImplementedError(
f"Context parallel (CP > 1) is not supported for "
f"model_type={self.hf_config.model_type!r}, which requires the "
"padded BSHD forward (it operates on [B, S] tensors while the "
"CP path packs sequences). "
f"Got context_parallel_size={self.parallel_strategy.context_parallel_size}."
)

self.quantization_config = getattr(
self.hf_config, "quantization_config", None
)

self._check_and_apply_fp8_config()
self._validate_fp8_consistency()

# Warn once if bridge-delegated weight sync was requested but a
# fallback condition forces the registry conversion path (the
# dispatch in _update_weights_from_distributed silently falls back).
if self.mcore_config.use_bridge_for_update_weights:
fallback_reasons = []
if self.bridge_cls != "megatron-bridge":
fallback_reasons.append(f"bridge_type={self.bridge_cls!r}")
if self.quantization_config:
fallback_reasons.append("FP8/quantized training")
if self.config.use_lora:
fallback_reasons.append("LoRA enabled")
if fallback_reasons:
self.logger.warning(
"use_bridge_for_update_weights=True, but live weight sync "
"will use the registry conversion path instead because: "
f"{', '.join(fallback_reasons)}."
)

with self.device:
models = make_mcore_model(
hf_config=self.hf_config,
Expand Down Expand Up @@ -844,6 +877,7 @@ def forward_step(batch_iter, model):
mb_input.padded_mb,
gather_cp_output=not cp_local,
is_vision_model=self.is_vision_model,
use_padded_seq=self.use_padded_seq,
)

# Release tree attention metadata after forward pass
Expand Down Expand Up @@ -1753,6 +1787,35 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:

dist.barrier(group=self.cpu_group)

# Bridge delegation: when bridge_type=megatron-bridge and the user opts in,
# stream HF tensors directly from bridge.export_hf_weights. Falls back to
# the hand-rolled registry path for FP8 (quant_mapping in megatron-bridge
# is amax-style, not TE blockwise) and for LoRA (separate adapter export
# path not yet wired here).
use_bridge = (
self.bridge_cls == "megatron-bridge"
and self.mcore_config.use_bridge_for_update_weights
and not self.quantization_config
and not self.config.use_lora
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user sets both use_lora and use_bridge_for_update_weights at the same time, it is recommended to add a log to report this issue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a one-time startup warning in MegatronEngine.initialize that fires when use_bridge_for_update_weights=True but a fallback condition forces the registry conversion path. It covers all silent-fallback cases - FP8/quantized training/lora

)
if use_bridge:
self._update_weights_via_bridge(meta)
else:
self._update_weights_via_registry(meta)

if dist.get_rank() == 0:
self.rollout_engine.continue_generation()

current_platform.synchronize()
dist.barrier(group=self.cpu_group)

def _update_weights_via_registry(self, meta: WeightUpdateMeta) -> None:
"""Hand-rolled conversion path via convert_to_hf registry.

Used for FP8, LoRA, and models with a converter entry. Iterates this PP
rank's local params, TP-gathers per param, converts to HF layout, and
bucket-broadcasts to the rollout engine.
"""
num_moe_experts = self.tf_config.num_moe_experts
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024

Expand Down Expand Up @@ -1807,10 +1870,37 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:

dist.barrier(group=self.cpu_group)

if dist.get_rank() == 0:
self.rollout_engine.continue_generation()
def _update_weights_via_bridge(self, meta: WeightUpdateMeta) -> None:
"""Delegate live weight sync to megatron-bridge.export_hf_weights.

Streams (hf_name, hf_tensor) directly from the bridge, which handles
TP/EP/PP gather and layout transformation internally. Each PP rank
iterates the global parameter set (vs registry path which iterates only
local layers); non-PP-heads participate in collectives but do not bucket.
MoE expert weights are yielded inline by the bridge's grouped-export
path, so no separate second pass is needed.
"""
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
bucket: list[tuple[str, torch.Tensor]] = []
bucket_size = 0

for hf_name, hf_tensor in self.bridge.export_hf_weights(
self.model,
cpu=False,
show_progress=False,
):
if not self.is_pipeline_parallel_head():
continue
size = hf_tensor.numel() * hf_tensor.element_size()
if bucket_size + size > weight_chunked_mem_size:
self._update_bucket_weights_from_distributed(meta, bucket)
bucket_size = 0
bucket.append((hf_name, hf_tensor.contiguous()))
bucket_size += size

if bucket:
self._update_bucket_weights_from_distributed(meta, bucket)

current_platform.synchronize()
dist.barrier(group=self.cpu_group)

@trace_perf("megatron_engine.update_weights_from_disk", category="io")
Expand Down Expand Up @@ -1909,7 +1999,17 @@ def _load_model_from_hf(self, path: str) -> None:
raise ValueError(
"Loading critic model is not supported with megatron-bridge."
)
self.bridge.load_hf_weights(self.model, hf_path=path)
# megatron-bridge's load path builds shard-index tensors via
# ``torch.arange(...)`` to index HF weights that live on CPU. Under
# the caller's ``with self.device:`` (CUDA) context, those indices
# become CUDA tensors and the CPU-tensor indexing raises
# ``RuntimeError: indices should be either on cpu or on the same
# device as the indexed tensor (cpu)`` — triggered by ChunkedMapping
# for any model with GDN/Mamba-style conv1d weights (e.g. Qwen3.5).
# Force CPU as the factory-op default here; tensor data assignment
# to GPU model params is unaffected (handled by .copy_()).
with torch.device("cpu"):
self.bridge.load_hf_weights(self.model, hf_path=path)
else:
load_weights_from_hf_with_mbridge_fast(
bridge=self.bridge,
Expand Down
84 changes: 84 additions & 0 deletions areal/engine/megatron_utils/megatron_bridge_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0

"""Runtime patches for megatron-bridge bugs not yet in a released version.

Each patch is keyed to an upstream PR. Patches are not version-gated; instead
each one's hot path becomes a no-op once the upstream fix is present (the patch
checks for the missing attribute/behavior before acting), and an idempotency
sentinel prevents double-application. Apply patches at import time via
``_apply_patches_on_import()`` at module bottom.
"""

from __future__ import annotations

import areal.utils.logging as logging

logger = logging.getLogger("MegatronBridgePatches")


def _patch_qwen3vl_pr3143_word_embeddings() -> None:
"""megatron-bridge PR #3143: expose word_embeddings on MTP shadow embedding.

Bug (issue #3112 / PR #3143): in ``Qwen3VLGPTModel.forward``, when
``mtp_process and sequence_parallel`` are both True, ``self.embedding`` is
temporarily replaced with a plain closure ``_sp_scatter_embedding``. The
closure lacks the ``word_embeddings`` attribute that
``shared_embedding_or_output_weight()`` accesses during ``_postprocess``
when ``share_embeddings_and_output_weights=True`` — typical for the
smaller Qwen3.5 dense models (0.8B/2B/4B).

Failure mode:
``AttributeError: 'function' object has no attribute 'word_embeddings'``

Affected versions: megatron-bridge 0.4.0 and 0.4.1. Fixed on ``main``
by commit 20749b09 (PR #3143) but not in any non-alpha release yet.

Strategy: wrap ``Qwen3VLGPTModel._postprocess`` so it lazily restores
``word_embeddings`` on the shadow embedding by inspecting its closure.
Closure-based recovery is non-invasive — we don't touch ``forward``
itself (~70 LoC method).
"""
try:
from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import (
Qwen3VLGPTModel,
)
except ImportError:
return

if getattr(Qwen3VLGPTModel, "_areal_pr3143_applied", False):
return

_orig_postprocess = Qwen3VLGPTModel._postprocess

def _patched_postprocess(self, *args, **kwargs):
emb = self.__dict__.get("embedding")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using getattr(self, "embedding", None) is more robust and idiomatic than directly accessing self.__dict__.get("embedding"). Direct __dict__ access bypasses standard Python attribute resolution (including properties, custom __getattr__ overrides, and inheritance) and is generally discouraged unless strictly necessary.

Suggested change
emb = self.__dict__.get("embedding")
emb = getattr(self, "embedding", None)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.__dict__.get("embedding") is intentional. The upstream bug temporarily replaces self.embedding with a plain closure (_sp_scatter_embedding); because that closure is not an nn.Module, nn.Module.__setattr__ stores it directly in self.__dict__, whereas the real embedding lives in self._modules and is only reachable via nn.Module.__getattr__. So __dict__.get("embedding") detects specifically the installed shadow closure (returning None when it isn't installed), which is exactly the state we want to act on — getattr(self, "embedding") would instead return the real LanguageModelEmbedding when the shadow is absent. The downstream callable(emb) and not hasattr(emb, "word_embeddings") guard keeps it correct either way.

# Only intervene when the shadow closure is currently installed and
# lacks the expected attribute.
if (
callable(emb)
and not hasattr(emb, "word_embeddings")
and emb.__closure__ is not None
):
for cell in emb.__closure__:
try:
target = cell.cell_contents
except ValueError:
continue
if hasattr(target, "word_embeddings"):
emb.word_embeddings = target.word_embeddings
break
return _orig_postprocess(self, *args, **kwargs)

Qwen3VLGPTModel._postprocess = _patched_postprocess
Qwen3VLGPTModel._areal_pr3143_applied = True
logger.info(
"Applied megatron-bridge PR #3143 workaround: "
"Qwen3VLGPTModel shadow embedding word_embeddings restoration."
)


def _apply_patches_on_import() -> None:
_patch_qwen3vl_pr3143_word_embeddings()


_apply_patches_on_import()
Loading
Loading