Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,28 @@ class MegatronEngineConfig:
},
)

use_bridge_for_update_weights: bool = field(
default=False,
metadata={
"help": "When True and bridge_type='megatron-bridge', delegate live "
"weight sync to bridge.export_hf_weights instead of the hand-rolled "
"convert_to_hf registry. Required for models without a registry entry "
"(e.g. Qwen3.5). FP8 paths fall back to the registry automatically.",
},
)

use_padded_seq: bool = field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend changing use_padded_seq from a CLI flag to an automatic decision based on model_type (following the pattern of is_valid_vision_model).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently use_padded_seq is only a must for Qwen3.5, i.e. other supported model can simply ignore and use default field, which pretty much implies an "automatic decision" to use packed seq unless this config is set explicitly.

I do agree the engine should auto-derive this config e.g. default use_padded_seq to True when the model is GDN, since BSHD is a fallback only when THD is not supported. So user shouldn't bother adding this when training Qwen3.5 models.

Do you wanna keep this flag as an optional override (as what verl does, we provide these two options after all) or remove and let engine fully decide?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current PR can first remove this option, as it is only needed for Qwen3.5 at the moment. If other models depend on this option in the future, it can be added separately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the use_padded_seq field entirely and made the layout an automatic decision from model_type, mirroring is_valid_vision_model. If a future model ever needs to override the layout, we can reintroduce an explicit knob at that point

default=False,
metadata={
"help": "Force padded (BSHD) input layout instead of packed (THD) for "
"forward / train_batch. Required for architectures whose state-space "
"or SSM layers reject packed sequences (e.g. Qwen3.5's GDN). Less "
"memory-efficient because attention computes over padding, so prefer "
"small per-microbatch sequence counts. Incompatible with "
"context_parallel_size > 1 (same constraint VLM has).",
},
)


class SchedulingStrategyType(str, Enum):
separation = "separation"
Expand Down Expand Up @@ -1713,6 +1735,16 @@ class vLLMConfig:
)
enable_sleep_mode: bool = False
uvicorn_log_level: str = "warning"
# GDN prefill backend for hybrid models like Qwen3.5; "triton" avoids the
# FlashInfer GDN-kernel hang (vLLM #38916). None leaves vLLM's default, so
# no flag is emitted and non-GDN models are unaffected.
gdn_prefill_backend: str | None = field(
default=None,
metadata={
"help": "GDN prefill backend for hybrid models like Qwen3.5.",
"choices": ["triton", "flashinfer"],
},
)
# lora
enable_lora: bool = False
max_lora_rank: int = 16 # vllm's default
Expand Down
2 changes: 2 additions & 0 deletions areal/engine/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
"gemma3",
Comment on lines 9 to 12
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Registering the text-only qwen3_5 and qwen3_5_moe models in VALID_VISION_MODELS is problematic. When a model is classified as a vision model (self.is_vision_model = True), the engine initialization attempts to load a processor via load_hf_processor_and_tokenizer(self.config.path). Since standard text-only Qwen3.5 models (like Qwen/Qwen3.5-2B in the example recipe) do not have a processor, this will raise an OSError and crash the engine on startup in production.

Since the padded sequence reconstruction path in packed_context_parallel_forward is already guarded by use_padded_seq (which is set to True for Qwen3.5), these models do not need to be registered as vision models to run on padded inputs. They should be removed from VALID_VISION_MODELS.

Suggested change
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
"gemma3",
"qwen3_vl_moe",
"gemma3",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3.5 has no text-only variant — every model in the series uses the multimodal Qwen3_5ForConditionalGeneration / Qwen3_5MoeForConditionalGeneration architecture and ships a preprocessor_config.json (plus video_preprocessor_config.json), so load_hf_processor_and_tokenizer resolves a processor for whole family. The dense single-GPU / TP / PP tests added in this PR run with this registration and pass, so there is no startup OSError. Registering the whole family as a vision model is intentional: there is no separate qwen3_5_vl model_type.

]
# This registry is used to check if a model is a vision model that we have checked it works with AReaL.
Expand Down
87 changes: 83 additions & 4 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_valid_vision_model,
lang_config,
)
from areal.engine.megatron_utils import megatron_bridge_patches # noqa: F401
from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager
from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms
from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper
Expand Down Expand Up @@ -361,6 +362,17 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
f"Loaded processor and tokenizer."
)

