Skip to content

Commit 3866671

Browse files
committed
[WAN] Use different sharding strategy for self and cross attention.
1 parent 95afb77 commit 3866671

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174174
flash_block_sizes: BlockSizes,
175175
dtype: jnp.dtype = jnp.float32,
176176
attention_kernel: str = "flash",
177+
is_self_attention: Optional[bool] = None,
177178
) -> jax.Array:
178179
"""TPU Flash Attention"""
179180

@@ -201,8 +202,22 @@ def _tpu_flash_attention(
201202
query = _reshape_data_for_flash(query, heads)
202203
key = _reshape_data_for_flash(key, heads)
203204
value = _reshape_data_for_flash(value, heads)
204-
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
205-
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
205+
206+
# Use different sharding strategy for self-attn vs cross-attn
207+
if is_self_attention is not None:
208+
if is_self_attention:
209+
# Self-attention: Context Parallelism (sharding along num_heads)
210+
q_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
211+
kv_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
212+
else:
213+
# Cross-attention: Sequence Parallelism for Q
214+
# Q's sequence is sharded; K/V are replicated
215+
q_axis_names = PartitionSpec("data", None, ("fsdp", "tensor"), None)
216+
kv_axis_names = PartitionSpec("data", None, None, None)
217+
else:
218+
# Fallback to original maxdiffusion behavior if the flag isn't provided
219+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
220+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
206221

207222
@functools.partial(
208223
shard_map.shard_map,
@@ -419,6 +434,7 @@ def _apply_attention(
419434
axis_names_kv: AxisNames,
420435
flash_block_sizes: BlockSizes,
421436
dpa_layer: Callable,
437+
is_self_attention: bool = True,
422438
):
423439
"""Routes to different attention kernels."""
424440
_check_attention_inputs(query, key, value)
@@ -439,7 +455,7 @@ def _apply_attention(
439455
)
440456
elif attention_kernel == "flash":
441457
return _tpu_flash_attention(
442-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
458+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, is_self_attention,
443459
)
444460
elif attention_kernel == "ring":
445461
return _tpu_flash_attention(
@@ -574,6 +590,7 @@ def __init__(
574590
flash_block_sizes: BlockSizes = None,
575591
dtype: DType = jnp.float32,
576592
quant: Quant = None,
593+
is_self_attention: bool = True,
577594
):
578595
self.dpa_layer = None
579596
if attention_kernel == "cudnn_flash_te":
@@ -593,6 +610,7 @@ def __init__(
593610
self.flash_block_sizes = flash_block_sizes
594611
self.dtype = dtype
595612
self.quant = quant
613+
self.is_self_attention = is_self_attention
596614

597615
def apply_attention(self, query: Array, key: Array, value: Array):
598616
return _apply_attention(
@@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
613631
axis_names_kv=self.axis_names_kv,
614632
flash_block_sizes=self.flash_block_sizes,
615633
dpa_layer=self.dpa_layer,
634+
is_self_attention=self.is_self_attention,
616635
)
617636

618637

@@ -701,6 +720,7 @@ def __init__(
701720
precision: jax.lax.Precision = None,
702721
qkv_bias: bool = False,
703722
quant: Quant = None,
723+
is_self_attention: bool = True,
704724
):
705725
if attention_kernel == "cudnn_flash_te":
706726
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
@@ -730,6 +750,7 @@ def __init__(
730750
flash_block_sizes=flash_block_sizes,
731751
dtype=dtype,
732752
quant=quant,
753+
is_self_attention=is_self_attention,
733754
)
734755
# None axes corresponds to the stacked weights across all blocks
735756
# because of the use of nnx.vmap and nnx.scan.

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def __init__(
282282
precision=precision,
283283
attention_kernel=attention,
284284
dropout=dropout,
285+
is_self_attention=True,
285286
)
286287

287288
# 1. Cross-attention
@@ -300,6 +301,7 @@ def __init__(
300301
precision=precision,
301302
attention_kernel=attention,
302303
dropout=dropout,
304+
is_self_attention=False,
303305
)
304306
assert cross_attn_norm is True
305307
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)

0 commit comments

Comments
 (0)