Skip to content

Commit 4a7a9f4

Browse files
committed
[GLA] Clean head_first option
1 parent d7d1d53 commit 4a7a9f4

File tree

3 files changed

+11
-63
lines changed

3 files changed

+11
-63
lines changed

fla/ops/gla/chunk.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
87
import triton
98
import triton.language as tl
10-
from einops import rearrange
119

1210
from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
1311
from fla.ops.utils import prepare_chunk_indices
@@ -1225,18 +1223,17 @@ def chunk_gla(
12251223
initial_state: torch.Tensor = None,
12261224
output_final_state: bool = False,
12271225
cu_seqlens: Optional[torch.LongTensor] = None,
1228-
head_first: bool = False
12291226
) -> Tuple[torch.Tensor, torch.Tensor]:
12301227
r"""
12311228
Args:
12321229
q (torch.Tensor):
1233-
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1230+
queries of shape `[B, T, H, K]`.
12341231
k (torch.Tensor):
1235-
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1232+
keys of shape `[B, T, H, K]`.
12361233
v (torch.Tensor):
1237-
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1234+
values of shape `[B, T, H, V]`.
12381235
g (torch.Tensor):
1239-
Forget gates of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` applied to keys.
1236+
Forget gates of shape `[B, T, H, K]`.
12401237
scale (Optional[int]):
12411238
Scale factor for the attention scores.
12421239
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
@@ -1249,13 +1246,10 @@ def chunk_gla(
12491246
cu_seqlens (torch.LongTensor):
12501247
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
12511248
consistent with the FlashAttention API.
1252-
head_first (Optional[bool]):
1253-
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1254-
Default: `False`.
12551249
12561250
Returns:
12571251
o (torch.Tensor):
1258-
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1252+
Outputs of shape `[B, T, H, V]`.
12591253
final_state (torch.Tensor):
12601254
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
12611255
@@ -1289,19 +1283,6 @@ def chunk_gla(
12891283
>>> assert o.allclose(o_var.view(o.shape))
12901284
>>> assert ht.allclose(ht_var)
12911285
"""
1292-
if head_first:
1293-
raise DeprecationWarning(
1294-
"head_first is deprecated and will be removed in a future version. "
1295-
"Please use head_first=False for now instead."
1296-
)
1297-
q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g))
1298-
if not head_first and q.shape[1] < q.shape[2]:
1299-
warnings.warn(
1300-
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
1301-
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
1302-
"when head_first=False was specified. "
1303-
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
1304-
)
13051286
if cu_seqlens is not None:
13061287
if q.shape[0] != 1:
13071288
raise ValueError(
@@ -1316,6 +1297,4 @@ def chunk_gla(
13161297
if scale is None:
13171298
scale = q.shape[-1] ** -0.5
13181299
o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens)
1319-
if head_first:
1320-
o = rearrange(o, 'b t h ... -> b h t ...')
13211300
return o, final_state

fla/ops/gla/fused_chunk.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -615,16 +615,11 @@ def fused_chunk_gla(
615615
scale: int = -1,
616616
initial_state: torch.Tensor = None,
617617
output_final_state: bool = False,
618-
head_first: bool = False
619618
) -> Tuple[torch.Tensor, torch.Tensor]:
620619
if scale == -1:
621620
scale = q.shape[-1] ** -0.5
622-
if not head_first:
623-
q, k, v, g = map(lambda x: x.transpose(1, 2), (q, k, v, g))
624621
seq_len = q.shape[-2]
625622
q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
626623
o, final_state = FusedChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)
627624
o = o[..., :seq_len, :].contiguous()
628-
if not head_first:
629-
o = o.transpose(1, 2)
630625
return o, final_state

fla/ops/gla/fused_recurrent.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2024, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
8-
from einops import rearrange
97

108
from fla.ops.common.fused_recurrent import fused_recurrent
119

@@ -21,20 +19,19 @@ def fused_recurrent_gla(
2119
output_final_state: bool = False,
2220
reverse: bool = False,
2321
cu_seqlens: Optional[torch.LongTensor] = None,
24-
head_first: bool = False
2522
) -> Tuple[torch.Tensor, torch.Tensor]:
2623
r"""
2724
Args:
2825
q (torch.Tensor):
29-
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
26+
queries of shape `[B, T, H, K]`.
3027
k (torch.Tensor):
31-
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
28+
keys of shape `[B, T, H, K]`.
3229
v (torch.Tensor):
33-
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
30+
values of shape `[B, T, H, V]`.
3431
gk (torch.Tensor):
35-
Forget gates of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` applied to keys.
32+
Forget gates of shape `[B, T, H, K]`.
3633
gv (torch.Tensor):
37-
Forget gates of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` applied to values.
34+
Forget gates of shape `[B, T, H, V]` applied to values.
3835
scale (Optional[int]):
3936
Scale factor for the attention scores.
4037
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
@@ -49,13 +46,10 @@ def fused_recurrent_gla(
4946
cu_seqlens (torch.LongTensor):
5047
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
5148
consistent with the FlashAttention API.
52-
head_first (Optional[bool]):
53-
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
54-
Default: `False`.
5549
5650
Returns:
5751
o (torch.Tensor):
58-
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
52+
Outputs of shape `[B, T, H, V]`.
5953
final_state (torch.Tensor):
6054
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
6155
@@ -87,25 +81,7 @@ def fused_recurrent_gla(
8781
cu_seqlens=cu_seqlens
8882
)
8983
>>> assert o.allclose(o_var.view(o.shape))
90-
>>> assert ht.allclose(ht_var)
9184
"""
92-
if head_first:
93-
raise DeprecationWarning(
94-
"head_first is deprecated and will be removed in a future version. "
95-
"Please use head_first=False for now instead."
96-
)
97-
q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v))
98-
if gk is not None:
99-
gk = rearrange(gk, 'b h t ... -> b t h ...')
100-
if gv is not None:
101-
gv = rearrange(gv, 'b h t ... -> b t h ...')
102-
if not head_first and q.shape[1] < q.shape[2]:
103-
warnings.warn(
104-
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
105-
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
106-
"when head_first=False was specified. "
107-
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
108-
)
10985
if cu_seqlens is not None:
11086
if q.shape[0] != 1:
11187
raise ValueError(
@@ -132,6 +108,4 @@ def fused_recurrent_gla(
132108
reverse=reverse,
133109
cu_seqlens=cu_seqlens,
134110
)
135-
if head_first:
136-
o = rearrange(o, 'b t h ... -> b h t ...')
137111
return o, final_state

0 commit comments

Comments
 (0)