if (
self.mcore_config.use_padded_seq
and self.parallel_strategy.context_parallel_size > 1
):
raise NotImplementedError(
"Context parallel (CP > 1) is not supported with "
"use_padded_seq=True (the padded BSHD path operates on "
"[B, S] tensors and the CP path packs sequences). "
f"Got context_parallel_size={self.parallel_strategy.context_parallel_size}."
)

self.quantization_config = getattr(
self.hf_config, "quantization_config", None
)
Expand Down Expand Up @@ -844,6 +856,7 @@ def forward_step(batch_iter, model):
mb_input.padded_mb,
gather_cp_output=not cp_local,
is_vision_model=self.is_vision_model,
use_padded_seq=self.mcore_config.use_padded_seq,
)

# Release tree attention metadata after forward pass
Expand Down Expand Up @@ -1753,6 +1766,35 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:

dist.barrier(group=self.cpu_group)

# Bridge delegation: when bridge_type=megatron-bridge and the user opts in,
# stream HF tensors directly from bridge.export_hf_weights. Falls back to
# the hand-rolled registry path for FP8 (quant_mapping in megatron-bridge
# is amax-style, not TE blockwise) and for LoRA (separate adapter export
# path not yet wired here).
use_bridge = (
self.bridge_cls == "megatron-bridge"
and self.mcore_config.use_bridge_for_update_weights
and not self.quantization_config
and not self.config.use_lora
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user sets both use_lora and use_bridge_for_update_weights at the same time, it is recommended to add a log to report this issue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a one-time startup warning in MegatronEngine.initialize that fires when use_bridge_for_update_weights=True but a fallback condition forces the registry conversion path. It covers all silent-fallback cases - FP8/quantized training/lora

)
if use_bridge:
self._update_weights_via_bridge(meta)
else:
self._update_weights_via_registry(meta)

if dist.get_rank() == 0:
self.rollout_engine.continue_generation()

current_platform.synchronize()
dist.barrier(group=self.cpu_group)

def _update_weights_via_registry(self, meta: WeightUpdateMeta) -> None:
"""Hand-rolled conversion path via convert_to_hf registry.
Used for FP8, LoRA, and models with a converter entry. Iterates this PP
rank's local params, TP-gathers per param, converts to HF layout, and
bucket-broadcasts to the rollout engine.
"""
num_moe_experts = self.tf_config.num_moe_experts
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024

Expand Down Expand Up @@ -1807,10 +1849,37 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:

dist.barrier(group=self.cpu_group)

if dist.get_rank() == 0:
self.rollout_engine.continue_generation()
def _update_weights_via_bridge(self, meta: WeightUpdateMeta) -> None:
"""Delegate live weight sync to megatron-bridge.export_hf_weights.
Streams (hf_name, hf_tensor) directly from the bridge, which handles
TP/EP/PP gather and layout transformation internally. Each PP rank
iterates the global parameter set (vs registry path which iterates only
local layers); non-PP-heads participate in collectives but do not bucket.
MoE expert weights are yielded inline by the bridge's grouped-export
path, so no separate second pass is needed.
"""
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
bucket: list[tuple[str, torch.Tensor]] = []
bucket_size = 0

for hf_name, hf_tensor in self.bridge.export_hf_weights(
self.model,
cpu=False,
show_progress=False,
):
if not self.is_pipeline_parallel_head():
continue
size = hf_tensor.numel() * hf_tensor.element_size()
if bucket_size + size > weight_chunked_mem_size:
self._update_bucket_weights_from_distributed(meta, bucket)
bucket_size = 0
bucket.append((hf_name, hf_tensor.contiguous()))
bucket_size += size

if bucket:
self._update_bucket_weights_from_distributed(meta, bucket)

current_platform.synchronize()
dist.barrier(group=self.cpu_group)

