-
Notifications
You must be signed in to change notification settings - Fork 205
[Linear Attention] Update fused_recurrent.py for inference with nomalization=true #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
287ffae
30d5606
8ec615a
3cb3c2a
d8965f2
1c6ea0c
b5d64ba
884a597
4c4c68c
20adb41
f050482
246f17c
c6ce801
ed6e92c
3ee20dd
8806758
565bbb8
1db56cb
104a2e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -235,17 +235,54 @@ def fused_recurrent_linear_attn( | |
v: torch.Tensor, | ||
scale: Optional[float] = None, | ||
initial_state: torch.Tensor = None, | ||
z_state: torch.Tensor = None, | ||
output_final_state: bool = False, | ||
normalize: bool = False, | ||
normalize: bool = True, | ||
head_first: bool = True | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
r""" | ||
Args: | ||
q (torch.Tensor): | ||
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` | ||
k (torch.Tensor): | ||
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` | ||
v (torch.Tensor): | ||
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` | ||
scale (Optional[int]): | ||
Scale factor for linear attention scores. | ||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. | ||
initial_state (Optional[torch.Tensor]): | ||
Initial state of shape `[B, H, K, V]`. Default: `None`. | ||
z_state (Optional[torch.Tensor]): | ||
Z state Of shape `[B, H, K, 1]. This is only needed when normalization is enabled. `. Default: `None`. | ||
output_final_state (Optional[bool]): | ||
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. | ||
normalize (bool): | ||
Whether to normalize the output. Default: `True`. | ||
head_first (Optional[bool]): | ||
Whether the inputs are in the head-first format. Default: `True`. | ||
|
||
Returns: | ||
o (torch.Tensor): | ||
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` | ||
final_state (torch.Tensor): | ||
Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` | ||
""" | ||
if scale is None: | ||
scale = q.shape[-1] ** -0.5 | ||
if not head_first: | ||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) | ||
o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) | ||
|
||
if normalize: | ||
o = normalize_output(q * scale, k, o) | ||
if z_state is None: | ||
k_shape = list(k.shape) | ||
k_shape[-2 ]= 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the space in [-2] would be better. z_state = torch.zeros_like(k[..., 0, :]) if z_state is None else z_state There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also I think |
||
z_state = k.new_zeros(k_shape) | ||
o, z_state = normalize_output(q * scale, k, o, z_state) | ||
if not head_first: | ||
o = o.transpose(1, 2) | ||
|
||
if normalize: | ||
return o, (final_state, z_state) | ||
return o, final_state |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | ||
|
||
import torch | ||
|
||
|
||
@torch.jit.script | ||
yzhangcs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def normalize_output(q, k, o): | ||
def normalize_output(q, k, o, z_state): | ||
k = k.cumsum(-2) | ||
yzhangcs marked this conversation as resolved.
Show resolved
Hide resolved
yzhangcs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
k = k + z_state | ||
z = (q * k).sum(-1, keepdim=True) | ||
return o / (z + 1e-10) | ||
return o / (z + 1e-10), k[...,-1:,:] | ||
|
Uh oh!
There was an error while loading. Please reload this page.