-
Notifications
You must be signed in to change notification settings - Fork 150
feat: real cp support with relayout fix for qwen3.5 train/rollout mismatch #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
304c748
e42a5d2
5d74611
b99ac8f
302ccfc
b6b4a1f
18699a3
7af979b
c5f194d
89d6dcd
ea7fafa
463a6ca
dca77e9
e304ef7
4471d46
2842d57
7314c5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2088,6 +2088,7 @@ def equal(x, y): | |
| if "rope_theta" in hf_config.rope_parameters: | ||
| hf_config.rope_theta = hf_config.rope_parameters["rope_theta"] | ||
|
|
||
| is_moe = hasattr(hf_config, "num_experts") or hasattr(hf_config, "moe_intermediate_size") | ||
|
||
| for hf_config_name, megatron_config_name, compare_fn in [ | ||
| ("hidden_size", "hidden_size", equal), | ||
| ("num_attention_heads", "num_attention_heads", equal), | ||
|
|
@@ -2102,6 +2103,8 @@ def equal(x, y): | |
| ("rope_theta", "rotary_base", equal), | ||
| ]: | ||
| if hasattr(hf_config, hf_config_name): | ||
| if is_moe and hf_config_name == "intermediate_size": | ||
|
||
| continue | ||
| if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): | ||
| errors.append( | ||
| f"{hf_config_name} in hf config {getattr(hf_config, hf_config_name)} is not equal to " | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,12 @@ | |
| except ImportError: | ||
| pass | ||
|
|
||
| try: | ||
| from fla.ops.cp import FLACPContext, build_cp_context | ||
| except ImportError: | ||
| FLACPContext = None | ||
| build_cp_context = None | ||
|
|
||
| from .hf_attention import HuggingfaceAttention, _load_hf_config | ||
|
|
||
|
|
||
|
|
@@ -81,24 +87,41 @@ def __init__(self, config, layer_idx: int): | |
|
|
||
| self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) | ||
|
|
||
| def _build_cp_context(self, local_seq_len: int, device: torch.device): | ||
| """Build fla CP context from the local (sharded) sequence length.""" | ||
| cp_group = getattr(self, "cp_group", None) | ||
| if cp_group is None or build_cp_context is None: | ||
| return None | ||
| global_seq_len = local_seq_len * self.cp_world_size | ||
| global_cu_seqlens = torch.tensor([0, global_seq_len], dtype=torch.int32, device=device) | ||
| return build_cp_context( | ||
| cu_seqlens=global_cu_seqlens, | ||
| group=cp_group, | ||
| conv1d_kernel_size=self.conv_kernel_size, | ||
| ) | ||
|
||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| cu_seqlens: torch.Tensor = None, | ||
| ): | ||
| batch_size, seq_len, _ = hidden_states.shape | ||
|
|
||
| cp_context = self._build_cp_context(seq_len, hidden_states.device) | ||
|
||
|
|
||
| # Projections (flat layout: [Q_all, K_all, V_all]) | ||
| mixed_qkv = self.in_proj_qkv(hidden_states) | ||
| z = self.in_proj_z(hidden_states) | ||
| z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) | ||
| b = self.in_proj_b(hidden_states) | ||
| a = self.in_proj_a(hidden_states) | ||
|
|
||
| # Convolution on the flat QKV | ||
| # Convolution on the flat QKV (pass cp_context for boundary handling) | ||
| conv_cu_seqlens = cp_context.cu_seqlens if cp_context is not None else cu_seqlens | ||
| mixed_qkv, _ = self.conv1d( | ||
| x=mixed_qkv, | ||
| cu_seqlens=cu_seqlens, | ||
| cu_seqlens=conv_cu_seqlens, | ||
| cp_context=cp_context, | ||
| ) | ||
|
|
||
| # Split into Q, K, V (flat split, matching HF layout) | ||
|
|
@@ -118,17 +141,29 @@ def forward( | |
| query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | ||
| key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | ||
|
|
||
| core_attn_out, last_recurrent_state = chunk_gated_delta_rule( | ||
| query, | ||
| key, | ||
| value, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=None, | ||
| output_final_state=False, | ||
| use_qk_l2norm_in_kernel=True, | ||
| cu_seqlens=cu_seqlens, | ||
| ) | ||
| if cp_context is not None: | ||
| core_attn_out, _ = chunk_gated_delta_rule( | ||
| query, | ||
| key, | ||
| value, | ||
| g=g, | ||
| beta=beta, | ||
| use_qk_l2norm_in_kernel=True, | ||
| cu_seqlens=cp_context.cu_seqlens, | ||
| cp_context=cp_context, | ||
| ) | ||
| else: | ||
| core_attn_out, _ = chunk_gated_delta_rule( | ||
| query, | ||
| key, | ||
| value, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=None, | ||
| output_final_state=False, | ||
| use_qk_l2norm_in_kernel=True, | ||
| cu_seqlens=cu_seqlens, | ||
| ) | ||
|
|
||
| z_shape_og = z.shape | ||
| # reshape input data into 2D tensor | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,12 @@ | |||||
| except ImportError: | ||||||
| pass | ||||||
|
|
||||||
| try: | ||||||
| from fla.ops.cp import FLACPContext, build_cp_context | ||||||
| except ImportError: | ||||||
| FLACPContext = None | ||||||
| build_cp_context = None | ||||||
|
|
||||||
| from .hf_attention import HuggingfaceAttention | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -74,6 +80,19 @@ def __init__(self, config, layer_idx: int): | |||||
|
|
||||||
| self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) | ||||||
|
|
||||||
| def _build_cp_context(self, local_seq_len: int, device: torch.device): | ||||||
| """Build fla CP context from the local (sharded) sequence length.""" | ||||||
| cp_group = getattr(self, "cp_group", None) | ||||||
| if cp_group is None or build_cp_context is None: | ||||||
| return None | ||||||
| global_seq_len = local_seq_len * self.cp_world_size | ||||||
| global_cu_seqlens = torch.tensor([0, global_seq_len], dtype=torch.int32, device=device) | ||||||
| return build_cp_context( | ||||||
| cu_seqlens=global_cu_seqlens, | ||||||
| group=cp_group, | ||||||
| conv1d_kernel_size=self.conv_kernel_size, | ||||||
| ) | ||||||
|
||||||
|
|
||||||
| def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): | ||||||
| """ | ||||||
| Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. | ||||||
|
|
@@ -108,16 +127,20 @@ def forward( | |||||
| hidden_states: torch.Tensor, | ||||||
| cu_seqlens: torch.Tensor = None, | ||||||
| ): | ||||||
| cp_context = self._build_cp_context(hidden_states.shape[1], hidden_states.device) | ||||||
|
||||||
| cp_context = self._build_cp_context(hidden_states.shape[1], hidden_states.device) | |
| cp_context = self._build_cp_context(hidden_states.shape[0], hidden_states.shape[1], hidden_states.device) |
Uh oh!
There was an error while loading. Please reload this page.