@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174
174
flash_block_sizes : BlockSizes ,
175
175
dtype : jnp .dtype = jnp .float32 ,
176
176
attention_kernel : str = "flash" ,
177
+ is_self_attention : Optional [bool ] = None ,
177
178
) -> jax .Array :
178
179
"""TPU Flash Attention"""
179
180
@@ -201,8 +202,22 @@ def _tpu_flash_attention(
201
202
query = _reshape_data_for_flash (query , heads )
202
203
key = _reshape_data_for_flash (key , heads )
203
204
value = _reshape_data_for_flash (value , heads )
204
- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
205
- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
205
+
206
+ # Use different sharding strategy for self-attn vs cross-attn
207
+ if is_self_attention is not None :
208
+ if is_self_attention :
209
+ # Self-attention: Context Parallelism (sharding along num_heads)
210
+ q_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
211
+ kv_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
212
+ else :
213
+ # Cross-attention: Sequence Parallelism for Q
214
+ # Q's sequence is sharded; K/V are replicated
215
+ q_axis_names = PartitionSpec ("data" , None , ("fsdp" , "tensor" ), None )
216
+ kv_axis_names = PartitionSpec ("data" , None , None , None )
217
+ else :
218
+ # Fallback to original maxdiffusion behavior if the flag isn't provided
219
+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
220
+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
206
221
207
222
@functools .partial (
208
223
shard_map .shard_map ,
@@ -419,6 +434,7 @@ def _apply_attention(
419
434
axis_names_kv : AxisNames ,
420
435
flash_block_sizes : BlockSizes ,
421
436
dpa_layer : Callable ,
437
+ is_self_attention : bool = True ,
422
438
):
423
439
"""Routes to different attention kernels."""
424
440
_check_attention_inputs (query , key , value )
@@ -439,7 +455,7 @@ def _apply_attention(
439
455
)
440
456
elif attention_kernel == "flash" :
441
457
return _tpu_flash_attention (
442
- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
458
+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel , is_self_attention ,
443
459
)
444
460
elif attention_kernel == "ring" :
445
461
return _tpu_flash_attention (
@@ -574,6 +590,7 @@ def __init__(
574
590
flash_block_sizes : BlockSizes = None ,
575
591
dtype : DType = jnp .float32 ,
576
592
quant : Quant = None ,
593
+ is_self_attention : bool = True ,
577
594
):
578
595
self .dpa_layer = None
579
596
if attention_kernel == "cudnn_flash_te" :
@@ -593,6 +610,7 @@ def __init__(
593
610
self .flash_block_sizes = flash_block_sizes
594
611
self .dtype = dtype
595
612
self .quant = quant
613
+ self .is_self_attention = is_self_attention
596
614
597
615
def apply_attention (self , query : Array , key : Array , value : Array ):
598
616
return _apply_attention (
@@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
613
631
axis_names_kv = self .axis_names_kv ,
614
632
flash_block_sizes = self .flash_block_sizes ,
615
633
dpa_layer = self .dpa_layer ,
634
+ is_self_attention = self .is_self_attention ,
616
635
)
617
636
618
637
@@ -701,6 +720,7 @@ def __init__(
701
720
precision : jax .lax .Precision = None ,
702
721
qkv_bias : bool = False ,
703
722
quant : Quant = None ,
723
+ is_self_attention : bool = True ,
704
724
):
705
725
if attention_kernel == "cudnn_flash_te" :
706
726
raise NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel } " )
@@ -730,6 +750,7 @@ def __init__(
730
750
flash_block_sizes = flash_block_sizes ,
731
751
dtype = dtype ,
732
752
quant = quant ,
753
+ is_self_attention = is_self_attention ,
733
754
)
734
755
# None axes corresponds to the stacked weights across all blocks
735
756
# because of the use of nnx.vmap and nnx.scan.
0 commit comments