Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ RUN pip install /tmp/wheels/flash_attn_3-*.whl && \

RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps

RUN pip install flash-linear-attention==0.4.1
RUN pip install flash-linear-attention==0.4.2
RUN pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu128/

RUN if [ "${ENABLE_CUDA_13}" = "1" ]; then \
Expand Down
11 changes: 11 additions & 0 deletions miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ def init(
args, role
)

from megatron.core import mpu

if mpu.get_context_parallel_world_size() > 1:
from miles.backends.training_utils.cp_utils import setup_hybrid_cp

cp_group = mpu.get_context_parallel_group()
cp_rank = mpu.get_context_parallel_rank()
cp_world_size = mpu.get_context_parallel_world_size()
for model_chunk in self.model:
setup_hybrid_cp(model_chunk, cp_group, cp_rank, cp_world_size)

verify_megatron_parallel_state(self.model)

if role == "critic":
Expand Down
29 changes: 29 additions & 0 deletions miles/backends/training_utils/cp_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from collections.abc import Callable

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from .parallel import get_parallel_state

logger = logging.getLogger(__name__)


def get_logits_and_tokens_offset_with_cp(
total_length: int,
Expand Down Expand Up @@ -336,3 +340,28 @@ def slice_log_prob_with_cp(
return chunk_1 + chunk_2
else:
return torch.cat([chunk_1, chunk_2], dim=0)


def setup_hybrid_cp(model: nn.Module, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> None:
"""Configure GatedDeltaNet modules for native fla CP instead of all-gather duplication.

Walks the model tree looking for HuggingfaceAttention submodules that have a
``linear_attn`` child (i.e. DeltaNet layers). For each one it sets the CP
metadata so that ``_build_cp_context`` produces a valid context, and flips
``hybrid_cp`` so the parent skips the all-gather path.
"""
from miles_plugins.models.hf_attention import HuggingfaceAttention

count = 0
for module in model.modules():
if isinstance(module, HuggingfaceAttention):
linear_attn = getattr(module, "linear_attn", None)
if linear_attn is not None:
linear_attn.cp_group = cp_group
linear_attn.cp_rank = cp_rank
linear_attn.cp_world_size = cp_world_size
module.hybrid_cp = True
count += 1

if count > 0:
logger.info(f"Configured hybrid CP on {count} DeltaNet modules (fla native state passing)")
3 changes: 3 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,7 @@ def equal(x, y):
if "rope_theta" in hf_config.rope_parameters:
hf_config.rope_theta = hf_config.rope_parameters["rope_theta"]

is_moe = hasattr(hf_config, "num_experts") or hasattr(hf_config, "moe_intermediate_size")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guapisolo I added this line to ensure the qwen 3.5 hf_config intermediate_size=5632 not equal with ffn_hidden_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me double check.

Copy link
Copy Markdown
Collaborator

@guapisolo guapisolo Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this line? This line is not related to your motivation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious that around few weeks before when I integrated qwen3.5, there was no assert errors here. (what about you I believe you also tested it before) is it from transformer version update or sth else?

for hf_config_name, megatron_config_name, compare_fn in [
("hidden_size", "hidden_size", equal),
("num_attention_heads", "num_attention_heads", equal),
Expand All @@ -2102,6 +2103,8 @@ def equal(x, y):
("rope_theta", "rotary_base", equal),
]:
if hasattr(hf_config, hf_config_name):
if is_moe and hf_config_name == "intermediate_size":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my side this line might be redundant? https://huggingface.co/Qwen/Qwen3.5-35B-A3B/blob/main/config.json . There is no intermediate_size=5632 in qwen3.5

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i found it... we should update into pip install transformers==5.2.0.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add a walkaround to avoid this under transformers 4.57.1

continue
if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)):
errors.append(
f"{hf_config_name} in hf config {getattr(hf_config, hf_config_name)} is not equal to "
Expand Down
8 changes: 6 additions & 2 deletions miles_plugins/models/hf_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class HuggingfaceAttention(MegatronModule, ABC):
"cross attn" specializations.
"""

# Subclasses set this to True when the underlying module handles CP natively
# (e.g. via fla's state-passing CP for DeltaNet), bypassing the all-gather.
hybrid_cp: bool = False

def __init__(
self,
args,
Expand Down Expand Up @@ -115,7 +119,7 @@ def forward(
group=mpu.get_tensor_model_parallel_group(),
)

if mpu.get_context_parallel_world_size() > 1:
if mpu.get_context_parallel_world_size() > 1 and not self.hybrid_cp:
cp_size = mpu.get_context_parallel_world_size()
# Use custom all-gather whose backward returns local gradient
# instead of reduce-scatter, since the computation is duplicated.
Expand Down Expand Up @@ -150,7 +154,7 @@ def forward(

output = output.permute(1, 0, 2) # [seq_len, bsz, hidden_dim]

if mpu.get_context_parallel_world_size() > 1:
if mpu.get_context_parallel_world_size() > 1 and not self.hybrid_cp:
cp_rank = mpu.get_context_parallel_rank()
output_list = []
for i in range(len(cu_seqlens) - 1):
Expand Down
61 changes: 48 additions & 13 deletions miles_plugins/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
except ImportError:
pass

try:
from fla.ops.cp import FLACPContext, build_cp_context
except ImportError:
FLACPContext = None
build_cp_context = None

from .hf_attention import HuggingfaceAttention, _load_hf_config


Expand Down Expand Up @@ -81,24 +87,41 @@ def __init__(self, config, layer_idx: int):

self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)

def _build_cp_context(self, local_seq_len: int, device: torch.device):
"""Build fla CP context from the local (sharded) sequence length."""
cp_group = getattr(self, "cp_group", None)
if cp_group is None or build_cp_context is None:
return None
global_seq_len = local_seq_len * self.cp_world_size
global_cu_seqlens = torch.tensor([0, global_seq_len], dtype=torch.int32, device=device)
return build_cp_context(
cu_seqlens=global_cu_seqlens,
group=cp_group,
conv1d_kernel_size=self.conv_kernel_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _build_cp_context method currently constructs global_cu_seqlens for only a single sequence, which will lead to incorrect behavior or crashes when the batch size is greater than 1. It should be updated to account for the batch size by creating a cu_seqlens tensor that covers all sequences in the batch.

    def _build_cp_context(self, batch_size: int, local_seq_len: int, device: torch.device):
        """Build fla CP context from the local (sharded) sequence length."""
        cp_group = getattr(self, "cp_group", None)
        if cp_group is None or build_cp_context is None:
            return None
        cp_world_size = getattr(self, "cp_world_size", 1)
        global_seq_len = local_seq_len * cp_world_size
        global_cu_seqlens = torch.arange(
            0, (batch_size + 1) * global_seq_len, step=global_seq_len, dtype=torch.int32, device=device
        )
        return build_cp_context(
            cu_seqlens=global_cu_seqlens,
            group=cp_group,
            conv1d_kernel_size=self.conv_kernel_size,
        )
References
  1. Avoid hardcoding model dimensions; derive them from configuration or input tensor shapes instead.


def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor = None,
):
batch_size, seq_len, _ = hidden_states.shape

cp_context = self._build_cp_context(seq_len, hidden_states.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to _build_cp_context needs to pass the batch_size to correctly initialize the CP context for multi-sequence batches.

Suggested change
cp_context = self._build_cp_context(seq_len, hidden_states.device)
cp_context = self._build_cp_context(batch_size, seq_len, hidden_states.device)


# Projections (flat layout: [Q_all, K_all, V_all])
mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

# Convolution on the flat QKV
# Convolution on the flat QKV (pass cp_context for boundary handling)
conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens
mixed_qkv, _ = self.conv1d(
x=mixed_qkv,
cu_seqlens=cu_seqlens,
cu_seqlens=conv_cu_seqlens,
cp_context=cp_context,
)

# Split into Q, K, V (flat split, matching HF layout)
Expand All @@ -118,17 +141,29 @@ def forward(
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
if cp_context is not None:
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cp_context.cu_seqlens,
cp_context=cp_context,
)
else:
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)

z_shape_og = z.shape
# reshape input data into 2D tensor
Expand Down
59 changes: 47 additions & 12 deletions miles_plugins/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
except ImportError:
pass

try:
from fla.ops.cp import FLACPContext, build_cp_context
except ImportError:
FLACPContext = None
build_cp_context = None

from .hf_attention import HuggingfaceAttention


Expand Down Expand Up @@ -74,6 +80,19 @@ def __init__(self, config, layer_idx: int):

self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)

def _build_cp_context(self, local_seq_len: int, device: torch.device):
"""Build fla CP context from the local (sharded) sequence length."""
cp_group = getattr(self, "cp_group", None)
if cp_group is None or build_cp_context is None:
return None
global_seq_len = local_seq_len * self.cp_world_size
global_cu_seqlens = torch.tensor([0, global_seq_len], dtype=torch.int32, device=device)
return build_cp_context(
cu_seqlens=global_cu_seqlens,
group=cp_group,
conv1d_kernel_size=self.conv_kernel_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _build_cp_context method ignores the batch size when constructing global_cu_seqlens. This will cause issues during Context Parallel execution if the batch size is greater than 1. The cu_seqlens should be generated based on the batch size.

    def _build_cp_context(self, batch_size: int, local_seq_len: int, device: torch.device):
        """Build fla CP context from the local (sharded) sequence length."""
        cp_group = getattr(self, "cp_group", None)
        if cp_group is None or build_cp_context is None:
            return None
        cp_world_size = getattr(self, "cp_world_size", 1)
        global_seq_len = local_seq_len * cp_world_size
        global_cu_seqlens = torch.arange(
            0, (batch_size + 1) * global_seq_len, step=global_seq_len, dtype=torch.int32, device=device
        )
        return build_cp_context(
            cu_seqlens=global_cu_seqlens,
            group=cp_group,
            conv1d_kernel_size=self.conv_kernel_size,
        )
References
  1. Avoid hardcoding model dimensions; derive them from configuration or input tensor shapes instead.


def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
Expand Down Expand Up @@ -108,16 +127,20 @@ def forward(
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor = None,
):
cp_context = self._build_cp_context(hidden_states.shape[1], hidden_states.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the call to _build_cp_context to include the batch size, ensuring the CP context is correctly built for batches with more than one sequence.

Suggested change
cp_context = self._build_cp_context(hidden_states.shape[1], hidden_states.device)
cp_context = self._build_cp_context(hidden_states.shape[0], hidden_states.shape[1], hidden_states.device)


projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))

mixed_qkv = torch.cat((query, key, value), dim=-1)

conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens
mixed_qkv, _ = self.conv1d(
x=mixed_qkv,
cu_seqlens=cu_seqlens,
cu_seqlens=conv_cu_seqlens,
cp_context=cp_context,
)

query, key, value = torch.split(
Expand All @@ -140,17 +163,29 @@ def forward(
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
if cp_context is not None:
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cp_context.cu_seqlens,
cp_context=cp_context,
)
else:
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)

z_shape_og = z.shape
# reshape input data into 2D tensor
Expand Down
Loading
Loading