Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def init(
args, role
)

parallel_state = get_parallel_state()
if parallel_state.cp.size > 1:
from miles_plugins.models.cp_utils import detect_and_setup_hybrid_cp

for model_chunk in self.model:
detect_and_setup_hybrid_cp(
model_chunk, parallel_state.cp.group, parallel_state.cp.rank, parallel_state.cp.size
)

verify_megatron_parallel_state(self.model)

if role == "critic":
Expand Down
36 changes: 36 additions & 0 deletions miles/backends/training_utils/cp_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import logging
from collections.abc import Callable

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from .parallel import get_parallel_state

try:
from fla.ops.cp import build_cp_context as _fla_build_cp_context
except ImportError:
_fla_build_cp_context = None

logger = logging.getLogger(__name__)


def get_logits_and_tokens_offset_with_cp(
total_length: int,
Expand Down Expand Up @@ -336,3 +345,30 @@ def slice_log_prob_with_cp(
return chunk_1 + chunk_2
else:
return torch.cat([chunk_1, chunk_2], dim=0)


def build_gdn_cp_context(module: nn.Module, cu_seqlens: torch.Tensor, device: torch.device):
"""Build fla CP context for a GatedDeltaNet module from packed sequence boundaries.

Args:
module: GDN module with ``cp_group`` / ``cp_world_size`` / ``conv_kernel_size``.
cu_seqlens: Global packed sequence boundaries (e.g. ``packed_seq_params.cu_seqlens_q``).
device: Target device.

Returns ``None`` when CP is not configured on the module (``cp_group`` not set).
Raises ``RuntimeError`` if hybrid CP is configured but ``fla.ops.cp`` is missing.
"""
cp_group = getattr(module, "cp_group", None)
if cp_group is None:
return None
if _fla_build_cp_context is None:
raise RuntimeError(
"Hybrid CP requires fla.ops.cp (flash-linear-attention >= 0.4.2) " "but it could not be imported."
)
if cu_seqlens is None or cu_seqlens.numel() < 2:
raise ValueError(f"Hybrid CP requires valid cu_seqlens (at least 2 elements) but got {cu_seqlens}")
return _fla_build_cp_context(
cu_seqlens=cu_seqlens.to(device=device, dtype=torch.int32),
group=cp_group,
conv1d_kernel_size=module.conv_kernel_size,
)
3 changes: 3 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2140,6 +2140,9 @@ def equal(x, y):
),
("rope_theta", "rotary_base", equal),
]:
# FIXME: Qwen3.5 transfomers has bug.
if getattr(hf_config, "model_type", "") == "qwen3_5_moe_text" and hf_config_name == "intermediate_size":
continue
if hasattr(hf_config, hf_config_name):
if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)):
errors.append(
Expand Down
26 changes: 26 additions & 0 deletions miles_plugins/models/cp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging

import torch.distributed as dist
import torch.nn as nn

from miles_plugins.models.hf_attention import HuggingfaceAttention

logger = logging.getLogger(__name__)


def detect_and_setup_hybrid_cp(model: nn.Module, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> int:
"""Scan for GatedDeltaNet modules and configure them for native fla CP."""
count = 0
for module in model.modules():
if isinstance(module, HuggingfaceAttention):
linear_attn = getattr(module, "linear_attn", None)
if linear_attn is not None:
linear_attn.cp_group = cp_group
linear_attn.cp_rank = cp_rank
linear_attn.cp_world_size = cp_world_size
module.hybrid_cp = True
count += 1

if count > 0:
logger.info(f"Configured hybrid CP on {count} GDN modules (fla native state passing)")
return count
143 changes: 141 additions & 2 deletions miles_plugins/models/hf_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,116 @@ def _fix_dtype(d):
return ns


def _get_cp_sequence_lengths(cu_seqlens, cp_size, local_total_len=None):
global_seq_lengths = [(cu_seqlens[i + 1] - cu_seqlens[i]).item() for i in range(len(cu_seqlens) - 1)]
local_seq_lengths = []
for global_seq_len in global_seq_lengths:
if global_seq_len % cp_size != 0:
raise ValueError(f"Expected sequence length {global_seq_len} to be divisible by cp_size={cp_size}")
local_seq_lengths.append(global_seq_len // cp_size)

if local_total_len is not None and sum(local_seq_lengths) != local_total_len:
raise ValueError(f"Expected local total length {local_total_len}, got {sum(local_seq_lengths)}")

return global_seq_lengths, local_seq_lengths


def _gather_cp_tensors(x, cp_group):
gathered = [torch.empty_like(x) for _ in range(dist.get_world_size(group=cp_group))]
dist.all_gather(gathered, x.contiguous(), group=cp_group)
return gathered


def _zigzag_to_packed_shard_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size):
"""Convert zigzag ring-attn layout to the contiguous packed shard expected by fla CP."""
global_seq_lengths, local_seq_lengths = _get_cp_sequence_lengths(cu_seqlens, cp_size, hidden_states.size(0))
gathered_by_rank = [
gathered.split(local_seq_lengths, dim=0) for gathered in _gather_cp_tensors(hidden_states, cp_group)
]

full_sequences = []
for seq_idx, global_seq_len in enumerate(global_seq_lengths):
per_rank = [rank_seqs[seq_idx] for rank_seqs in gathered_by_rank]
if global_seq_len % (2 * cp_size) == 0:
subchunk_len = global_seq_len // (2 * cp_size)
full_seq = torch.cat(
[seq[:subchunk_len] for seq in per_rank] + [seq[subchunk_len:] for seq in per_rank][::-1],
dim=0,
)
else:
# Final local padding is appended contiguously on each rank, not in zigzag order.
full_seq = torch.cat(per_rank, dim=0)
full_sequences.append(full_seq)

full_stream = torch.cat(full_sequences, dim=0) if full_sequences else hidden_states[:0]
shard_len = hidden_states.size(0)
return full_stream[cp_rank * shard_len : (cp_rank + 1) * shard_len]


def _packed_shard_to_zigzag_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size):
"""Convert contiguous packed shard layout back to zigzag ring-attn layout."""
global_seq_lengths, local_seq_lengths = _get_cp_sequence_lengths(cu_seqlens, cp_size, hidden_states.size(0))
full_stream = torch.cat(_gather_cp_tensors(hidden_states, cp_group), dim=0)
full_sequences = full_stream.split(global_seq_lengths, dim=0)

local_sequences = []
for full_seq, global_seq_len, local_seq_len in zip(
full_sequences, global_seq_lengths, local_seq_lengths, strict=True
):
if global_seq_len % (2 * cp_size) == 0:
subchunk_len = global_seq_len // (2 * cp_size)
parts = full_seq.split(subchunk_len, dim=0)
local_sequences.append(torch.cat([parts[cp_rank], parts[2 * cp_size - 1 - cp_rank]], dim=0))
else:
local_sequences.append(full_seq.split(local_seq_len, dim=0)[cp_rank])

return torch.cat(local_sequences, dim=0) if local_sequences else hidden_states[:0]


class _ZigzagToPackedShard(torch.autograd.Function):
"""Convert zigzag ring-attn layout to contiguous packed shards for native fla CP."""

@staticmethod
def forward(ctx, hidden_states, cu_seqlens, cp_group, cp_rank, cp_size):
ctx.cp_group = cp_group
ctx.cp_rank = cp_rank
ctx.cp_size = cp_size
ctx.save_for_backward(cu_seqlens)
return _zigzag_to_packed_shard_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size)

