From 287ffaee26846f0370c84e6d281247c553a8986a Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:23:43 +0800 Subject: [PATCH 01/18] Update utils.py --- fla/ops/linear_attn/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py index b44437683..5497a78e4 100644 --- a/fla/ops/linear_attn/utils.py +++ b/fla/ops/linear_attn/utils.py @@ -4,7 +4,9 @@ @torch.jit.script -def normalize_output(q, k, o): +def normalize_output(q, k, o, cum_k=None): k = k.cumsum(-2) + if cum_k is not None: + k=k+cum_K z = (q * k).sum(-1, keepdim=True) return o / (z + 1e-10) From 30d5606ff75497b10d990772c8aee4a02fe2d128 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:25:21 +0800 Subject: [PATCH 02/18] Update fused_recurrent.py --- fla/ops/linear_attn/fused_recurrent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index b50b8c7bf..a39aad278 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -235,6 +235,7 @@ def fused_recurrent_linear_attn( v: torch.Tensor, scale: Optional[float] = None, initial_state: torch.Tensor = None, + cum_K: torch.Tensor = None, output_final_state: bool = False, normalize: bool = False, head_first: bool = True @@ -245,7 +246,7 @@ def fused_recurrent_linear_attn( q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) if normalize: - o = normalize_output(q * scale, k, o) + o = normalize_output(q * scale, k, o,cum_K) if not head_first: o = o.transpose(1, 2) return o, final_state From 8ec615ab1d435f0bb21e3beeee6f3888d2264b40 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:42:51 +0800 Subject: [PATCH 03/18] Update fused_recurrent.py --- fla/ops/linear_attn/fused_recurrent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index a39aad278..bc7c254ed 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -235,7 +235,7 @@ def fused_recurrent_linear_attn( v: torch.Tensor, scale: Optional[float] = None, initial_state: torch.Tensor = None, - cum_K: torch.Tensor = None, + cum_k: torch.Tensor = None, output_final_state: bool = False, normalize: bool = False, head_first: bool = True @@ -246,7 +246,7 @@ def fused_recurrent_linear_attn( q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) if normalize: - o = normalize_output(q * scale, k, o,cum_K) + o = normalize_output(q * scale, k, o,cum_k) if not head_first: o = o.transpose(1, 2) return o, final_state From 3cb3c2acab3371b065127dd08fdb3194b9a98f18 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:43:09 +0800 Subject: [PATCH 04/18] Update utils.py --- fla/ops/linear_attn/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py index 5497a78e4..73ddc6217 100644 --- a/fla/ops/linear_attn/utils.py +++ b/fla/ops/linear_attn/utils.py @@ -7,6 +7,6 @@ def normalize_output(q, k, o, cum_k=None): k = k.cumsum(-2) if cum_k is not None: - k=k+cum_K + k=k+cum_k z = (q * k).sum(-1, keepdim=True) return o / (z + 1e-10) From d8965f2a2a61edaf4b5d912a6118383f76543579 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:47:13 +0800 Subject: [PATCH 05/18] refactor --- fla/ops/linear_attn/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py index 73ddc6217..71ac945e3 100644 --- a/fla/ops/linear_attn/utils.py +++ b/fla/ops/linear_attn/utils.py @@ -7,6 +7,6 @@ def normalize_output(q, k, o, cum_k=None): k = k.cumsum(-2) if cum_k is not None: - k=k+cum_k + k = k + cum_k z = (q * k).sum(-1, keepdim=True) return o / (z + 1e-10) From 1c6ea0ce59457fb5784a44babae65a70654475fa Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:47:43 +0800 Subject: [PATCH 06/18] refactor --- fla/ops/linear_attn/fused_recurrent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index bc7c254ed..255934635 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -246,7 +246,7 @@ def fused_recurrent_linear_attn( q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) if normalize: - o = normalize_output(q * scale, k, o,cum_k) + o = normalize_output(q * scale, k, o, cum_k) if not head_first: o = o.transpose(1, 2) return o, final_state From b5d64ba82214cd49dc0d693e7f72c3a1d972a196 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:52:10 +0800 Subject: [PATCH 07/18] Update fused_chunk.py --- fla/ops/linear_attn/fused_chunk.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py index bfcc1212a..563b5966d 100644 --- a/fla/ops/linear_attn/fused_chunk.py +++ b/fla/ops/linear_attn/fused_chunk.py @@ -276,6 +276,7 @@ def fused_chunk_linear_attn( v: torch.Tensor, scale: Optional[float] = None, initial_state: torch.Tensor = None, + cum_k: torch.Tensor = None, output_final_state: bool = False, normalize: bool = True, head_first: bool = True @@ -312,7 +313,7 @@ def fused_chunk_linear_attn( q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) if normalize: - o = normalize_output(q * scale, k, o) + o = normalize_output(q * scale, k, o, cum_k) if not head_first: o = o.transpose(1, 2) return o, final_state From 884a59751c297098eb9f8b12d768da5663d427ca Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Tue, 1 Apr 2025 16:18:04 +0800 Subject: [PATCH 08/18] [Utils] Replace torch.jit.script with torch.compile --- fla/ops/linear_attn/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py index 71ac945e3..92c2b1e06 100644 --- a/fla/ops/linear_attn/utils.py +++ b/fla/ops/linear_attn/utils.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import torch -@torch.jit.script +@torch.compile def normalize_output(q, k, o, cum_k=None): k = k.cumsum(-2) if cum_k is not None: From 4c4c68c5f7a05e6195b016bc88305affef83e5f0 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 16:57:36 +0800 Subject: [PATCH 09/18] Update fused_chunk.py --- fla/ops/linear_attn/fused_chunk.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py index 563b5966d..2df944460 100644 --- a/fla/ops/linear_attn/fused_chunk.py +++ b/fla/ops/linear_attn/fused_chunk.py @@ -276,10 +276,11 @@ def fused_chunk_linear_attn( v: torch.Tensor, scale: Optional[float] = None, initial_state: torch.Tensor = None, - cum_k: torch.Tensor = None, + z_state: torch.Tensor = None, output_final_state: bool = False, normalize: bool = True, - head_first: bool = True + head_first: bool = True, + output_z_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: @@ -294,12 +295,17 @@ def fused_chunk_linear_attn( If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): Initial state of shape `[B, H, K, V]`. Default: `None`. + z_state (Optional[torch.Tensor]): + Z state Of shape `[B, H, K, 1]. This is only needed when normalization is enabled. `. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. normalize (bool): Whether to normalize the output. Default: `True`. head_first (Optional[bool]): Whether the inputs are in the head-first format. Default: `True`. + output_z_state (Optional[bool]): + Whether to output the final Z state of shape `[B, H, K, 1]`. For API consistency, we recommend to update Z outside the function. Default: `False`. + Returns: o (torch.Tensor): @@ -312,8 +318,17 @@ def fused_chunk_linear_attn( if not head_first: q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: - o = normalize_output(q * scale, k, o, cum_k) + if z_state is None: + k_shape = list(k.shape) + k_shape[-2 ]= 1 + z_state = k.new_zeros(k_shape) + o = normalize_output(q * scale, k, o, z_state) if not head_first: o = o.transpose(1, 2) + + if normalize and output_z_state: + z_state = z_state + torch.sum(k, dim = -2, keepdim = True) + return o, final_state, z_state return o, final_state From 20adb418cb0b0338972a8312d4ef386db8bac103 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:00:01 +0800 Subject: [PATCH 10/18] Update fused_recurrent.py --- fla/ops/linear_attn/fused_recurrent.py | 49 +++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index 255934635..ed27b6db0 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -235,18 +235,59 @@ def fused_recurrent_linear_attn( v: torch.Tensor, scale: Optional[float] = None, initial_state: torch.Tensor = None, - cum_k: torch.Tensor = None, + z_state: torch.Tensor = None, output_final_state: bool = False, - normalize: bool = False, - head_first: bool = True + normalize: bool = True, + head_first: bool = True, + output_z_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + scale (Optional[int]): + Scale factor for linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + z_state (Optional[torch.Tensor]): + Z state Of shape `[B, H, K, 1]. This is only needed when normalization is enabled. `. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + output_z_state (Optional[bool]): + Whether to output the final Z state of shape `[B, H, K, 1]`. This parameter is only effective when normalize=True. For API consistency, we recommend to update Z outside the function. Default: `False`. + + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ if scale is None: scale = q.shape[-1] ** -0.5 if not head_first: q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: - o = normalize_output(q * scale, k, o, cum_k) + if z_state is None: + k_shape = list(k.shape) + k_shape[-2 ]= 1 + z_state = k.new_zeros(k_shape) + o = normalize_output(q * scale, k, o, z_state) if not head_first: o = o.transpose(1, 2) + + if normalize and output_z_state: + z_state = z_state + torch.sum(k, dim = -2, keepdim = True) + return o, final_state, z_state return o, final_state From f050482c07ef1ea1997fad22257f6afa7a900bd3 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:01:33 +0800 Subject: [PATCH 11/18] Update fused_chunk.py --- fla/ops/linear_attn/fused_chunk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py index 2df944460..80923812b 100644 --- a/fla/ops/linear_attn/fused_chunk.py +++ b/fla/ops/linear_attn/fused_chunk.py @@ -330,5 +330,5 @@ def fused_chunk_linear_attn( if normalize and output_z_state: z_state = z_state + torch.sum(k, dim = -2, keepdim = True) - return o, final_state, z_state + return o, (final_state, z_state) return o, final_state From 246f17c4a9b036c5c75da99d11c965867fcfa963 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:02:13 +0800 Subject: [PATCH 12/18] Update fused_recurrent.py --- fla/ops/linear_attn/fused_recurrent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index ed27b6db0..37a5d93f2 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -289,5 +289,5 @@ def fused_recurrent_linear_attn( if normalize and output_z_state: z_state = z_state + torch.sum(k, dim = -2, keepdim = True) - return o, final_state, z_state + return o, (final_state, z_state) return o, final_state From c6ce801f25f1a216b6d475e4670581c371af2e47 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:02:44 +0800 Subject: [PATCH 13/18] Update utils.py --- fla/ops/linear_attn/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py index 92c2b1e06..f0a308d04 100644 --- a/fla/ops/linear_attn/utils.py +++ b/fla/ops/linear_attn/utils.py @@ -4,10 +4,10 @@ import torch -@torch.compile -def normalize_output(q, k, o, cum_k=None): +@torch.jit.script +def normalize_output(q, k, o, z_state): k = k.cumsum(-2) - if cum_k is not None: - k = k + cum_k + k = k + z_state z = (q * k).sum(-1, keepdim=True) return o / (z + 1e-10) + From ed6e92c184aa55325a31d15817fffdb4215780ce Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:09:25 +0800 Subject: [PATCH 14/18] Update utils.py now returns z. --- fla/ops/linear_attn/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py index f0a308d04..8b3727bc1 100644 --- a/fla/ops/linear_attn/utils.py +++ b/fla/ops/linear_attn/utils.py @@ -9,5 +9,5 @@ def normalize_output(q, k, o, z_state): k = k.cumsum(-2) k = k + z_state z = (q * k).sum(-1, keepdim=True) - return o / (z + 1e-10) + return o / (z + 1e-10), k[...,-1:,:] From 3ee20dd9e7104334381fddefe0c0c3c14ce86764 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:10:07 +0800 Subject: [PATCH 15/18] Update fused_recurrent.py --- fla/ops/linear_attn/fused_recurrent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index 37a5d93f2..6e7e2f4fa 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -283,11 +283,10 @@ def fused_recurrent_linear_attn( k_shape = list(k.shape) k_shape[-2 ]= 1 z_state = k.new_zeros(k_shape) - o = normalize_output(q * scale, k, o, z_state) + o, z_state = normalize_output(q * scale, k, o, z_state) if not head_first: o = o.transpose(1, 2) if normalize and output_z_state: - z_state = z_state + torch.sum(k, dim = -2, keepdim = True) return o, (final_state, z_state) return o, final_state From 8806758c622eb80e9c4da7eba6b8094a9c7eeafd Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:15:48 +0800 Subject: [PATCH 16/18] Update fused_chunk.py --- fla/ops/linear_attn/fused_chunk.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py index 80923812b..3fb54088e 100644 --- a/fla/ops/linear_attn/fused_chunk.py +++ b/fla/ops/linear_attn/fused_chunk.py @@ -280,7 +280,6 @@ def fused_chunk_linear_attn( output_final_state: bool = False, normalize: bool = True, head_first: bool = True, - output_z_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: @@ -303,8 +302,6 @@ def fused_chunk_linear_attn( Whether to normalize the output. Default: `True`. head_first (Optional[bool]): Whether the inputs are in the head-first format. Default: `True`. - output_z_state (Optional[bool]): - Whether to output the final Z state of shape `[B, H, K, 1]`. For API consistency, we recommend to update Z outside the function. Default: `False`. Returns: @@ -324,11 +321,10 @@ def fused_chunk_linear_attn( k_shape = list(k.shape) k_shape[-2 ]= 1 z_state = k.new_zeros(k_shape) - o = normalize_output(q * scale, k, o, z_state) + o, z_state = normalize_output(q * scale, k, o, z_state) if not head_first: o = o.transpose(1, 2) - if normalize and output_z_state: - z_state = z_state + torch.sum(k, dim = -2, keepdim = True) + if normalize: return o, (final_state, z_state) return o, final_state From 565bbb88353a083b2ca17595220de35c686dd1a1 Mon Sep 17 00:00:00 2001 From: Y Song <56291756+yiyousong@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:18:26 +0800 Subject: [PATCH 17/18] Update fused_recurrent.py --- fla/ops/linear_attn/fused_recurrent.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py index 6e7e2f4fa..2394cf219 100644 --- a/fla/ops/linear_attn/fused_recurrent.py +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -238,8 +238,7 @@ def fused_recurrent_linear_attn( z_state: torch.Tensor = None, output_final_state: bool = False, normalize: bool = True, - head_first: bool = True, - output_z_state: bool = False + head_first: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: @@ -261,10 +260,7 @@ def fused_recurrent_linear_attn( normalize (bool): Whether to normalize the output. Default: `True`. head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `True`. - output_z_state (Optional[bool]): - Whether to output the final Z state of shape `[B, H, K, 1]`. This parameter is only effective when normalize=True. For API consistency, we recommend to update Z outside the function. Default: `False`. - + Whether the inputs are in the head-first format. Default: `True`. Returns: o (torch.Tensor): @@ -287,6 +283,6 @@ def fused_recurrent_linear_attn( if not head_first: o = o.transpose(1, 2) - if normalize and output_z_state: + if normalize: return o, (final_state, z_state) return o, final_state From 1db56cb0c2a44b25319ac5241d001f4d09d266a2 Mon Sep 17 00:00:00 2001 From: Yu Zhang Date: Thu, 3 Apr 2025 17:28:11 +0800 Subject: [PATCH 18/18] Fix typos --- fla/ops/linear_attn/fused_chunk.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py index 3fb54088e..e670ad32e 100644 --- a/fla/ops/linear_attn/fused_chunk.py +++ b/fla/ops/linear_attn/fused_chunk.py @@ -295,14 +295,13 @@ def fused_chunk_linear_attn( initial_state (Optional[torch.Tensor]): Initial state of shape `[B, H, K, V]`. Default: `None`. z_state (Optional[torch.Tensor]): - Z state Of shape `[B, H, K, 1]. This is only needed when normalization is enabled. `. Default: `None`. + Z state of shape `[B, H, K, 1]`. This is only needed when normalization is enabled. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. normalize (bool): Whether to normalize the output. Default: `True`. head_first (Optional[bool]): Whether the inputs are in the head-first format. Default: `True`. - Returns: o (torch.Tensor):