|
27 | 27 | from . import max_logging
|
28 | 28 | from . import max_utils
|
29 | 29 | from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
|
30 |
| -from maxdiffusion.common_types import LENGTH, KV_LENGTH |
| 30 | +from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, SELF_ATTN_Q_LENGTH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, CROSS_ATTN_Q_LENGTH |
31 | 31 |
|
32 | 32 |
|
33 | 33 | def string_to_bool(s: str) -> bool:
|
@@ -180,14 +180,22 @@ def user_init(raw_keys):
|
180 | 180 | raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
|
181 | 181 | # Verify qkv is sharded across sequence.
|
182 | 182 | if raw_keys["attention"] == "ring":
|
| 183 | + max_logging.log("Using ring attention, adding sequence sharding to q and kv if not already present.") |
183 | 184 | logical_axis_rules = list(raw_keys["logical_axis_rules"])
|
| 185 | + max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") |
| 186 | + new_rules = [] |
184 | 187 | q_seq_sharding = (LENGTH, "fsdp")
|
185 | 188 | kv_seq_sharding = (KV_LENGTH, "fsdp")
|
186 | 189 | if q_seq_sharding not in logical_axis_rules:
|
187 | 190 | logical_axis_rules.append(q_seq_sharding)
|
188 | 191 | if kv_seq_sharding not in logical_axis_rules:
|
189 | 192 | logical_axis_rules.append(kv_seq_sharding)
|
190 |
| - raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) |
| 193 | + for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: |
| 194 | + if ring_attention_axis_rule not in logical_axis_rules: |
| 195 | + max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") |
| 196 | + new_rules.append(ring_attention_axis_rule) |
| 197 | + raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules) |
| 198 | + max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}") |
191 | 199 |
|
192 | 200 | raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])
|
193 | 201 |
|
|
0 commit comments