-
Notifications
You must be signed in to change notification settings - Fork 163
[Linear Attention] Update fused_recurrent.py for inference with nomalization=true #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThe changes introduce an additional optional parameter in three functions to support cumulative tensor handling. The function Changes
Sequence Diagram(s)sequenceDiagram
participant FR as fused_recurrent_linear_attn
participant NO as normalize_output
FR->>NO: Call normalize_output(q * scale, k, o, cum_k)
alt cum_k provided
NO->>NO: Compute k = k + cum_k
else No cum_k
NO->>NO: Proceed without modifying k
end
NO->>FR: Return normalized output
participant FC as fused_chunk_linear_attn
FC->>NO: Call normalize_output(q * scale, k, o, cum_k)
alt cum_k provided
NO->>NO: Compute k = k + cum_k
else No cum_k
NO->>NO: Proceed without modifying k
end
NO->>FC: Return normalized output
Poem
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
fla/ops/linear_attn/fused_recurrent.py (2)
238-238
: Consider consistent naming convention across filesThe parameter here is named
cum_K
with uppercase K, while innormalize_output
it's namedcum_k
with lowercase k. Consider using consistent naming across the codebase to avoid confusion.- cum_K: torch.Tensor = None, + cum_k: torch.Tensor = None,
249-249
: Fix parameter case and formatting
- The parameter case should match between the function definition and where it's used.
- Add a space after the comma for better readability.
- o = normalize_output(q * scale, k, o,cum_K) + o = normalize_output(q * scale, k, o, cum_K)If you choose to update the parameter name for consistency:
- o = normalize_output(q * scale, k, o,cum_K) + o = normalize_output(q * scale, k, o, cum_k)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/ops/linear_attn/fused_recurrent.py
(2 hunks)fla/ops/linear_attn/utils.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
fla/ops/linear_attn/utils.py
10-10: Undefined name cum_K
(F821)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
fla/ops/linear_attn/fused_chunk.py (1)
279-279
: Add parameter documentation in the docstring.The new
cum_k
parameter has been added to the function signature, but its description is missing from the docstring. For better maintainability and clarity, please update the docstring to explain the purpose ofcum_k
, its expected type, and when it should be provided.Consider adding a description like:
cum_k (Optional[torch.Tensor]): Cumulative key tensor used during normalization when inference with saved state. Defaults to `None`.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/linear_attn/fused_chunk.py
(2 hunks)
🔇 Additional comments (1)
fla/ops/linear_attn/fused_chunk.py (1)
316-316
:✅ Verification successful
Fix for normalization with KV cache.
This change correctly addresses the issue described in the PR where QK1 (or Q@sum(K)) is missing earlier Keys when normalization is enabled. By passing the cumulative
cum_k
parameter to thenormalize_output
function, the normalization process can now properly calculate the output asQKV/QK1
including all previous keys in the cache.
🏁 Script executed:
#!/bin/bash # Check if normalize_output has been updated to accept cum_k parameter cat fla/ops/linear_attn/utils.py | grep -A 10 "def normalize_output"Length of output: 247
Normalization with KV Cache Fix Confirmed
The updated implementation of
normalize_output
now correctly accepts thecum_k
parameter, ensuring that earlier keys (the cumulative key values) are incorporated into the normalization process. This resolves the issue where QK1 (or Q @ sum(K)) was missing contributions from previous keys when normalization was enabled.
- Location:
fla/ops/linear_attn/utils.py
- Key Change: In
normalize_output
, the cumulative keys are now added to the cumsum ofk
whencum_k
is notNone
.
Thanks for contributing, can you @yiyousong add tests to your contribution? This will improve the robustness of the code @yzhangcs could you please give some comments? |
Sorry, I don't think I understand how the tests works. |
You could have a try:) pip install pytest
export COMPILER_MODE=1 # to speed up
pytest tests/ops/test_linear_attn.py
pytest tests/layers/test_linearatten_layer.py You can see it will test function automatically. The thing you need to do is to test your |
@yiyousong Hello, could you please explain more on what does this arg mean and what's the purpose of imposing this arg |
Linear attention without normalization equals to Linear attention with normalization equals to However, implementation was harder than I thought, as the compiled function does not take if statement. So I cannot just simply add a few parameters. Maybe you need to change all the code involving normalization and cache |
@yiyousong Thank you, good point! We do need to suuport this. But I dont think |
fla/ops/linear_attn/utils.py
Outdated
@@ -4,7 +4,9 @@ | |||
|
|||
|
|||
@torch.jit.script | |||
def normalize_output(q, k, o): | |||
def normalize_output(q, k, o, cum_k=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyousong Maybe we could pass initial_state
as an arg with cum_k
included for API consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only use ops, so I like passing in separately. However, your code you choice. I believe it doesn't matter as long as you don't merge them into one tensor
@@ -235,6 +235,7 @@ def fused_recurrent_linear_attn( | |||
v: torch.Tensor, | |||
scale: Optional[float] = None, | |||
initial_state: torch.Tensor = None, | |||
cum_k: torch.Tensor = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyousong I think we could make initial_state a Tuple if normalize is True. What do you think?
@@ -235,6 +235,7 @@ def fused_recurrent_linear_attn( | |||
v: torch.Tensor, | |||
scale: Optional[float] = None, | |||
initial_state: torch.Tensor = None, | |||
cum_k: torch.Tensor = None, | |||
output_final_state: bool = False, | |||
normalize: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyousong Could you please add some docstrings BTW.
@yiyousong Hello r u still working on this PR? |
I was evaluating using my own code. (only used fla.ops, not fla.layers). |
These changes are based on the code I changed to work for my model. I probably won't work on this further. |
o = normalize_output(q * scale, k, o) | ||
if z_state is None: | ||
k_shape = list(k.shape) | ||
k_shape[-2 ]= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the space in [-2] would be better.
How about directly init z_state by
z_state = torch.zeros_like(k[..., 0, :]) if z_state is None else z_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I think [B, H, K, 1]
could be confusing, would [B, H, K]
be better.
There's no cost for unsqueeze when updating z state
cbd196d
to
6b673e6
Compare
aacf017
to
0d815aa
Compare
the current linear attention can save a$KV$ state cache. This works when normalization is not enabled. When normalization is enabled. the output should be $\frac{QKV}{QK1}$ . we can see that $QK1$ or Q@sum(K) is missing earlier Keys
last pull request only modified one file, not sure why this happen, re-opened this, hope this version does contain two changes
Summary by CodeRabbit