Skip to content

Commit 315f44b

Browse files
committed
Add gather and scatter_add strategies
They were taken from #29
1 parent fcc91a5 commit 315f44b

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

autoparallel/propagation_rules.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,105 @@ def index_rule(mesh, op_schema):
660660
return out_strat
661661

662662

663+
@register_opschema_rule(torch.ops.aten.gather.default)
664+
def gather_strategy(mesh, op_schema):
665+
from torch.distributed.tensor._op_schema import PlacementList
666+
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
667+
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
668+
669+
input_strategy = op_schema.args_schema[0]
670+
dim = op_schema.args_schema[1]
671+
index_strategy = op_schema.args_schema[2]
672+
673+
input_shape = input_strategy.shape
674+
index_shape = index_strategy.shape
675+
676+
single_mesh_dim_strategies = []
677+
678+
# placement list stores placements of [output, input, index]
679+
# first we always have replicate all for inputs and output
680+
all_replicate: PlacementList = [Replicate()] * 3
681+
single_mesh_dim_strategies.append(all_replicate)
682+
683+
# input sharding, input sharded, index accepts mask partial, output follows index
684+
# this only works when the input is sharded on the gather dimension, and
685+
# index has size 1 on the gather dimension
686+
if index_shape[dim] == 1:
687+
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
688+
input_sharding: PlacementList = [
689+
index_partial_placement,
690+
Shard(dim),
691+
index_partial_placement,
692+
]
693+
single_mesh_dim_strategies.append(input_sharding)
694+
695+
# index sharding, input replicated, index sharded, output follows index
696+
# this only works when the sharding dimension is the gather dimension
697+
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
698+
single_mesh_dim_strategies.append(index_sharding)
699+
700+
if len(input_shape) == len(index_shape):
701+
for d in range(len(input_shape)):
702+
if d != dim:
703+
sharding = [Shard(d), Shard(d), Shard(d)]
704+
single_mesh_dim_strategies.append(sharding)
705+
706+
return expand_to_full_mesh_op_strategy(
707+
mesh, op_schema, single_mesh_dim_strategies, input_index=1
708+
)
709+
710+
711+
@register_opschema_rule(torch.ops.aten.scatter_add.default)
712+
def scatter_add_strategy(mesh, op_schema):
713+
from torch.distributed.tensor._op_schema import PlacementList
714+
715+
# from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
716+
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
717+
718+
input_strategy = op_schema.args_schema[0]
719+
dim = op_schema.args_schema[1]
720+
index_strategy = op_schema.args_schema[2]
721+
# src_strategy = op_schema.args_schema[3]
722+
723+
input_shape = input_strategy.shape
724+
index_shape = index_strategy.shape
725+
726+
single_mesh_dim_strategies = []
727+
728+
# placement list stores placements of [output, input, index]
729+
# first we always have replicate all for inputs and output
730+
all_replicate: PlacementList = [Replicate()] * 4
731+
single_mesh_dim_strategies.append(all_replicate)
732+
733+
"""
734+
# input sharding, input sharded, index accepts mask partial, output follows index
735+
# this only works when the input is sharded on the gather dimension, and
736+
# index has size 1 on the gather dimension
737+
if index_shape[dim] == 1:
738+
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
739+
input_sharding: PlacementList = [
740+
index_partial_placement,
741+
Shard(dim),
742+
index_partial_placement,
743+
]
744+
single_mesh_dim_strategies.append(input_sharding)
745+
"""
746+
# index sharding, input replicated, index sharded, output follows index
747+
# this only works when the sharding dimension is the gather dimension
748+
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)]
749+
single_mesh_dim_strategies.append(index_sharding)
750+
751+
if len(input_shape) == len(index_shape):
752+
for d in range(len(input_shape)):
753+
if d != dim:
754+
sharding = [Shard(d), Shard(d), Shard(d), Shard(d)]
755+
single_mesh_dim_strategies.append(sharding)
756+
757+
return expand_to_full_mesh_op_strategy(
758+
mesh, op_schema, single_mesh_dim_strategies, input_index=1
759+
)
760+
761+
663762
def sdpa_rule(op, mesh, op_schema):
664763
out_strat = get_op_strategy(op, op_schema)
665764
# remove wrong context-parallel strategy

0 commit comments

Comments
 (0)