Skip to content

Commit 8c4f774

Browse files
committed
update sharding rules for attn.
1 parent 320d282 commit 8c4f774

File tree

3 files changed

+43
-25
lines changed

3 files changed

+43
-25
lines changed

src/maxdiffusion/common_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,13 @@
4444
KEEP_2 = "activation_keep_2"
4545
CONV_OUT = "activation_conv_out_channels"
4646

47+
# For setting self/cross attention independently in splash kernel
48+
SELF_ATTN_HEAD = "activation_self_attn_heads"
49+
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
50+
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
51+
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
52+
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
53+
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"
54+
55+
4756
WAN_MODEL = "Wan2.1"

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,21 @@ flash_block_sizes: {}
6868
# "block_kv" : 2048,
6969
# "block_q_dkv" : 3024,
7070
# "block_kv_dkv" : 2048,
71-
# "block_kv_dkv_compute" : 2048,
71+
# "block_kv_dkv_compute" : 1024,
7272
# "block_q_dq" : 3024,
7373
# "block_kv_dq" : 2048
7474
# }
75+
# Use on v5p
76+
flash_block_sizes: {
77+
"block_q" : 1024,
78+
"block_kv_compute" : 256,
79+
"block_kv" : 3072,
80+
"block_q_dkv" : 1024,
81+
"block_kv_dkv" : 3072,
82+
"block_kv_dkv_compute" : 256,
83+
"block_q_dq" : 1024,
84+
"block_kv_dq" : 3072
85+
}
7586
# GroupNorm groups
7687
norm_num_groups: 32
7788

@@ -132,8 +143,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
132143
logical_axis_rules: [
133144
['batch', 'data'],
134145
['activation_batch', 'data'],
146+
['activation_self_attn_heads', ['fsdp', 'tensor']],
147+
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
135148
['activation_length', 'fsdp'],
136-
137149
['activation_heads', 'tensor'],
138150
['mlp','tensor'],
139151
['embed','fsdp'],

src/maxdiffusion/models/attention_flax.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@
4545
EMBED = common_types.EMBED
4646
Quant = quantizations.AqtQuantization
4747

48+
SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD
49+
SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH
50+
SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH
51+
CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD
52+
CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH
53+
CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH
54+
55+
4856

4957
def _maybe_aqt_einsum(quant: Quant):
5058
return jnp.einsum if quant is None else quant.einsum()
@@ -174,7 +182,6 @@ def _tpu_flash_attention(
174182
flash_block_sizes: BlockSizes,
175183
dtype: jnp.dtype = jnp.float32,
176184
attention_kernel: str = "flash",
177-
is_self_attention: Optional[bool] = None,
178185
) -> jax.Array:
179186
"""TPU Flash Attention"""
180187

@@ -203,22 +210,8 @@ def _tpu_flash_attention(
203210
query = _reshape_data_for_flash(query, heads)
204211
key = _reshape_data_for_flash(key, heads)
205212
value = _reshape_data_for_flash(value, heads)
206-
207-
# Use different sharding strategy for self-attn vs cross-attn
208-
if is_self_attention is not None:
209-
if is_self_attention:
210-
# Self-attention: Context Parallelism (sharding along num_heads)
211-
q_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
212-
kv_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
213-
else:
214-
# Cross-attention: Sequence Parallelism for Q
215-
# Q's sequence is sharded; K/V are replicated
216-
q_axis_names = PartitionSpec("data", None, ("fsdp", "tensor"), None)
217-
kv_axis_names = PartitionSpec("data", None, None, None)
218-
else:
219-
# Fallback to original maxdiffusion behavior if the flag isn't provided
220-
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
221-
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
213+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
214+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
222215

223216
@functools.partial(
224217
shard_map.shard_map,
@@ -435,7 +428,6 @@ def _apply_attention(
435428
axis_names_kv: AxisNames,
436429
flash_block_sizes: BlockSizes,
437430
dpa_layer: Callable,
438-
is_self_attention: bool = True,
439431
):
440432
"""Routes to different attention kernels."""
441433
_check_attention_inputs(query, key, value)
@@ -456,7 +448,7 @@ def _apply_attention(
456448
)
457449
elif attention_kernel == "flash":
458450
return _tpu_flash_attention(
459-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, is_self_attention,
451+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
460452
)
461453
elif attention_kernel == "ring":
462454
return _tpu_flash_attention(
@@ -591,7 +583,6 @@ def __init__(
591583
flash_block_sizes: BlockSizes = None,
592584
dtype: DType = jnp.float32,
593585
quant: Quant = None,
594-
is_self_attention: bool = True,
595586
):
596587
self.dpa_layer = None
597588
if attention_kernel == "cudnn_flash_te":
@@ -611,7 +602,6 @@ def __init__(
611602
self.flash_block_sizes = flash_block_sizes
612603
self.dtype = dtype
613604
self.quant = quant
614-
self.is_self_attention = is_self_attention
615605

616606
def apply_attention(self, query: Array, key: Array, value: Array):
617607
return _apply_attention(
@@ -632,7 +622,6 @@ def apply_attention(self, query: Array, key: Array, value: Array):
632622
axis_names_kv=self.axis_names_kv,
633623
flash_block_sizes=self.flash_block_sizes,
634624
dpa_layer=self.dpa_layer,
635-
is_self_attention=self.is_self_attention,
636625
)
637626

638627

@@ -738,6 +727,13 @@ def __init__(
738727
self.value_axis_names = value_axis_names
739728
self.out_axis_names = out_axis_names
740729

730+
if is_self_attention:
731+
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
732+
axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV)
733+
else:
734+
axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV)
735+
axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV)
736+
741737
self.attention_op = NNXAttentionOp(
742738
mesh=mesh,
743739
attention_kernel=attention_kernel,
@@ -747,11 +743,12 @@ def __init__(
747743
use_memory_efficient_attention=use_memory_efficient_attention,
748744
split_head_dim=split_head_dim,
749745
float32_qk_product=False,
746+
axis_names_q=axis_names_q,
747+
axis_names_kv=axis_names_kv,
750748
flash_min_seq_length=flash_min_seq_length,
751749
flash_block_sizes=flash_block_sizes,
752750
dtype=dtype,
753751
quant=quant,
754-
is_self_attention=is_self_attention,
755752
)
756753
# None axes corresponds to the stacked weights across all blocks
757754
# because of the use of nnx.vmap and nnx.scan.

0 commit comments

Comments
 (0)