Skip to content

Commit 19f1fab

Browse files
authored
Use custom redistribute_cost for optimal PP->S(0)S(0) cost (#109)
* Use custom redistribute_cost for optimal PP->S(0)S(0) cost Previously, given the default iteration order, it was always less expensive to do PP->S(0)P->S(0)S(0) instead of directly PP->S(0)S(0). We now favor doing it in a single pass by uting the optimal redistribution cost for a given operation. For now, we only consider the PP->S(0)S(0) case, but we should generalize this to all cases in the future * Add all-to-all cost Need to verify if it's correct * Using a2a == ag*4 gives better results for llama3 8b
1 parent c680107 commit 19f1fab

File tree

3 files changed

+119
-1
lines changed

3 files changed

+119
-1
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch.distributed.tensor._dtensor_spec as dtensor_spec
7+
from torch.distributed.tensor._collective_utils import (
8+
MeshTopoInfo,
9+
allgather_cost,
10+
allreduce_cost,
11+
reduce_scatter_cost,
12+
spec_to_bytes,
13+
)
14+
from torch.distributed.tensor.placement_types import Partial, Shard
15+
16+
17+
def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
18+
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
19+
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
20+
num_hops = num_devices_on_mesh_dim**2
21+
# base latency + comm latency
22+
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
23+
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
24+
return latency + bw * 1e6 # rescale to us
25+
26+
27+
# this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py
28+
# with iteration order introduced
29+
# TODO: this should be improved, as we just really use the non-canonical order for
30+
# PP->S(0)S(0) for now
31+
def redistribute_cost(
32+
current_spec: "dtensor_spec.DTensorSpec",
33+
target_spec: "dtensor_spec.DTensorSpec",
34+
order: list[int],
35+
) -> float:
36+
"""
37+
This function returns the cost of redistribute from current to target DTensorSpec.
38+
39+
NOTE:
40+
1. Only consider communication cost here, since computation costs for redistribute
41+
are quite trivial (i.e. we only need to narrow or simple division)
42+
2. Only consider redistribute cost on same mesh, cross mesh communication cost is
43+
not quite needed for operator strategy estimation/selection.
44+
"""
45+
if current_spec.mesh != target_spec.mesh:
46+
# make infinite cost if meshes are not same
47+
# TODO: see if we want to support this once there's cross mesh communication
48+
return float("inf")
49+
50+
if current_spec.is_replicated():
51+
# short-cut:
52+
# comm cost is 0 if current spec is already full replication
53+
return 0.0
54+
55+
mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
56+
cost = 0.0
57+
comm_bytes_gb = (
58+
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
59+
)
60+
# Transformation that considered for redistribute cost:
61+
# 1. allgather 2. alltoall
62+
# 3. allreduce 4. reduce_scatter
63+
curr_placements = [current_spec.placements[i] for i in order]
64+
tgt_placements = [target_spec.placements[i] for i in order]
65+
for i, current, target in zip(order, curr_placements, tgt_placements):
66+
if current == target:
67+
continue
68+
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
69+
if current.is_shard() and target.is_replicate():
70+
# allgather gives larger comm bytes
71+
comm_bytes_gb *= num_devices_on_mesh_dim
72+
# add up allgather comm cost
73+
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
74+
elif current.is_shard() and target.is_shard():
75+
# should be alltoall comm, since we haven't implement it yet, add penalty
76+
# to favor allgather instead
77+
# cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i)
78+
cost += allgather_cost(comm_bytes_gb, mesh_topo, i) * 4.0
79+
elif current.is_partial() and target.is_replicate():
80+
# add up allreduce comm cost
81+
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
82+
elif current.is_partial() and target.is_shard():
83+
# add up reduce_scatter comm cost
84+
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
85+
# after reduce_scatter the comm bytes for further collectives halved.
86+
comm_bytes_gb /= num_devices_on_mesh_dim
87+
elif current.is_shard() and target.is_partial():
88+
# ban shard -> partial as it does not make sense to perform
89+
# this redistribute
90+
return float("inf")
91+
92+
return cost
93+
94+
95+
def estimate_strategy_comms_cost(src_spec, tgt_spec):
96+
order = list(range(src_spec.mesh.ndim))
97+
if src_spec.placements == (Partial(), Partial()) and tgt_spec.placements == (
98+
Shard(0),
99+
Shard(0),
100+
):
101+
order = [1, 0]
102+
comms_cost = redistribute_cost(src_spec, tgt_spec, order)
103+
return comms_cost

autoparallel/optimize_sharding.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
9797
from torch.utils._pytree import tree_flatten, tree_map_only
9898

99+
from .collective_runtime_estimation import estimate_strategy_comms_cost
99100
from .compute_estimation import (
100101
_get_sharded_shape_stride,
101102
estimate_strategy_runtime_cost,
@@ -301,6 +302,19 @@ def build_ds(self):
301302
if node.op != "placeholder":
302303
argi_strat = self.strats[self._all_input_nodes(node)[argi]]
303304
for ii, comm_cost in enumerate(xxi):
305+
if node.op != "placeholder":
306+
src_spec = argi_strat.strategies[ii].output_specs
307+
# TODO: operator.getitem being special is something
308+
# we might want to change in the future
309+
if node.target == operator.getitem:
310+
src_spec = src_spec[node.args[1]]
311+
tgt_spec = ssi.input_specs[argi]
312+
assert isinstance(src_spec, DTensorSpec)
313+
assert isinstance(tgt_spec, DTensorSpec)
314+
# we use our custom comm_cost function to estimate the cost
315+
# of the collective operation
316+
comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec)
317+
304318
if node in grad_param_nodes:
305319
comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp
306320
# Imagine we start node_i from S(0)S(0) and we want to reach node_{i+2} at

autoparallel/ordered_sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def _optimize_same_nd_sharding_as_1d(
3737
return redistribute_local_tensor(arg, curr_spec, tgt_spec)
3838

3939
# TODO: make this more general, I'm playing safe for now
40-
if not (curr_spec_first == Shard(0) and tgt_spec_first == Replicate()):
40+
allowed_placements = [(Shard(0), Replicate()), (Partial(), Shard(0))]
41+
if (curr_spec_first, tgt_spec_first) not in allowed_placements:
4142
print(f"NOT doing optimization for {str(curr_spec)} -> {str(tgt_spec)}")
4243
return redistribute_local_tensor(arg, curr_spec, tgt_spec)
4344

0 commit comments

Comments
 (0)