Skip to content

Commit 6b44ea6

Browse files
authored
Disable S(0)S(0) -> RS(0) optimization (#24)
Now that we are using DTensors to store the parameters, it's better to be safe and use the redistribution from DTensor. We can optimize this further once DTensor has order information
1 parent fa8ed5b commit 6b44ea6

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

autoparallel/apply_sharding.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,31 @@
99
from torch.distributed.tensor import DTensor
1010
from torch.distributed.tensor._dtensor_spec import DTensorSpec
1111
from torch.distributed.tensor._redistribute import redistribute_local_tensor
12-
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
12+
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard # noqa
1313
from torch.fx.experimental.proxy_tensor import make_fx
1414
from torch.utils._pytree import tree_flatten, tree_map_only
1515

1616

1717
def my_redistribute_local_tensor(arg, curr_spec, tgt_spec):
18-
if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
19-
Replicate(),
20-
Shard(0),
21-
):
22-
# TODO: double-check in which cases this is valid
23-
x = curr_spec.placements[0]._to_replicate_tensor(
24-
arg, curr_spec.mesh, 0, curr_spec.shape
25-
)
26-
elif curr_spec.placements == (Partial(), Shard(0)) and tgt_spec.placements == (
27-
Shard(0),
28-
Shard(0),
29-
):
30-
x = curr_spec.placements[0]._reduce_shard_value(
31-
arg, curr_spec.mesh, 0, tgt_spec.placements[0]
32-
)
18+
# if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
19+
# Replicate(),
20+
# Shard(0),
21+
# ):
22+
# # TODO: double-check in which cases this is valid
23+
# x = curr_spec.placements[0]._to_replicate_tensor(
24+
# arg, curr_spec.mesh, 0, curr_spec.shape
25+
# )
26+
# elif curr_spec.placements == (Partial(), Shard(0)) and tgt_spec.placements == (
27+
# Shard(0),
28+
# Shard(0),
29+
# ):
30+
# x = curr_spec.placements[0]._reduce_shard_value(
31+
# arg, curr_spec.mesh, 0, tgt_spec.placements[0]
32+
# )
3333
# elif curr_spec.placements == (Partial(), Shard(1)) and tgt_spec.placements == (Replicate(), Shard(1)):
3434
# from IPython import embed; embed(); sys.sdf
35-
else:
36-
x = redistribute_local_tensor(arg, curr_spec, tgt_spec)
35+
# else:
36+
x = redistribute_local_tensor(arg, curr_spec, tgt_spec)
3737
return x
3838

3939

0 commit comments

Comments
 (0)