1
1
# -*- coding: utf-8 -*-
2
2
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
3
4
- import warnings
5
4
from typing import Optional , Tuple
6
5
7
6
import torch
8
7
import triton
9
8
import triton .language as tl
10
- from einops import rearrange
11
9
12
10
from fla .ops .common .chunk_h import chunk_bwd_dh , chunk_fwd_h
13
11
from fla .ops .utils import prepare_chunk_indices
@@ -1225,18 +1223,17 @@ def chunk_gla(
1225
1223
initial_state : torch .Tensor = None ,
1226
1224
output_final_state : bool = False ,
1227
1225
cu_seqlens : Optional [torch .LongTensor ] = None ,
1228
- head_first : bool = False
1229
1226
) -> Tuple [torch .Tensor , torch .Tensor ]:
1230
1227
r"""
1231
1228
Args:
1232
1229
q (torch.Tensor):
1233
- queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` .
1230
+ queries of shape `[B, T, H, K]`.
1234
1231
k (torch.Tensor):
1235
- keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` .
1232
+ keys of shape `[B, T, H, K]`.
1236
1233
v (torch.Tensor):
1237
- values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` .
1234
+ values of shape `[B, T, H, V]`.
1238
1235
g (torch.Tensor):
1239
- Forget gates of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` applied to keys .
1236
+ Forget gates of shape `[B, T, H, K]`.
1240
1237
scale (Optional[int]):
1241
1238
Scale factor for the attention scores.
1242
1239
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
@@ -1249,13 +1246,10 @@ def chunk_gla(
1249
1246
cu_seqlens (torch.LongTensor):
1250
1247
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1251
1248
consistent with the FlashAttention API.
1252
- head_first (Optional[bool]):
1253
- Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1254
- Default: `False`.
1255
1249
1256
1250
Returns:
1257
1251
o (torch.Tensor):
1258
- Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` .
1252
+ Outputs of shape `[B, T, H, V]`.
1259
1253
final_state (torch.Tensor):
1260
1254
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
1261
1255
@@ -1289,19 +1283,6 @@ def chunk_gla(
1289
1283
>>> assert o.allclose(o_var.view(o.shape))
1290
1284
>>> assert ht.allclose(ht_var)
1291
1285
"""
1292
- if head_first :
1293
- raise DeprecationWarning (
1294
- "head_first is deprecated and will be removed in a future version. "
1295
- "Please use head_first=False for now instead."
1296
- )
1297
- q , k , v , g = map (lambda x : rearrange (x , 'b h t ... -> b t h ...' ), (q , k , v , g ))
1298
- if not head_first and q .shape [1 ] < q .shape [2 ]:
1299
- warnings .warn (
1300
- f"Input tensor shape suggests potential format mismatch: seq_len ({ q .shape [1 ]} ) < num_heads ({ q .shape [2 ]} ). "
1301
- "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
1302
- "when head_first=False was specified. "
1303
- "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
1304
- )
1305
1286
if cu_seqlens is not None :
1306
1287
if q .shape [0 ] != 1 :
1307
1288
raise ValueError (
@@ -1316,6 +1297,4 @@ def chunk_gla(
1316
1297
if scale is None :
1317
1298
scale = q .shape [- 1 ] ** - 0.5
1318
1299
o , final_state = ChunkGLAFunction .apply (q , k , v , g , scale , initial_state , output_final_state , cu_seqlens )
1319
- if head_first :
1320
- o = rearrange (o , 'b t h ... -> b h t ...' )
1321
1300
return o , final_state
0 commit comments