@staticmethod
def backward(ctx, grad_output):
(cu_seqlens,) = ctx.saved_tensors
result = _packed_shard_to_zigzag_impl(grad_output, cu_seqlens, ctx.cp_group, ctx.cp_rank, ctx.cp_size)
return result, None, None, None, None


class _PackedShardToZigzag(torch.autograd.Function):
"""Convert contiguous packed shards back to zigzag ring-attn layout."""

@staticmethod
def forward(ctx, hidden_states, cu_seqlens, cp_group, cp_rank, cp_size):
ctx.cp_group = cp_group
ctx.cp_rank = cp_rank
ctx.cp_size = cp_size
ctx.save_for_backward(cu_seqlens)
return _packed_shard_to_zigzag_impl(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size)

@staticmethod
def backward(ctx, grad_output):
(cu_seqlens,) = ctx.saved_tensors
result = _zigzag_to_packed_shard_impl(grad_output, cu_seqlens, ctx.cp_group, ctx.cp_rank, ctx.cp_size)
return result, None, None, None, None


def _zigzag_to_packed_shard(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size):
return _ZigzagToPackedShard.apply(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size)


def _packed_shard_to_zigzag(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size):
return _PackedShardToZigzag.apply(hidden_states, cu_seqlens, cp_group, cp_rank, cp_size)


