Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _tpu_flash_attention(
flash_block_sizes: BlockSizes,
dtype: jnp.dtype = jnp.float32,
attention_kernel: str = "flash",
is_self_attention: Optional[bool] = None,
) -> jax.Array:
"""TPU Flash Attention"""

Expand Down Expand Up @@ -201,8 +202,22 @@ def _tpu_flash_attention(
query = _reshape_data_for_flash(query, heads)
key = _reshape_data_for_flash(key, heads)
value = _reshape_data_for_flash(value, heads)
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

# Use different sharding strategy for self-attn vs cross-attn
if is_self_attention is not None:
if is_self_attention:
# Self-attention: Context Parallelism (sharding along num_heads)
q_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
kv_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
else:
# Cross-attention: Sequence Parallelism for Q
# Q's sequence is sharded; K/V are replicated
q_axis_names = PartitionSpec("data", None, ("fsdp", "tensor"), None)
kv_axis_names = PartitionSpec("data", None, None, None)
else:
# Fallback to original maxdiffusion behavior if the flag isn't provided
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

@functools.partial(
shard_map.shard_map,
Expand Down Expand Up @@ -419,6 +434,7 @@ def _apply_attention(
axis_names_kv: AxisNames,
flash_block_sizes: BlockSizes,
dpa_layer: Callable,
is_self_attention: bool = True,
):
"""Routes to different attention kernels."""
_check_attention_inputs(query, key, value)
Expand All @@ -439,7 +455,7 @@ def _apply_attention(
)
elif attention_kernel == "flash":
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, is_self_attention,
)
elif attention_kernel == "ring":
return _tpu_flash_attention(
Expand Down Expand Up @@ -574,6 +590,7 @@ def __init__(
flash_block_sizes: BlockSizes = None,
dtype: DType = jnp.float32,
quant: Quant = None,
is_self_attention: bool = True,
):
self.dpa_layer = None
if attention_kernel == "cudnn_flash_te":
Expand All @@ -593,6 +610,7 @@ def __init__(
self.flash_block_sizes = flash_block_sizes
self.dtype = dtype
self.quant = quant
self.is_self_attention = is_self_attention

def apply_attention(self, query: Array, key: Array, value: Array):
return _apply_attention(
Expand All @@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
axis_names_kv=self.axis_names_kv,
flash_block_sizes=self.flash_block_sizes,
dpa_layer=self.dpa_layer,
is_self_attention=self.is_self_attention,
)


Expand Down Expand Up @@ -701,6 +720,7 @@ def __init__(
precision: jax.lax.Precision = None,
qkv_bias: bool = False,
quant: Quant = None,
is_self_attention: bool = True,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
Expand Down Expand Up @@ -730,6 +750,7 @@ def __init__(
flash_block_sizes=flash_block_sizes,
dtype=dtype,
quant=quant,
is_self_attention=is_self_attention,
)
# None axes corresponds to the stacked weights across all blocks
# because of the use of nnx.vmap and nnx.scan.
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=True,
)

# 1. Cross-attention
Expand All @@ -300,6 +301,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=False,
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand Down