Skip to content

Commit 273f54c

Browse files
authored
Add manual constraints in Llama3 example (#130)
* [WIP] Remove Embedding constraint for 1d and add mm constraints for 2d * Cleanup * Cleanup * Bugfix and disable by default * Make it work with mm and einsum
1 parent d2ba202 commit 273f54c

File tree

1 file changed

+92
-14
lines changed

1 file changed

+92
-14
lines changed

examples/example_llama3.py

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212
from torch import nn
1313
from torch.distributed.fsdp import MixedPrecisionPolicy
14-
from torch.distributed.tensor.placement_types import Replicate, Shard
14+
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
1515
from torch.nn.attention import SDPBackend, sdpa_kernel
1616
from torch.testing._internal.distributed.fake_pg import FakeStore
1717

@@ -556,7 +556,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
556556
# AutoParallel code starts here
557557
# ==============================================================
558558

559-
world_size = 256
559+
world_size = 64
560560

561561
fake_store = FakeStore()
562562
torch.distributed.init_process_group(
@@ -579,7 +579,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
579579
),
580580
)
581581

582-
batch_size = 4 * mesh.shape[0]
582+
batch_size = 2 * mesh.shape[0]
583583
seqlen = 2048 * 4
584584
vocab_size = 128256
585585
use_vocab_parallel = not use_1d_mesh
@@ -588,6 +588,8 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
588588

589589
def model_fn():
590590
model_args = TransformerModelArgs(
591+
dim=4096,
592+
n_heads=32,
591593
n_layers=32,
592594
vocab_size=vocab_size,
593595
max_seq_len=seqlen,
@@ -610,6 +612,90 @@ def input_fn():
610612

611613
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
612614

615+
616+
def group_mm_nodes_with_its_gradients(nodes):
617+
fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta]
618+
bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta]
619+
assert len(fwd_nodes) * 2 == len(bwd_nodes)
620+
res = {}
621+
for fwd_node in fwd_nodes:
622+
o = []
623+
for bwd_node in bwd_nodes:
624+
if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]:
625+
o.append(bwd_node)
626+
assert len(o) == 2
627+
res[fwd_node] = o
628+
return res
629+
630+
631+
def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False):
632+
# out = x @ w - S(0)R, RS(1) -> S(0)S(1)
633+
# g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0)
634+
# g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P
635+
636+
add_node_constraint = autop.sharding_optimizer.add_node_constraint
637+
fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes)
638+
fwd_nodes = list(fwd_bwd_groups.keys())
639+
dim1 = 0 if feat_dim == 1 else 1
640+
dim2 = 1 if feat_dim == 1 else 0
641+
# assume there are 7 mm nodes per transformer block
642+
# skip last mm as it's the final projection layer
643+
assert (
644+
len(fwd_nodes) - 1
645+
) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}"
646+
for block in range(0, len(fwd_nodes) - 1, 7):
647+
fwd_nodes_block = fwd_nodes[block : block + 7]
648+
# force the first 3 mm nodes to be S(0)S(1)
649+
the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6]
650+
for n in the_nodes:
651+
add_node_constraint(n, (Shard(0), Shard(feat_dim)))
652+
add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate()))
653+
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1)))
654+
655+
if bwd_constraint:
656+
bwd_nodes = fwd_bwd_groups[n]
657+
# first is g_w, second is g_x
658+
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1)))
659+
add_node_constraint(bwd_nodes[1], (Shard(0), Partial()))
660+
661+
# add reduction to finish TP, yielding S(0)P
662+
the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7]
663+
for n in the_nodes:
664+
add_node_constraint(n, (Shard(0), Partial()))
665+
add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim)))
666+
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0)))
667+
668+
if bwd_constraint:
669+
bwd_nodes = fwd_bwd_groups[n]
670+
# first is g_w, second is g_x
671+
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2)))
672+
add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim)))
673+
674+
675+
def add_tp_constraints(autop):
676+
mm_nodes = autop.gm.graph.find_nodes(
677+
op="call_function", target=torch.ops.aten.mm.default
678+
)
679+
einsum_nodes = autop.gm.graph.find_nodes(
680+
op="call_function", target=torch.ops.aten.einsum.default
681+
)
682+
assert (len(mm_nodes) > 0) ^ (
683+
len(einsum_nodes) > 0
684+
), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}"
685+
feat_dim = 1 if len(mm_nodes) > 0 else 2
686+
tgt_nodes = mm_nodes + einsum_nodes
687+
force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True)
688+
689+
if einsum_nodes:
690+
# add sequence parallelism if we have einsum nodes
691+
autop.sharding_optimizer.add_node_constraint(
692+
list(tgt_nodes[3].users)[0], (Shard(0), Shard(1))
693+
)
694+
autop.sharding_optimizer.add_node_constraint(
695+
list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1))
696+
)
697+
698+
613699
# parallelize the model
614700
with AutoParallel(
615701
model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True
@@ -626,17 +712,9 @@ def input_fn():
626712
autop.add_input_constraints([x_sharding])
627713
autop.add_output_constraints([out_sharding])
628714

629-
# example of how to add manual constraints
630-
if use_1d_mesh:
631-
# add constraint on the output sharding of embedding bag
632-
# otherwise it might decide that it's ok to replicate both inputs. This is indeed fine
633-
# for 1d but the current cost model doesn't take output memory into account, so it thinks
634-
# it is not expensive. I should add an activation memory constraint as well to avoid
635-
# those cases
636-
embedding_nodes = autop.gm.graph.find_nodes(
637-
op="call_function", target=torch.ops.aten.embedding.default
638-
)
639-
autop.sharding_optimizer.add_node_constraint(embedding_nodes[0], x_sharding)
715+
enable_manual_constraint = False
716+
if enable_manual_constraint and not use_1d_mesh:
717+
add_tp_constraints(autop)
640718

641719
t = time.time()
642720
sharding_placement = autop.optimize_placement()

0 commit comments

Comments
 (0)