@trace_perf("megatron_engine.update_weights_from_disk", category="io")
Expand Down Expand Up @@ -1909,7 +1978,17 @@ def _load_model_from_hf(self, path: str) -> None:
raise ValueError(
"Loading critic model is not supported with megatron-bridge."
)
self.bridge.load_hf_weights(self.model, hf_path=path)
# megatron-bridge's load path builds shard-index tensors via
# ``torch.arange(...)`` to index HF weights that live on CPU. Under
# the caller's ``with self.device:`` (CUDA) context, those indices
# become CUDA tensors and the CPU-tensor indexing raises
# ``RuntimeError: indices should be either on cpu or on the same
# device as the indexed tensor (cpu)`` — triggered by ChunkedMapping
# for any model with GDN/Mamba-style conv1d weights (e.g. Qwen3.5).
# Force CPU as the factory-op default here; tensor data assignment
# to GPU model params is unaffected (handled by .copy_()).
with torch.device("cpu"):
self.bridge.load_hf_weights(self.model, hf_path=path)
else:
load_weights_from_hf_with_mbridge_fast(
bridge=self.bridge,
Expand Down
84 changes: 84 additions & 0 deletions areal/engine/megatron_utils/megatron_bridge_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0

"""Runtime patches for megatron-bridge bugs not yet in a released version.

Each patch is keyed to an upstream PR. Patches are not version-gated; instead
each one's hot path becomes a no-op once the upstream fix is present (the patch
checks for the missing attribute/behavior before acting), and an idempotency
sentinel prevents double-application. Apply patches at import time via
``_apply_patches_on_import()`` at module bottom.
"""

from __future__ import annotations

import areal.utils.logging as logging

logger = logging.getLogger("MegatronBridgePatches")


def _patch_qwen3vl_pr3143_word_embeddings() -> None:
"""megatron-bridge PR #3143: expose word_embeddings on MTP shadow embedding.

Bug (issue #3112 / PR #3143): in ``Qwen3VLGPTModel.forward``, when
``mtp_process and sequence_parallel`` are both True, ``self.embedding`` is
temporarily replaced with a plain closure ``_sp_scatter_embedding``. The
closure lacks the ``word_embeddings`` attribute that
``shared_embedding_or_output_weight()`` accesses during ``_postprocess``
when ``share_embeddings_and_output_weights=True`` — typical for the
smaller Qwen3.5 dense models (0.8B/2B/4B).

Failure mode:
``AttributeError: 'function' object has no attribute 'word_embeddings'``

Affected versions: megatron-bridge 0.4.0 and 0.4.1. Fixed on ``main``
by commit 20749b09 (PR #3143) but not in any non-alpha release yet.

Strategy: wrap ``Qwen3VLGPTModel._postprocess`` so it lazily restores
``word_embeddings`` on the shadow embedding by inspecting its closure.
Closure-based recovery is non-invasive — we don't touch ``forward``
itself (~70 LoC method).
"""
try:
from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import (
Qwen3VLGPTModel,
)
except ImportError:
return

if getattr(Qwen3VLGPTModel, "_areal_pr3143_applied", False):
return

_orig_postprocess = Qwen3VLGPTModel._postprocess

def _patched_postprocess(self, *args, **kwargs):
emb = self.__dict__.get("embedding")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using getattr(self, "embedding", None) is more robust and idiomatic than directly accessing self.__dict__.get("embedding"). Direct __dict__ access bypasses standard Python attribute resolution (including properties, custom __getattr__ overrides, and inheritance) and is generally discouraged unless strictly necessary.

Suggested change
emb = self.__dict__.get("embedding")
emb = getattr(self, "embedding", None)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.__dict__.get("embedding") is intentional. The upstream bug temporarily replaces self.embedding with a plain closure (_sp_scatter_embedding); because that closure is not an nn.Module, nn.Module.__setattr__ stores it directly in self.__dict__, whereas the real embedding lives in self._modules and is only reachable via nn.Module.__getattr__. So __dict__.get("embedding") detects specifically the installed shadow closure (returning None when it isn't installed), which is exactly the state we want to act on — getattr(self, "embedding") would instead return the real LanguageModelEmbedding when the shadow is absent. The downstream callable(emb) and not hasattr(emb, "word_embeddings") guard keeps it correct either way.

# Only intervene when the shadow closure is currently installed and
# lacks the expected attribute.
if (
callable(emb)
and not hasattr(emb, "word_embeddings")
and emb.__closure__ is not None
):
for cell in emb.__closure__:
try:
target = cell.cell_contents
except ValueError:
continue
if hasattr(target, "word_embeddings"):
emb.word_embeddings = target.word_embeddings
break
return _orig_postprocess(self, *args, **kwargs)

Qwen3VLGPTModel._postprocess = _patched_postprocess
Qwen3VLGPTModel._areal_pr3143_applied = True
logger.info(
"Applied megatron-bridge PR #3143 workaround: "
"Qwen3VLGPTModel shadow embedding word_embeddings restoration."
)


def _apply_patches_on_import() -> None:
_patch_qwen3vl_pr3143_word_embeddings()


_apply_patches_on_import()
46 changes: 30 additions & 16 deletions areal/engine/megatron_utils/packed_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def packed_context_parallel_forward(
input_: dict[str, Any],
gather_cp_output: bool = True,
is_vision_model: bool = False,
use_padded_seq: bool = False,
):
input_ids = input_["input_ids"]
position_ids = input_.get("position_ids", None)
Expand All @@ -300,12 +301,18 @@ def packed_context_parallel_forward(
packed_seq_params = None

is_vision = is_vision_model and any(key in input_ for key in _VLM_FORWARD_KEYS)
# Architectures whose attention/SSM kernels reject packed sequences (e.g.
# Qwen3.5 GDN) must run on [B, S] padded input. The reconstruction logic
# below is shared with the VLM path; the difference is downstream
# (attention_mask and position_ids are passed through for text-only).
needs_padded_form = is_vision or use_padded_seq

# Track whether we reconstructed 2D batch form for vision
vision_repack_info = None
# Track shape metadata so the output can be repacked back to packed
# [total_len, ...] form on the last PP stage.
padded_repack_info = None

if cu_seqlens is not None:
if not is_vision:
if not needs_padded_form:
if attention_mask is not None or tree_triton_data is not None:
raise ValueError(
"Attention mask should be None when using packed sequences."
Expand All @@ -315,10 +322,9 @@ def packed_context_parallel_forward(
)
input_ids = input_ids.contiguous()
else:
# VLM models expect batch-form [B, S] input_ids for mRoPE position
# computation and vision token embedding replacement. Reconstruct
# padded 2D tensors from packed 1D using cu_seqlens via boolean
# masking — avoids per-sample Python loop and GPU-CPU sync.
# VLM and BSHD-only models expect [B, S] padded input. Reconstruct
# padded 2D tensors from packed 1D via boolean masking — avoids
# per-sample Python loop and GPU-CPU sync.
batch_size = cu_seqlens.shape[0] - 1
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = int(seq_lens.max().item())
Expand All @@ -337,14 +343,15 @@ def packed_context_parallel_forward(
)
input_ids_2d[attention_mask] = input_ids
input_ids = input_ids_2d
vision_repack_info = (cu_seqlens, seq_lens, max_seqlen)
padded_repack_info = (cu_seqlens, seq_lens, max_seqlen)

# Pass tree_triton_data as attention_mask if present (for Triton tree attention)
# Otherwise use the attention_mask from input (could be dense tensor for flex attention)
# For VLM: pass None — the model's get_rope_index uses the 2D attention_mask
# internally for correct mRoPE positions. Each batch slot holds one sequence
# with trailing padding, so causal attention yields correct outputs at
# VLM path: attention_mask=None — model's get_rope_index uses the 2D mask
# internally for mRoPE positions. Each batch slot holds one sequence with
# trailing padding, so causal attention yields correct outputs at
# non-padding positions; padding outputs are discarded during repack.
#
# BSHD text-only path (use_padded_seq): pass our built attention_mask so
# the model's attention layers skip padding.
if is_vision:
final_attention_mask = None
else:
Expand All @@ -360,6 +367,13 @@ def packed_context_parallel_forward(
if key in input_:
vlm_kwargs[key] = input_[key]

# For BSHD text-only, drop the packed-form position_ids (a 1D tensor of
# length total_len) — they don't match the 2D [B, S] input. Let mcore
# compute the default torch.arange positions per row; padding positions
# are masked out by attention_mask.
if use_padded_seq and not is_vision:
position_ids = None

try:
output = model(
input_ids=input_ids,
Expand All @@ -379,16 +393,16 @@ def packed_context_parallel_forward(
ignore_virtual=False, vp_stage=model_vp_stage
)

# Repack vision output to packed [total_len, ...] for the last PP stage only.
# Repack padded output to packed [total_len, ...] for the last PP stage only.
# Intermediate stages must return their output unchanged so the pipeline
# send/recv shapes match what the next stage expects (megatron-core's
# `_communicate_shapes` negotiates based on this return value).
#
# On the last PP stage, megatron-core GPTModel returns logits already
# transposed to [B, S, V] (gpt_model.py: `return logits.transpose(0, 1).contiguous()`),
# so a boolean mask of valid positions selects the packed sequence.
if vision_repack_info is not None and is_pipeline_last_stage:
_, repack_seq_lens, repack_max_seqlen = vision_repack_info
if padded_repack_info is not None and is_pipeline_last_stage:
_, repack_seq_lens, repack_max_seqlen = padded_repack_info
mask = (
torch.arange(repack_max_seqlen, device=output.device)[None, :]
< repack_seq_lens[:, None]
Expand Down
Loading
Loading