Skip to content

[Linear Attention] Update fused_recurrent.py for inference with nomalization=true #268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
15 changes: 13 additions & 2 deletions fla/ops/linear_attn/fused_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,10 @@ def fused_chunk_linear_attn(
v: torch.Tensor,
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
z_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True,
head_first: bool = True
head_first: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
Expand All @@ -293,6 +294,8 @@ def fused_chunk_linear_attn(
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[B, H, K, V]`. Default: `None`.
z_state (Optional[torch.Tensor]):
Z state of shape `[B, H, K, 1]`. This is only needed when normalization is enabled. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
normalize (bool):
Expand All @@ -311,8 +314,16 @@ def fused_chunk_linear_attn(
if not head_first:
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)

if normalize:
o = normalize_output(q * scale, k, o)
if z_state is None:
k_shape = list(k.shape)
k_shape[-2 ]= 1
z_state = k.new_zeros(k_shape)
o, z_state = normalize_output(q * scale, k, o, z_state)
if not head_first:
o = o.transpose(1, 2)

if normalize:
return o, (final_state, z_state)
return o, final_state
41 changes: 39 additions & 2 deletions fla/ops/linear_attn/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,54 @@ def fused_recurrent_linear_attn(
v: torch.Tensor,
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
z_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = False,
normalize: bool = True,
head_first: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
k (torch.Tensor):
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
v (torch.Tensor):
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
scale (Optional[int]):
Scale factor for linear attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[B, H, K, V]`. Default: `None`.
z_state (Optional[torch.Tensor]):
Z state Of shape `[B, H, K, 1]. This is only needed when normalization is enabled. `. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
normalize (bool):
Whether to normalize the output. Default: `True`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `True`.

Returns:
o (torch.Tensor):
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
final_state (torch.Tensor):
Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`
"""
if scale is None:
scale = q.shape[-1] ** -0.5
if not head_first:
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)

if normalize:
o = normalize_output(q * scale, k, o)
if z_state is None:
k_shape = list(k.shape)
k_shape[-2 ]= 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the space in [-2] would be better.
How about directly init z_state by

z_state =  torch.zeros_like(k[..., 0, :]) if z_state is None else z_state 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I think [B, H, K, 1] could be confusing, would [B, H, K] be better.
There's no cost for unsqueeze when updating z state

z_state = k.new_zeros(k_shape)
o, z_state = normalize_output(q * scale, k, o, z_state)
if not head_first:
o = o.transpose(1, 2)

if normalize:
return o, (final_state, z_state)
return o, final_state
7 changes: 5 additions & 2 deletions fla/ops/linear_attn/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

import torch


@torch.jit.script
def normalize_output(q, k, o):
def normalize_output(q, k, o, z_state):
k = k.cumsum(-2)
k = k + z_state
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-10)
return o / (z + 1e-10), k[...,-1:,:]