45
45
EMBED = common_types .EMBED
46
46
Quant = quantizations .AqtQuantization
47
47
48
+ SELF_ATTN_HEAD = common_types .SELF_ATTN_HEAD
49
+ SELF_ATTN_Q_LENGTH = common_types .SELF_ATTN_Q_LENGTH
50
+ SELF_ATTN_KV_LENGTH = common_types .SELF_ATTN_KV_LENGTH
51
+ CROSS_ATTN_HEAD = common_types .CROSS_ATTN_HEAD
52
+ CROSS_ATTN_Q_LENGTH = common_types .CROSS_ATTN_Q_LENGTH
53
+ CROSS_ATTN_KV_LENGTH = common_types .CROSS_ATTN_KV_LENGTH
54
+
55
+
48
56
49
57
def _maybe_aqt_einsum (quant : Quant ):
50
58
return jnp .einsum if quant is None else quant .einsum ()
@@ -174,7 +182,6 @@ def _tpu_flash_attention(
174
182
flash_block_sizes : BlockSizes ,
175
183
dtype : jnp .dtype = jnp .float32 ,
176
184
attention_kernel : str = "flash" ,
177
- is_self_attention : Optional [bool ] = None ,
178
185
) -> jax .Array :
179
186
"""TPU Flash Attention"""
180
187
@@ -203,22 +210,8 @@ def _tpu_flash_attention(
203
210
query = _reshape_data_for_flash (query , heads )
204
211
key = _reshape_data_for_flash (key , heads )
205
212
value = _reshape_data_for_flash (value , heads )
206
-
207
- # Use different sharding strategy for self-attn vs cross-attn
208
- if is_self_attention is not None :
209
- if is_self_attention :
210
- # Self-attention: Context Parallelism (sharding along num_heads)
211
- q_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
212
- kv_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
213
- else :
214
- # Cross-attention: Sequence Parallelism for Q
215
- # Q's sequence is sharded; K/V are replicated
216
- q_axis_names = PartitionSpec ("data" , None , ("fsdp" , "tensor" ), None )
217
- kv_axis_names = PartitionSpec ("data" , None , None , None )
218
- else :
219
- # Fallback to original maxdiffusion behavior if the flag isn't provided
220
- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
221
- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
213
+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
214
+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
222
215
223
216
@functools .partial (
224
217
shard_map .shard_map ,
@@ -435,7 +428,6 @@ def _apply_attention(
435
428
axis_names_kv : AxisNames ,
436
429
flash_block_sizes : BlockSizes ,
437
430
dpa_layer : Callable ,
438
- is_self_attention : bool = True ,
439
431
):
440
432
"""Routes to different attention kernels."""
441
433
_check_attention_inputs (query , key , value )
@@ -456,7 +448,7 @@ def _apply_attention(
456
448
)
457
449
elif attention_kernel == "flash" :
458
450
return _tpu_flash_attention (
459
- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel , is_self_attention ,
451
+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
460
452
)
461
453
elif attention_kernel == "ring" :
462
454
return _tpu_flash_attention (
@@ -591,7 +583,6 @@ def __init__(
591
583
flash_block_sizes : BlockSizes = None ,
592
584
dtype : DType = jnp .float32 ,
593
585
quant : Quant = None ,
594
- is_self_attention : bool = True ,
595
586
):
596
587
self .dpa_layer = None
597
588
if attention_kernel == "cudnn_flash_te" :
@@ -611,7 +602,6 @@ def __init__(
611
602
self .flash_block_sizes = flash_block_sizes
612
603
self .dtype = dtype
613
604
self .quant = quant
614
- self .is_self_attention = is_self_attention
615
605
616
606
def apply_attention (self , query : Array , key : Array , value : Array ):
617
607
return _apply_attention (
@@ -632,7 +622,6 @@ def apply_attention(self, query: Array, key: Array, value: Array):
632
622
axis_names_kv = self .axis_names_kv ,
633
623
flash_block_sizes = self .flash_block_sizes ,
634
624
dpa_layer = self .dpa_layer ,
635
- is_self_attention = self .is_self_attention ,
636
625
)
637
626
638
627
@@ -738,6 +727,13 @@ def __init__(
738
727
self .value_axis_names = value_axis_names
739
728
self .out_axis_names = out_axis_names
740
729
730
+ if is_self_attention :
731
+ axis_names_q = (BATCH , SELF_ATTN_HEAD , SELF_ATTN_Q_LENGTH , D_KV )
732
+ axis_names_kv = (BATCH , SELF_ATTN_HEAD , SELF_ATTN_KV_LENGTH , D_KV )
733
+ else :
734
+ axis_names_q = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_Q_LENGTH , D_KV )
735
+ axis_names_kv = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_KV_LENGTH , D_KV )
736
+
741
737
self .attention_op = NNXAttentionOp (
742
738
mesh = mesh ,
743
739
attention_kernel = attention_kernel ,
@@ -747,11 +743,12 @@ def __init__(
747
743
use_memory_efficient_attention = use_memory_efficient_attention ,
748
744
split_head_dim = split_head_dim ,
749
745
float32_qk_product = False ,
746
+ axis_names_q = axis_names_q ,
747
+ axis_names_kv = axis_names_kv ,
750
748
flash_min_seq_length = flash_min_seq_length ,
751
749
flash_block_sizes = flash_block_sizes ,
752
750
dtype = dtype ,
753
751
quant = quant ,
754
- is_self_attention = is_self_attention ,
755
752
)
756
753
# None axes corresponds to the stacked weights across all blocks
757
754
# because of the use of nnx.vmap and nnx.scan.
0 commit comments