From 3866671652df4dba7393156eeb24490bd0c1fd9f Mon Sep 17 00:00:00 2001 From: Hyesoo Yang Date: Wed, 17 Sep 2025 06:53:03 +0000 Subject: [PATCH] [WAN] Use different sharding strategy for self and cross attention. --- src/maxdiffusion/models/attention_flax.py | 27 ++++++++++++++++--- .../wan/transformers/transformer_wan.py | 2 ++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5df5f334..d986f4c7 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -174,6 +174,7 @@ def _tpu_flash_attention( flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, attention_kernel: str = "flash", + is_self_attention: Optional[bool] = None, ) -> jax.Array: """TPU Flash Attention""" @@ -201,8 +202,22 @@ def _tpu_flash_attention( query = _reshape_data_for_flash(query, heads) key = _reshape_data_for_flash(key, heads) value = _reshape_data_for_flash(value, heads) - q_axis_names = nn.logical_to_mesh_axes(axis_names_q) - kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + + # Use different sharding strategy for self-attn vs cross-attn + if is_self_attention is not None: + if is_self_attention: + # Self-attention: Context Parallelism (sharding along num_heads) + q_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None) + kv_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None) + else: + # Cross-attention: Sequence Parallelism for Q + # Q's sequence is sharded; K/V are replicated + q_axis_names = PartitionSpec("data", None, ("fsdp", "tensor"), None) + kv_axis_names = PartitionSpec("data", None, None, None) + else: + # Fallback to original maxdiffusion behavior if the flag isn't provided + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) @functools.partial( shard_map.shard_map, @@ -419,6 +434,7 @@ def _apply_attention( axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dpa_layer: Callable, + is_self_attention: bool = True, ): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) @@ -439,7 +455,7 @@ def _apply_attention( ) elif attention_kernel == "flash": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, is_self_attention, ) elif attention_kernel == "ring": return _tpu_flash_attention( @@ -574,6 +590,7 @@ def __init__( flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, quant: Quant = None, + is_self_attention: bool = True, ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": @@ -593,6 +610,7 @@ def __init__( self.flash_block_sizes = flash_block_sizes self.dtype = dtype self.quant = quant + self.is_self_attention = is_self_attention def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( @@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array): axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, + is_self_attention=self.is_self_attention, ) @@ -701,6 +720,7 @@ def __init__( precision: jax.lax.Precision = None, qkv_bias: bool = False, quant: Quant = None, + is_self_attention: bool = True, ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -730,6 +750,7 @@ def __init__( flash_block_sizes=flash_block_sizes, dtype=dtype, quant=quant, + is_self_attention=is_self_attention, ) # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 48ed7b8e..a9bc8f35 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -282,6 +282,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=True, ) # 1. Cross-attention @@ -300,6 +301,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=False, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)