diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py index bfcc1212a..e670ad32e 100644 --- a/fla/ops/linear_attn/fused_chunk.py +++ b/fla/ops/linear_attn/fused_chunk.py @@ -276,9 +276,10 @@ def fused_chunk_linear_attn( v: torch.Tensor, scale: Optional[float] = None, initial_state: torch.Tensor = None, + z_state: torch.Tensor = None, output_final_state: bool = False, normalize: bool = True, - head_first: bool = True + head_first: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: @@ -293,6 +294,8 @@ def fused_chunk_linear_attn( If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): Initial state of shape `[B, H, K, V]`. Default: `None`. + z_state (Optional[torch.Tensor]): + Z state of shape `[B, H, K, 1]`. This is only needed when normalization is enabled. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. normalize (bool): @@ -311,8 +314,16 @@ def fused_chunk_linear_attn( if not head_first: q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: - o = normalize_output(q * scale, k, o) + if z_state is None: + k_shape = list(k.shape) + k_shape[-2 ]= 1 + z_state = k.new_zeros(k_shape) + o, z_state = normalize_output(q * scale, k, o, z_state) if not head_first: o = o.transpose(1, 2) + + if normalize: + return o, (final_state, z_state) return o, final_state diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index b50b8c7bf..2394cf219 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -235,17 +235,54 @@ def fused_recurrent_linear_attn( v: torch.Tensor, scale: Optional[float] = None, initial_state: torch.Tensor = None, + z_state: torch.Tensor = None, output_final_state: bool = False, - normalize: bool = False, + normalize: bool = True, head_first: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + scale (Optional[int]): + Scale factor for linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + z_state (Optional[torch.Tensor]): + Z state Of shape `[B, H, K, 1]. This is only needed when normalization is enabled. `. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ if scale is None: scale = q.shape[-1] ** -0.5 if not head_first: q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: - o = normalize_output(q * scale, k, o) + if z_state is None: + k_shape = list(k.shape) + k_shape[-2 ]= 1 + z_state = k.new_zeros(k_shape) + o, z_state = normalize_output(q * scale, k, o, z_state) if not head_first: o = o.transpose(1, 2) + + if normalize: + return o, (final_state, z_state) return o, final_state diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py index b44437683..8b3727bc1 100644 --- a/fla/ops/linear_attn/utils.py +++ b/fla/ops/linear_attn/utils.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import torch @torch.jit.script -def normalize_output(q, k, o): +def normalize_output(q, k, o, z_state): k = k.cumsum(-2) + k = k + z_state z = (q * k).sum(-1, keepdim=True) - return o / (z + 1e-10) + return o / (z + 1e-10), k[...,-1:,:] +