|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import torch.distributed.tensor._dtensor_spec as dtensor_spec |
| 7 | +from torch.distributed.tensor._collective_utils import ( |
| 8 | + MeshTopoInfo, |
| 9 | + allgather_cost, |
| 10 | + allreduce_cost, |
| 11 | + reduce_scatter_cost, |
| 12 | + spec_to_bytes, |
| 13 | +) |
| 14 | +from torch.distributed.tensor.placement_types import Partial, Shard |
| 15 | + |
| 16 | + |
| 17 | +def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: |
| 18 | + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] |
| 19 | + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] |
| 20 | + num_hops = num_devices_on_mesh_dim**2 |
| 21 | + # base latency + comm latency |
| 22 | + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us |
| 23 | + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s |
| 24 | + return latency + bw * 1e6 # rescale to us |
| 25 | + |
| 26 | + |
| 27 | +# this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py |
| 28 | +# with iteration order introduced |
| 29 | +# TODO: this should be improved, as we just really use the non-canonical order for |
| 30 | +# PP->S(0)S(0) for now |
| 31 | +def redistribute_cost( |
| 32 | + current_spec: "dtensor_spec.DTensorSpec", |
| 33 | + target_spec: "dtensor_spec.DTensorSpec", |
| 34 | + order: list[int], |
| 35 | +) -> float: |
| 36 | + """ |
| 37 | + This function returns the cost of redistribute from current to target DTensorSpec. |
| 38 | +
|
| 39 | + NOTE: |
| 40 | + 1. Only consider communication cost here, since computation costs for redistribute |
| 41 | + are quite trivial (i.e. we only need to narrow or simple division) |
| 42 | + 2. Only consider redistribute cost on same mesh, cross mesh communication cost is |
| 43 | + not quite needed for operator strategy estimation/selection. |
| 44 | + """ |
| 45 | + if current_spec.mesh != target_spec.mesh: |
| 46 | + # make infinite cost if meshes are not same |
| 47 | + # TODO: see if we want to support this once there's cross mesh communication |
| 48 | + return float("inf") |
| 49 | + |
| 50 | + if current_spec.is_replicated(): |
| 51 | + # short-cut: |
| 52 | + # comm cost is 0 if current spec is already full replication |
| 53 | + return 0.0 |
| 54 | + |
| 55 | + mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) |
| 56 | + cost = 0.0 |
| 57 | + comm_bytes_gb = ( |
| 58 | + spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 |
| 59 | + ) |
| 60 | + # Transformation that considered for redistribute cost: |
| 61 | + # 1. allgather 2. alltoall |
| 62 | + # 3. allreduce 4. reduce_scatter |
| 63 | + curr_placements = [current_spec.placements[i] for i in order] |
| 64 | + tgt_placements = [target_spec.placements[i] for i in order] |
| 65 | + for i, current, target in zip(order, curr_placements, tgt_placements): |
| 66 | + if current == target: |
| 67 | + continue |
| 68 | + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] |
| 69 | + if current.is_shard() and target.is_replicate(): |
| 70 | + # allgather gives larger comm bytes |
| 71 | + comm_bytes_gb *= num_devices_on_mesh_dim |
| 72 | + # add up allgather comm cost |
| 73 | + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) |
| 74 | + elif current.is_shard() and target.is_shard(): |
| 75 | + # should be alltoall comm, since we haven't implement it yet, add penalty |
| 76 | + # to favor allgather instead |
| 77 | + # cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) |
| 78 | + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) * 4.0 |
| 79 | + elif current.is_partial() and target.is_replicate(): |
| 80 | + # add up allreduce comm cost |
| 81 | + cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) |
| 82 | + elif current.is_partial() and target.is_shard(): |
| 83 | + # add up reduce_scatter comm cost |
| 84 | + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) |
| 85 | + # after reduce_scatter the comm bytes for further collectives halved. |
| 86 | + comm_bytes_gb /= num_devices_on_mesh_dim |
| 87 | + elif current.is_shard() and target.is_partial(): |
| 88 | + # ban shard -> partial as it does not make sense to perform |
| 89 | + # this redistribute |
| 90 | + return float("inf") |
| 91 | + |
| 92 | + return cost |
| 93 | + |
| 94 | + |
| 95 | +def estimate_strategy_comms_cost(src_spec, tgt_spec): |
| 96 | + order = list(range(src_spec.mesh.ndim)) |
| 97 | + if src_spec.placements == (Partial(), Partial()) and tgt_spec.placements == ( |
| 98 | + Shard(0), |
| 99 | + Shard(0), |
| 100 | + ): |
| 101 | + order = [1, 0] |
| 102 | + comms_cost = redistribute_cost(src_spec, tgt_spec, order) |
| 103 | + return comms_cost |
0 commit comments