|
9 | 9 | from torch.distributed.tensor import DTensor |
10 | 10 | from torch.distributed.tensor._dtensor_spec import DTensorSpec |
11 | 11 | from torch.distributed.tensor._redistribute import redistribute_local_tensor |
12 | | -from torch.distributed.tensor.placement_types import Partial, Replicate, Shard |
| 12 | +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard # noqa |
13 | 13 | from torch.fx.experimental.proxy_tensor import make_fx |
14 | 14 | from torch.utils._pytree import tree_flatten, tree_map_only |
15 | 15 |
|
16 | 16 |
|
17 | 17 | def my_redistribute_local_tensor(arg, curr_spec, tgt_spec): |
18 | | - if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == ( |
19 | | - Replicate(), |
20 | | - Shard(0), |
21 | | - ): |
22 | | - # TODO: double-check in which cases this is valid |
23 | | - x = curr_spec.placements[0]._to_replicate_tensor( |
24 | | - arg, curr_spec.mesh, 0, curr_spec.shape |
25 | | - ) |
26 | | - elif curr_spec.placements == (Partial(), Shard(0)) and tgt_spec.placements == ( |
27 | | - Shard(0), |
28 | | - Shard(0), |
29 | | - ): |
30 | | - x = curr_spec.placements[0]._reduce_shard_value( |
31 | | - arg, curr_spec.mesh, 0, tgt_spec.placements[0] |
32 | | - ) |
| 18 | + # if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == ( |
| 19 | + # Replicate(), |
| 20 | + # Shard(0), |
| 21 | + # ): |
| 22 | + # # TODO: double-check in which cases this is valid |
| 23 | + # x = curr_spec.placements[0]._to_replicate_tensor( |
| 24 | + # arg, curr_spec.mesh, 0, curr_spec.shape |
| 25 | + # ) |
| 26 | + # elif curr_spec.placements == (Partial(), Shard(0)) and tgt_spec.placements == ( |
| 27 | + # Shard(0), |
| 28 | + # Shard(0), |
| 29 | + # ): |
| 30 | + # x = curr_spec.placements[0]._reduce_shard_value( |
| 31 | + # arg, curr_spec.mesh, 0, tgt_spec.placements[0] |
| 32 | + # ) |
33 | 33 | # elif curr_spec.placements == (Partial(), Shard(1)) and tgt_spec.placements == (Replicate(), Shard(1)): |
34 | 34 | # from IPython import embed; embed(); sys.sdf |
35 | | - else: |
36 | | - x = redistribute_local_tensor(arg, curr_spec, tgt_spec) |
| 35 | + # else: |
| 36 | + x = redistribute_local_tensor(arg, curr_spec, tgt_spec) |
37 | 37 | return x |
38 | 38 |
|
39 | 39 |
|
|
0 commit comments