Skip to content

Commit ea84cb9

Browse files
committed
ring attention rules are added at front if not present to shard sequence on fsdp axis
1 parent 8609f78 commit ea84cb9

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

src/maxdiffusion/common_types.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
BlockSizes = splash_attention_kernel.BlockSizes
3434

3535
AxisNames = tuple[str, ...]
36-
36+
# Physical axis names for device meshes.
37+
DATA = "data"
38+
FSDP = "fsdp"
39+
TENSOR = "tensor"
40+
# Logical axis names for model parameters and activations.
3741
BATCH = "activation_batch"
3842
LENGTH = "activation_length"
3943
KV_LENGTH = "activation_kv_length"
@@ -54,3 +58,22 @@
5458

5559

5660
WAN_MODEL = "Wan2.1"
61+
62+
### Common axis rules for ring attention ###
63+
RING_ATTENTION_AXIS_RULES = [
64+
[SELF_ATTN_HEAD, None],
65+
[SELF_ATTN_Q_LENGTH, FSDP],
66+
[SELF_ATTN_KV_LENGTH, FSDP],
67+
[CROSS_ATTN_HEAD, None],
68+
[CROSS_ATTN_Q_LENGTH, FSDP],
69+
[CROSS_ATTN_KV_LENGTH, FSDP],
70+
]
71+
72+
SEQUENCE_PARALLEL_AXIS_RULES = [
73+
[SELF_ATTN_HEAD, None],
74+
[SELF_ATTN_Q_LENGTH, FSDP],
75+
[SELF_ATTN_KV_LENGTH, None],
76+
[CROSS_ATTN_HEAD, None],
77+
[CROSS_ATTN_Q_LENGTH, FSDP],
78+
[CROSS_ATTN_KV_LENGTH, None],
79+
]

src/maxdiffusion/pyconfig.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from . import max_logging
2828
from . import max_utils
2929
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
30-
from maxdiffusion.common_types import LENGTH, KV_LENGTH
30+
from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, SELF_ATTN_Q_LENGTH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, CROSS_ATTN_Q_LENGTH
3131

3232

3333
def string_to_bool(s: str) -> bool:
@@ -180,14 +180,22 @@ def user_init(raw_keys):
180180
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
181181
# Verify qkv is sharded across sequence.
182182
if raw_keys["attention"] == "ring":
183+
max_logging.log("Using ring attention, adding sequence sharding to q and kv if not already present.")
183184
logical_axis_rules = list(raw_keys["logical_axis_rules"])
185+
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
186+
new_rules = []
184187
q_seq_sharding = (LENGTH, "fsdp")
185188
kv_seq_sharding = (KV_LENGTH, "fsdp")
186189
if q_seq_sharding not in logical_axis_rules:
187190
logical_axis_rules.append(q_seq_sharding)
188191
if kv_seq_sharding not in logical_axis_rules:
189192
logical_axis_rules.append(kv_seq_sharding)
190-
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
193+
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
194+
if ring_attention_axis_rule not in logical_axis_rules:
195+
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
196+
new_rules.append(ring_attention_axis_rule)
197+
raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules)
198+
max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}")
191199

192200
raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])
193201

0 commit comments

Comments
 (0)