class _AllGatherForDuplicatedComputation(torch.autograd.Function):
"""All-gather whose backward just returns the local gradient slice (no reduce).

Expand Down Expand Up @@ -68,6 +178,10 @@ class HuggingfaceAttention(MegatronModule, ABC):
"cross attn" specializations.
"""

# Subclasses set this to True when the underlying module handles CP natively
# (e.g. via fla's state-passing CP for DeltaNet), bypassing the all-gather.
hybrid_cp: bool = False

def __init__(
self,
args,
Expand Down Expand Up @@ -115,7 +229,22 @@ def forward(
group=mpu.get_tensor_model_parallel_group(),
)

if mpu.get_context_parallel_world_size() > 1:
if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp:
cp_size = mpu.get_context_parallel_world_size()
# Native fla CP expects each rank to own a contiguous shard of the
# packed global token stream. In allgather-CP mode the data pipeline
# already provides that layout, so no extra relayout is
# needed here.
if not self.args.allgather_cp:
hidden_states = _zigzag_to_packed_shard(
hidden_states,
cu_seqlens,
mpu.get_context_parallel_group(),
mpu.get_context_parallel_rank(),
cp_size,
)

elif mpu.get_context_parallel_world_size() > 1:
cp_size = mpu.get_context_parallel_world_size()
# Use custom all-gather whose backward returns local gradient
# instead of reduce-scatter, since the computation is duplicated.
Expand Down Expand Up @@ -150,7 +279,17 @@ def forward(

output = output.permute(1, 0, 2) # [seq_len, bsz, hidden_dim]

if mpu.get_context_parallel_world_size() > 1:
if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp:
if not self.args.allgather_cp:
output = _packed_shard_to_zigzag(
output,
cu_seqlens,
mpu.get_context_parallel_group(),
mpu.get_context_parallel_rank(),
cp_size,
)

elif mpu.get_context_parallel_world_size() > 1:
cp_rank = mpu.get_context_parallel_rank()
output_list = []
for i in range(len(cu_seqlens) - 1):
Expand Down
44 changes: 31 additions & 13 deletions miles_plugins/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
except ImportError:
pass

from miles.backends.training_utils.cp_utils import build_gdn_cp_context

from .hf_attention import HuggingfaceAttention, _load_hf_config


Expand Down Expand Up @@ -88,17 +90,21 @@ def forward(
):
batch_size, seq_len, _ = hidden_states.shape

cp_context = build_gdn_cp_context(self, cu_seqlens, hidden_states.device)

# Projections (flat layout: [Q_all, K_all, V_all])
mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

# Convolution on the flat QKV
# Convolution on the flat QKV (pass cp_context for boundary handling)
conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens
mixed_qkv, _ = self.conv1d(
x=mixed_qkv,
cu_seqlens=cu_seqlens,
cu_seqlens=conv_cu_seqlens,
cp_context=cp_context,
)

# Split into Q, K, V (flat split, matching HF layout)
Expand All @@ -118,17 +124,29 @@ def forward(
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
if cp_context is not None:
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cp_context.cu_seqlens,
cp_context=cp_context,
)
else:
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)

z_shape_og = z.shape
# reshape input data into 2D tensor
Expand Down
Loading
Loading