53
53
CROSS_ATTN_KV_LENGTH = common_types .CROSS_ATTN_KV_LENGTH
54
54
55
55
56
-
57
56
def _maybe_aqt_einsum (quant : Quant ):
58
57
return jnp .einsum if quant is None else quant .einsum ()
59
58
@@ -448,7 +447,16 @@ def _apply_attention(
448
447
)
449
448
elif attention_kernel == "flash" :
450
449
return _tpu_flash_attention (
451
- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
450
+ query ,
451
+ key * scale ,
452
+ value ,
453
+ heads ,
454
+ mesh ,
455
+ axis_names_q ,
456
+ axis_names_kv ,
457
+ flash_block_sizes ,
458
+ dtype ,
459
+ attention_kernel ,
452
460
)
453
461
elif attention_kernel == "ring" :
454
462
return _tpu_flash_attention (
@@ -733,7 +741,7 @@ def __init__(
733
741
else :
734
742
axis_names_q = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_Q_LENGTH , D_KV )
735
743
axis_names_kv = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_KV_LENGTH , D_KV )
736
-
744
+
737
745
self .attention_op = NNXAttentionOp (
738
746
mesh = mesh ,
739
747
attention_kernel = attention_kernel ,
@@ -1542,4 +1550,4 @@ def setup(self):
1542
1550
def __call__ (self , hidden_states , deterministic = True ):
1543
1551
hidden_states = self .proj (hidden_states )
1544
1552
hidden_linear , hidden_gelu = jnp .split (hidden_states , 2 , axis = 2 )
1545
- return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
1553
+ return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
0 commit comments