@@ -660,6 +660,105 @@ def index_rule(mesh, op_schema):
660660 return out_strat
661661
662662
663+ @register_opschema_rule (torch .ops .aten .gather .default )
664+ def gather_strategy (mesh , op_schema ):
665+ from torch .distributed .tensor ._op_schema import PlacementList
666+ from torch .distributed .tensor ._ops ._embedding_ops import _MaskPartial
667+ from torch .distributed .tensor ._ops .utils import expand_to_full_mesh_op_strategy
668+
669+ input_strategy = op_schema .args_schema [0 ]
670+ dim = op_schema .args_schema [1 ]
671+ index_strategy = op_schema .args_schema [2 ]
672+
673+ input_shape = input_strategy .shape
674+ index_shape = index_strategy .shape
675+
676+ single_mesh_dim_strategies = []
677+
678+ # placement list stores placements of [output, input, index]
679+ # first we always have replicate all for inputs and output
680+ all_replicate : PlacementList = [Replicate ()] * 3
681+ single_mesh_dim_strategies .append (all_replicate )
682+
683+ # input sharding, input sharded, index accepts mask partial, output follows index
684+ # this only works when the input is sharded on the gather dimension, and
685+ # index has size 1 on the gather dimension
686+ if index_shape [dim ] == 1 :
687+ index_partial_placement = _MaskPartial (offset_shape = input_shape , offset_dim = dim )
688+ input_sharding : PlacementList = [
689+ index_partial_placement ,
690+ Shard (dim ),
691+ index_partial_placement ,
692+ ]
693+ single_mesh_dim_strategies .append (input_sharding )
694+
695+ # index sharding, input replicated, index sharded, output follows index
696+ # this only works when the sharding dimension is the gather dimension
697+ index_sharding : PlacementList = [Shard (dim ), Replicate (), Shard (dim )]
698+ single_mesh_dim_strategies .append (index_sharding )
699+
700+ if len (input_shape ) == len (index_shape ):
701+ for d in range (len (input_shape )):
702+ if d != dim :
703+ sharding = [Shard (d ), Shard (d ), Shard (d )]
704+ single_mesh_dim_strategies .append (sharding )
705+
706+ return expand_to_full_mesh_op_strategy (
707+ mesh , op_schema , single_mesh_dim_strategies , input_index = 1
708+ )
709+
710+
711+ @register_opschema_rule (torch .ops .aten .scatter_add .default )
712+ def scatter_add_strategy (mesh , op_schema ):
713+ from torch .distributed .tensor ._op_schema import PlacementList
714+
715+ # from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
716+ from torch .distributed .tensor ._ops .utils import expand_to_full_mesh_op_strategy
717+
718+ input_strategy = op_schema .args_schema [0 ]
719+ dim = op_schema .args_schema [1 ]
720+ index_strategy = op_schema .args_schema [2 ]
721+ # src_strategy = op_schema.args_schema[3]
722+
723+ input_shape = input_strategy .shape
724+ index_shape = index_strategy .shape
725+
726+ single_mesh_dim_strategies = []
727+
728+ # placement list stores placements of [output, input, index]
729+ # first we always have replicate all for inputs and output
730+ all_replicate : PlacementList = [Replicate ()] * 4
731+ single_mesh_dim_strategies .append (all_replicate )
732+
733+ """
734+ # input sharding, input sharded, index accepts mask partial, output follows index
735+ # this only works when the input is sharded on the gather dimension, and
736+ # index has size 1 on the gather dimension
737+ if index_shape[dim] == 1:
738+ index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
739+ input_sharding: PlacementList = [
740+ index_partial_placement,
741+ Shard(dim),
742+ index_partial_placement,
743+ ]
744+ single_mesh_dim_strategies.append(input_sharding)
745+ """
746+ # index sharding, input replicated, index sharded, output follows index
747+ # this only works when the sharding dimension is the gather dimension
748+ index_sharding : PlacementList = [Shard (dim ), Replicate (), Shard (dim ), Shard (dim )]
749+ single_mesh_dim_strategies .append (index_sharding )
750+
751+ if len (input_shape ) == len (index_shape ):
752+ for d in range (len (input_shape )):
753+ if d != dim :
754+ sharding = [Shard (d ), Shard (d ), Shard (d ), Shard (d )]
755+ single_mesh_dim_strategies .append (sharding )
756+
757+ return expand_to_full_mesh_op_strategy (
758+ mesh , op_schema , single_mesh_dim_strategies , input_index = 1
759+ )
760+
761+
663762def sdpa_rule (op , mesh , op_schema ):
664763 out_strat = get_op_strategy (op , op_schema )
665764 # remove wrong context-parallel strategy
0 commit comments