Skip to content

Commit 995a15d

Browse files
committed
refactor(engine): auto-derive padded-seq layout from model type
The padded (BSHD) vs packed (THD) forward layout is a hard architectural property of the model -- GDN/SSM kernels (the Qwen3.5 family) reject packed sequences -- not a user tunable. Exposing it as the `use_padded_seq` config field let it be mis-set and risked silent correctness or crash issues. Derive it from `model_type` instead so the layout can never disagree with the architecture. Also surface a startup warning when `use_bridge_for_update_weights=True` but a fallback condition (non-megatron-bridge, FP8/quantized, or LoRA) silently routes live weight sync through the registry path, so the effective behavior is visible in logs. Key changes: - Add requires_padded_seq(model_type) helper in engine/core/model.py - Derive self.use_padded_seq from model_type in MegatronEngine.initialize - Remove use_padded_seq from MegatronEngineConfig and regenerate CLI docs - Warn once when bridge weight-sync falls back to the registry path - Drop the test-runner override map and example yaml flag Refs: areal-project#1384
1 parent 6356c8e commit 995a15d

8 files changed

Lines changed: 103 additions & 94 deletions

File tree

areal/api/cli_args.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -950,18 +950,6 @@ class MegatronEngineConfig:
950950
},
951951
)
952952

953-
use_padded_seq: bool = field(
954-
default=False,
955-
metadata={
956-
"help": "Force padded (BSHD) input layout instead of packed (THD) for "
957-
"forward / train_batch. Required for architectures whose state-space "
958-
"or SSM layers reject packed sequences (e.g. Qwen3.5's GDN). Less "
959-
"memory-efficient because attention computes over padding, so prefer "
960-
"small per-microbatch sequence counts. Incompatible with "
961-
"context_parallel_size > 1 (same constraint VLM has).",
962-
},
963-
)
964-
965953

966954
class SchedulingStrategyType(str, Enum):
967955
separation = "separation"

areal/engine/core/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ def is_qwen3_5_model(model_type: str) -> bool:
8585
return model_type in ["qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"]
8686

8787

88+
def requires_padded_seq(model_type: str) -> bool:
89+
"""Whether the model must run the padded (BSHD) forward instead of packed (THD).
90+
91+
GDN/SSM models (currently the Qwen3.5 family) reject packed sequences in their
92+
attention/SSM kernels, so they must run on padded ``[B, S]`` input. THD stays
93+
the default for every other model.
94+
"""
95+
return is_qwen3_5_model(model_type)
96+
97+
8898
# Copied from trl
8999
def disable_dropout_in_model(model: torch.nn.Module) -> None:
90100
for module in model.modules():

areal/engine/megatron_engine.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
disable_dropout_in_model,
6161
is_valid_vision_model,
6262
lang_config,
63+
requires_padded_seq,
6364
)
6465
from areal.engine.megatron_utils import megatron_bridge_patches # noqa: F401
6566
from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager
@@ -347,6 +348,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
347348
)
348349

349350
self.is_vision_model = is_valid_vision_model(self.hf_config.model_type)
351+
# GDN/SSM models (e.g. Qwen3.5) reject packed THD input and must run
352+
# the padded BSHD forward. Derived from model type rather than a
353+
# config flag so the layout can't be mis-set.
354+
self.use_padded_seq = requires_padded_seq(self.hf_config.model_type)
350355
if self.is_vision_model:
351356
if self.parallel_strategy.context_parallel_size > 1:
352357
raise NotImplementedError(
@@ -362,14 +367,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
362367
f"Loaded processor and tokenizer."
363368
)
364369

365-
if (
366-
self.mcore_config.use_padded_seq
367-
and self.parallel_strategy.context_parallel_size > 1
368-
):
370+
if self.use_padded_seq and self.parallel_strategy.context_parallel_size > 1:
369371
raise NotImplementedError(
370-
"Context parallel (CP > 1) is not supported with "
371-
"use_padded_seq=True (the padded BSHD path operates on "
372-
"[B, S] tensors and the CP path packs sequences). "
372+
f"Context parallel (CP > 1) is not supported for "
373+
f"model_type={self.hf_config.model_type!r}, which requires the "
374+
"padded BSHD forward (it operates on [B, S] tensors while the "
375+
"CP path packs sequences). "
373376
f"Got context_parallel_size={self.parallel_strategy.context_parallel_size}."
374377
)
375378

@@ -380,6 +383,24 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
380383
self._check_and_apply_fp8_config()
381384
self._validate_fp8_consistency()
382385

386+
# Warn once if bridge-delegated weight sync was requested but a
387+
# fallback condition forces the registry conversion path (the
388+
# dispatch in _update_weights_from_distributed silently falls back).
389+
if self.mcore_config.use_bridge_for_update_weights:
390+
fallback_reasons = []
391+
if self.bridge_cls != "megatron-bridge":
392+
fallback_reasons.append(f"bridge_type={self.bridge_cls!r}")
393+
if self.quantization_config:
394+
fallback_reasons.append("FP8/quantized training")
395+
if self.config.use_lora:
396+
fallback_reasons.append("LoRA enabled")
397+
if fallback_reasons:
398+
self.logger.warning(
399+
"use_bridge_for_update_weights=True, but live weight sync "
400+
"will use the registry conversion path instead because: "
401+
f"{', '.join(fallback_reasons)}."
402+
)
403+
383404
with self.device:
384405
models = make_mcore_model(
385406
hf_config=self.hf_config,
@@ -856,7 +877,7 @@ def forward_step(batch_iter, model):
856877
mb_input.padded_mb,
857878
gather_cp_output=not cp_local,
858879
is_vision_model=self.is_vision_model,
859-
use_padded_seq=self.mcore_config.use_padded_seq,
880+
use_padded_seq=self.use_padded_seq,
860881
)
861882

862883
# Release tree attention metadata after forward pass

0 commit comments

Comments
 (0)