Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
BlockSizes = splash_attention_kernel.BlockSizes

AxisNames = tuple[str, ...]

# Physical axis names for device meshes.
DATA = "data"
FSDP = "fsdp"
TENSOR = "tensor"
# Logical axis names for model parameters and activations.
BATCH = "activation_batch"
LENGTH = "activation_length"
KV_LENGTH = "activation_kv_length"
Expand All @@ -44,4 +48,32 @@
KEEP_2 = "activation_keep_2"
CONV_OUT = "activation_conv_out_channels"

# For setting self/cross attention independently in splash kernel
SELF_ATTN_HEAD = "activation_self_attn_heads"
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"


WAN_MODEL = "Wan2.1"

### Common axis rules for ring attention ###
RING_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, FSDP],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, FSDP],
]

SEQUENCE_PARALLEL_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, None],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, None],
]
17 changes: 15 additions & 2 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,21 @@ flash_block_sizes: {}
# "block_kv" : 2048,
# "block_q_dkv" : 3024,
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 2048,
# "block_kv_dkv_compute" : 1024,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048
# }
# Use on v5p
flash_block_sizes: {
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 3072,
"block_q_dkv" : 1024,
"block_kv_dkv" : 3072,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 3072
}
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -132,15 +143,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],

['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_in', 'fsdp'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
Expand Down
16 changes: 8 additions & 8 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,14 @@ def get_flash_block_sizes(config):
flash_block_sizes = None
if len(config.flash_block_sizes.keys()) > 0:
flash_block_sizes = splash_attention_kernel.BlockSizes(
block_q=config.flash_block_sizes["block_q"],
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
block_kv=config.flash_block_sizes["block_kv"],
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
block_q_dq=config.flash_block_sizes["block_q_dq"],
block_kv_dq=config.flash_block_sizes["block_kv_dq"],
block_q=int(config.flash_block_sizes["block_q"]),
block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]),
block_kv=int(config.flash_block_sizes["block_kv"]),
block_q_dkv=int(config.flash_block_sizes["block_q_dkv"]),
block_kv_dkv=int(config.flash_block_sizes["block_kv_dkv"]),
block_kv_dkv_compute=int(config.flash_block_sizes["block_kv_dkv_compute"]),
block_q_dq=int(config.flash_block_sizes["block_q_dq"]),
block_kv_dq=int(config.flash_block_sizes["block_kv_dq"]),
)
return flash_block_sizes

Expand Down
31 changes: 29 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@
EMBED = common_types.EMBED
Quant = quantizations.AqtQuantization

SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD
SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH
SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH
CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD
CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH
CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH


def _maybe_aqt_einsum(quant: Quant):
return jnp.einsum if quant is None else quant.einsum()
Expand Down Expand Up @@ -184,7 +191,8 @@ def _tpu_flash_attention(
kv_max_block_size = key.shape[1]
else:
kv_max_block_size = q_max_block_size
if flash_block_sizes:
# ensure that for cross attention we override the block sizes.
if flash_block_sizes and key.shape[1] == query.shape[1]:
block_sizes = flash_block_sizes
else:
block_sizes = splash_attention_kernel.BlockSizes(
Expand Down Expand Up @@ -439,7 +447,16 @@ def _apply_attention(
)
elif attention_kernel == "flash":
return _tpu_flash_attention(
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
query,
key * scale,
value,
heads,
mesh,
axis_names_q,
axis_names_kv,
flash_block_sizes,
dtype,
attention_kernel,
)
elif attention_kernel == "ring":
return _tpu_flash_attention(
Expand Down Expand Up @@ -701,6 +718,7 @@ def __init__(
precision: jax.lax.Precision = None,
qkv_bias: bool = False,
quant: Quant = None,
is_self_attention: bool = True,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
Expand All @@ -717,6 +735,13 @@ def __init__(
self.value_axis_names = value_axis_names
self.out_axis_names = out_axis_names

if is_self_attention:
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV)
else:
axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV)
axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV)

self.attention_op = NNXAttentionOp(
mesh=mesh,
attention_kernel=attention_kernel,
Expand All @@ -726,6 +751,8 @@ def __init__(
use_memory_efficient_attention=use_memory_efficient_attention,
split_head_dim=split_head_dim,
float32_qk_product=False,
axis_names_q=axis_names_q,
axis_names_kv=axis_names_kv,
flash_min_seq_length=flash_min_seq_length,
flash_block_sizes=flash_block_sizes,
dtype=dtype,
Expand Down
7 changes: 6 additions & 1 deletion src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=True,
)

# 1. Cross-attention
Expand All @@ -300,6 +301,7 @@ def __init__(
precision=precision,
attention_kernel=attention,
dropout=dropout,
is_self_attention=False,
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand Down Expand Up @@ -351,7 +353,10 @@ def __call__(
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
hidden_states = hidden_states + attn_output

Expand Down
12 changes: 10 additions & 2 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from . import max_logging
from . import max_utils
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
from maxdiffusion.common_types import LENGTH, KV_LENGTH
from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES


def string_to_bool(s: str) -> bool:
Expand Down Expand Up @@ -180,14 +180,22 @@ def user_init(raw_keys):
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
# Verify qkv is sharded across sequence.
if raw_keys["attention"] == "ring":
max_logging.log("Using ring attention, adding sequence sharding to q and kv if not already present.")
logical_axis_rules = list(raw_keys["logical_axis_rules"])
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
new_rules = []
q_seq_sharding = (LENGTH, "fsdp")
kv_seq_sharding = (KV_LENGTH, "fsdp")
if q_seq_sharding not in logical_axis_rules:
logical_axis_rules.append(q_seq_sharding)
if kv_seq_sharding not in logical_axis_rules:
logical_axis_rules.append(kv_seq_sharding)
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
if ring_attention_axis_rule not in logical_axis_rules:
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
new_rules.append(ring_attention_axis_rule)
raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules)
max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}")

raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])

Expand Down
Loading
Loading