Skip to content

Commit 43f136e

Browse files
committed
Raise DeprecationWarning if head_first passed
1 parent 3a7ecbf commit 43f136e

26 files changed

+53
-78
lines changed

fla/ops/attn/parallel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -713,15 +712,15 @@ def parallel_attn(
713712
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
714713
"""
715714
if head_first:
716-
warnings.warn(
715+
raise DeprecationWarning(
717716
"head_first is deprecated and will be removed in a future version. "
718717
"Please use head_first=False for now instead."
719718
)
720719
q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v))
721720
if g is not None:
722721
g = rearrange(g, 'b h t ... -> b t h ...')
723722
if not head_first and q.shape[1] < q.shape[2]:
724-
warnings.warn(
723+
raise DeprecationWarning(
725724
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
726725
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
727726
"when head_first=False was specified. "

fla/ops/delta_rule/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -280,13 +279,13 @@ def chunk_delta_rule(
280279
assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
281280

282281
if head_first:
283-
warnings.warn(
282+
raise DeprecationWarning(
284283
"head_first is deprecated and will be removed in a future version. "
285284
"Please use head_first=False for now instead."
286285
)
287286
q, k, v, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta))
288287
if not head_first and q.shape[1] < q.shape[2]:
289-
warnings.warn(
288+
raise DeprecationWarning(
290289
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
291290
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
292291
"when head_first=False was specified. "

fla/ops/delta_rule/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -514,13 +513,13 @@ def fused_recurrent_delta_rule(
514513
>>> assert ht.allclose(ht_var)
515514
"""
516515
if head_first:
517-
warnings.warn(
516+
raise DeprecationWarning(
518517
"head_first is deprecated and will be removed in a future version. "
519518
"Please use head_first=False for now instead."
520519
)
521520
q, k, v, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta))
522521
if not head_first and q.shape[1] < q.shape[2]:
523-
warnings.warn(
522+
raise DeprecationWarning(
524523
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
525524
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
526525
"when head_first=False was specified. "

fla/ops/forgetting_attn/parallel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -49,13 +48,13 @@ def parallel_forgetting_attn(
4948
if cu_seqlens is not None:
5049
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
5150
if head_first:
52-
warnings.warn(
51+
raise DeprecationWarning(
5352
"head_first is deprecated and will be removed in a future version. "
5453
"Please use head_first=False for now instead."
5554
)
5655
q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g))
5756
if not head_first and q.shape[1] < q.shape[2]:
58-
warnings.warn(
57+
raise DeprecationWarning(
5958
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
6059
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
6160
"when head_first=False was specified. "

fla/ops/gated_delta_rule/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -313,13 +312,13 @@ def chunk_gated_delta_rule(
313312
assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
314313

315314
if head_first:
316-
warnings.warn(
315+
raise DeprecationWarning(
317316
"head_first is deprecated and will be removed in a future version. "
318317
"Please use head_first=False for now instead."
319318
)
320319
q, k, v, beta, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta, g))
321320
if not head_first and q.shape[1] < q.shape[2]:
322-
warnings.warn(
321+
raise DeprecationWarning(
323322
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
324323
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
325324
"when head_first=False was specified. "

fla/ops/gated_delta_rule/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -283,13 +282,13 @@ def fused_recurrent_gated_delta_rule(
283282
>>> assert ht.allclose(ht_var)
284283
"""
285284
if head_first:
286-
warnings.warn(
285+
raise DeprecationWarning(
287286
"head_first is deprecated and will be removed in a future version. "
288287
"Please use head_first=False for now instead."
289288
)
290289
q, k, v, beta, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta, g))
291290
if not head_first and q.shape[1] < q.shape[2]:
292-
warnings.warn(
291+
raise DeprecationWarning(
293292
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
294293
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
295294
"when head_first=False was specified. "

fla/ops/generalized_delta_rule/dplr/chunk.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional
65

76
import torch
@@ -318,20 +317,20 @@ def chunk_dplr_delta_rule(
318317
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319318
"""
320319
if head_first:
321-
warnings.warn(
320+
raise DeprecationWarning(
322321
"head_first is deprecated and will be removed in a future version. "
323322
"Please use head_first=False for now instead."
324323
)
325324
q, k, v, a, b, gk = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b, gk))
326325
if not head_first and q.shape[1] < q.shape[2]:
327-
warnings.warn(
326+
raise DeprecationWarning(
328327
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
329328
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
330329
"when head_first=False was specified. "
331330
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
332331
)
333332
if q.dtype == torch.float32:
334-
warnings.warn(
333+
raise DeprecationWarning(
335334
"""ChunkDeltaRuleFunction does not support float32. Please use bfloat16.
336335
If you want to use float32, please solve the issue by yourself."""
337336
)

fla/ops/generalized_delta_rule/dplr/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -249,13 +248,13 @@ def fused_recurrent_dplr_delta_rule(
249248
Default: `False`.
250249
"""
251250
if head_first:
252-
warnings.warn(
251+
raise DeprecationWarning(
253252
"head_first is deprecated and will be removed in a future version. "
254253
"Please use head_first=False for now instead."
255254
)
256255
q, k, v, a, b, gk = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b, gk))
257256
if not head_first and q.shape[1] < q.shape[2]:
258-
warnings.warn(
257+
raise DeprecationWarning(
259258
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
260259
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
261260
"when head_first=False was specified. "

fla/ops/generalized_delta_rule/iplr/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -462,13 +461,13 @@ def chunk_iplr_delta_rule(
462461
assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
463462

464463
if head_first:
465-
warnings.warn(
464+
raise DeprecationWarning(
466465
"head_first is deprecated and will be removed in a future version. "
467466
"Please use head_first=False for now instead."
468467
)
469468
q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
470469
if not head_first and q.shape[1] < q.shape[2]:
471-
warnings.warn(
470+
raise DeprecationWarning(
472471
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
473472
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
474473
"when head_first=False was specified. "

fla/ops/generalized_delta_rule/iplr/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2024-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -427,13 +426,13 @@ def fused_recurrent_iplr_delta_rule(
427426
428427
"""
429428
if head_first:
430-
warnings.warn(
429+
raise DeprecationWarning(
431430
"head_first is deprecated and will be removed in a future version. "
432431
"Please use head_first=False for now instead."
433432
)
434433
q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
435434
if not head_first and q.shape[1] < q.shape[2]:
436-
warnings.warn(
435+
raise DeprecationWarning(
437436
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
438437
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
439438
"when head_first=False was specified. "

fla/ops/gla/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -1290,13 +1289,13 @@ def chunk_gla(
12901289
>>> assert ht.allclose(ht_var)
12911290
"""
12921291
if head_first:
1293-
warnings.warn(
1292+
raise DeprecationWarning(
12941293
"head_first is deprecated and will be removed in a future version. "
12951294
"Please use head_first=False for now instead."
12961295
)
12971296
q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g))
12981297
if not head_first and q.shape[1] < q.shape[2]:
1299-
warnings.warn(
1298+
raise DeprecationWarning(
13001299
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
13011300
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
13021301
"when head_first=False was specified. "

fla/ops/gla/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2024, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -90,7 +89,7 @@ def fused_recurrent_gla(
9089
>>> assert ht.allclose(ht_var)
9190
"""
9291
if head_first:
93-
warnings.warn(
92+
raise DeprecationWarning(
9493
"head_first is deprecated and will be removed in a future version. "
9594
"Please use head_first=False for now instead."
9695
)
@@ -100,7 +99,7 @@ def fused_recurrent_gla(
10099
if gv is not None:
101100
gv = rearrange(gv, 'b h t ... -> b t h ...')
102101
if not head_first and q.shape[1] < q.shape[2]:
103-
warnings.warn(
102+
raise DeprecationWarning(
104103
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
105104
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
106105
"when head_first=False was specified. "

fla/ops/gsa/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -1082,13 +1081,13 @@ def chunk_gsa(
10821081
>>> assert hv.allclose(hv_var)
10831082
"""
10841083
if head_first:
1085-
warnings.warn(
1084+
raise DeprecationWarning(
10861085
"head_first is deprecated and will be removed in a future version. "
10871086
"Please use head_first=False for now instead."
10881087
)
10891088
q, k, v, s, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, s, g))
10901089
if not head_first and q.shape[1] < q.shape[2]:
1091-
warnings.warn(
1090+
raise DeprecationWarning(
10921091
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
10931092
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
10941093
"when head_first=False was specified. "

fla/ops/gsa/fused_recurrent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2024, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -510,13 +509,13 @@ def fused_recurrent_gsa(
510509
>>> assert hv.allclose(hv_var)
511510
"""
512511
if head_first:
513-
warnings.warn(
512+
raise DeprecationWarning(
514513
"head_first is deprecated and will be removed in a future version. "
515514
"Please use head_first=False for now instead."
516515
)
517516
q, k, v, s, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, s, g))
518517
if not head_first and q.shape[1] < q.shape[2]:
519-
warnings.warn(
518+
raise DeprecationWarning(
520519
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
521520
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
522521
"when head_first=False was specified. "

fla/ops/linear_attn/chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Yu Zhang, Songlin Yang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -51,14 +50,14 @@ def chunk_linear_attn(
5150
if scale is None:
5251
scale = k.shape[-1] ** -0.5
5352
if head_first:
54-
warnings.warn(
53+
raise DeprecationWarning(
5554
"head_first is deprecated and will be removed in a future version. "
5655
"Please use head_first=False for now instead."
5756
)
5857
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
5958
if not head_first:
6059
if q.shape[1] < q.shape[2]:
61-
warnings.warn(
60+
raise DeprecationWarning(
6261
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
6362
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
6463
"when head_first=False was specified. "

fla/ops/linear_attn/fused_chunk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
@@ -343,14 +342,14 @@ def fused_chunk_linear_attn(
343342
if scale is None:
344343
scale = q.shape[-1] ** -0.5
345344
if head_first:
346-
warnings.warn(
345+
raise DeprecationWarning(
347346
"head_first is deprecated and will be removed in a future version. "
348347
"Please use head_first=False for now instead."
349348
)
350349
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
351350
if not head_first:
352351
if q.shape[1] < q.shape[2]:
353-
warnings.warn(
352+
raise DeprecationWarning(
354353
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
355354
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
356355
"when head_first=False was specified. "

0 commit comments

Comments
 (0)