1111import torch .nn .functional as F
1212from torch import nn
1313from torch .distributed .fsdp import MixedPrecisionPolicy
14- from torch .distributed .tensor .placement_types import Replicate , Shard
14+ from torch .distributed .tensor .placement_types import Partial , Replicate , Shard
1515from torch .nn .attention import SDPBackend , sdpa_kernel
1616from torch .testing ._internal .distributed .fake_pg import FakeStore
1717
@@ -556,7 +556,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
556556# AutoParallel code starts here
557557# ==============================================================
558558
559- world_size = 256
559+ world_size = 64
560560
561561fake_store = FakeStore ()
562562torch .distributed .init_process_group (
@@ -579,7 +579,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
579579 ),
580580 )
581581
582- batch_size = 4 * mesh .shape [0 ]
582+ batch_size = 2 * mesh .shape [0 ]
583583seqlen = 2048 * 4
584584vocab_size = 128256
585585use_vocab_parallel = not use_1d_mesh
@@ -588,6 +588,8 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
588588
589589def model_fn ():
590590 model_args = TransformerModelArgs (
591+ dim = 4096 ,
592+ n_heads = 32 ,
591593 n_layers = 32 ,
592594 vocab_size = vocab_size ,
593595 max_seq_len = seqlen ,
@@ -610,6 +612,90 @@ def input_fn():
610612
611613mp_policy = MixedPrecisionPolicy (param_dtype = torch .bfloat16 , reduce_dtype = torch .float32 )
612614
615+
616+ def group_mm_nodes_with_its_gradients (nodes ):
617+ fwd_nodes = [n for n in nodes if "nn_module_stack" in n .meta ]
618+ bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n .meta ]
619+ assert len (fwd_nodes ) * 2 == len (bwd_nodes )
620+ res = {}
621+ for fwd_node in fwd_nodes :
622+ o = []
623+ for bwd_node in bwd_nodes :
624+ if fwd_node .meta ["nn_module_stack" ] == bwd_node .meta ["fwd_nn_module_stack" ]:
625+ o .append (bwd_node )
626+ assert len (o ) == 2
627+ res [fwd_node ] = o
628+ return res
629+
630+
631+ def force_tp_constraints (autop , mm_nodes , feat_dim = 1 , bwd_constraint = False ):
632+ # out = x @ w - S(0)R, RS(1) -> S(0)S(1)
633+ # g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0)
634+ # g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P
635+
636+ add_node_constraint = autop .sharding_optimizer .add_node_constraint
637+ fwd_bwd_groups = group_mm_nodes_with_its_gradients (mm_nodes )
638+ fwd_nodes = list (fwd_bwd_groups .keys ())
639+ dim1 = 0 if feat_dim == 1 else 1
640+ dim2 = 1 if feat_dim == 1 else 0
641+ # assume there are 7 mm nodes per transformer block
642+ # skip last mm as it's the final projection layer
643+ assert (
644+ len (fwd_nodes ) - 1
645+ ) % 7 == 0 , f"expected 7 mm nodes per transformer block, { len (fwd_nodes ) - 1 } "
646+ for block in range (0 , len (fwd_nodes ) - 1 , 7 ):
647+ fwd_nodes_block = fwd_nodes [block : block + 7 ]
648+ # force the first 3 mm nodes to be S(0)S(1)
649+ the_nodes = fwd_nodes_block [:3 ] + fwd_nodes_block [4 :6 ]
650+ for n in the_nodes :
651+ add_node_constraint (n , (Shard (0 ), Shard (feat_dim )))
652+ add_node_constraint (n .all_input_nodes [0 ], (Shard (0 ), Replicate ()))
653+ add_node_constraint (n .all_input_nodes [1 ], (Replicate (), Shard (1 )))
654+
655+ if bwd_constraint :
656+ bwd_nodes = fwd_bwd_groups [n ]
657+ # first is g_w, second is g_x
658+ add_node_constraint (bwd_nodes [0 ], (Partial (), Shard (dim1 )))
659+ add_node_constraint (bwd_nodes [1 ], (Shard (0 ), Partial ()))
660+
661+ # add reduction to finish TP, yielding S(0)P
662+ the_nodes = fwd_nodes_block [3 :4 ] + fwd_nodes_block [6 :7 ]
663+ for n in the_nodes :
664+ add_node_constraint (n , (Shard (0 ), Partial ()))
665+ add_node_constraint (n .all_input_nodes [0 ], (Shard (0 ), Shard (feat_dim )))
666+ add_node_constraint (n .all_input_nodes [1 ], (Replicate (), Shard (0 )))
667+
668+ if bwd_constraint :
669+ bwd_nodes = fwd_bwd_groups [n ]
670+ # first is g_w, second is g_x
671+ add_node_constraint (bwd_nodes [0 ], (Partial (), Shard (dim2 )))
672+ add_node_constraint (bwd_nodes [1 ], (Shard (0 ), Shard (feat_dim )))
673+
674+
675+ def add_tp_constraints (autop ):
676+ mm_nodes = autop .gm .graph .find_nodes (
677+ op = "call_function" , target = torch .ops .aten .mm .default
678+ )
679+ einsum_nodes = autop .gm .graph .find_nodes (
680+ op = "call_function" , target = torch .ops .aten .einsum .default
681+ )
682+ assert (len (mm_nodes ) > 0 ) ^ (
683+ len (einsum_nodes ) > 0
684+ ), f"only one should be non-empty, got { len (mm_nodes )} and { len (einsum_nodes )} "
685+ feat_dim = 1 if len (mm_nodes ) > 0 else 2
686+ tgt_nodes = mm_nodes + einsum_nodes
687+ force_tp_constraints (autop , tgt_nodes , feat_dim = feat_dim , bwd_constraint = True )
688+
689+ if einsum_nodes :
690+ # add sequence parallelism if we have einsum nodes
691+ autop .sharding_optimizer .add_node_constraint (
692+ list (tgt_nodes [3 ].users )[0 ], (Shard (0 ), Shard (1 ))
693+ )
694+ autop .sharding_optimizer .add_node_constraint (
695+ list (list (tgt_nodes [3 ].users )[0 ].users )[0 ], (Shard (0 ), Shard (1 ))
696+ )
697+
698+
613699# parallelize the model
614700with AutoParallel (
615701 model , input_fn , mesh , mp_policy , compile = True , repeated_subgraphs = True
@@ -626,17 +712,9 @@ def input_fn():
626712 autop .add_input_constraints ([x_sharding ])
627713 autop .add_output_constraints ([out_sharding ])
628714
629- # example of how to add manual constraints
630- if use_1d_mesh :
631- # add constraint on the output sharding of embedding bag
632- # otherwise it might decide that it's ok to replicate both inputs. This is indeed fine
633- # for 1d but the current cost model doesn't take output memory into account, so it thinks
634- # it is not expensive. I should add an activation memory constraint as well to avoid
635- # those cases
636- embedding_nodes = autop .gm .graph .find_nodes (
637- op = "call_function" , target = torch .ops .aten .embedding .default
638- )
639- autop .sharding_optimizer .add_node_constraint (embedding_nodes [0 ], x_sharding )
715+ enable_manual_constraint = False
716+ if enable_manual_constraint and not use_1d_mesh :
717+ add_tp_constraints (autop )
640718
641719 t = time .time ()
642720 sharding_placement = autop .optimize_placement ()
0 commit comments