Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions autoparallel/collective_runtime_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
allgather_cost,
allreduce_cost,
reduce_scatter_cost,
spec_to_bytes,
)
from torch.distributed.tensor.placement_types import Partial, Shard

Expand Down Expand Up @@ -63,29 +62,50 @@ def redistribute_cost(

mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
cost = 0.0
comm_bytes_gb = (
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
import math

from torch.distributed._functional_collectives import _are_we_tracing
from torch.distributed.tensor._redistribute import (
_gen_transform_infos,
_gen_transform_infos_non_cached,
)
# Transformation that considered for redistribute cost:
# 1. allgather 2. alltoall
# 3. allreduce 4. reduce_scatter
curr_placements = [current_spec.placements[i] for i in order]
tgt_placements = [target_spec.placements[i] for i in order]

if _are_we_tracing():
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
else:
transform_infos = _gen_transform_infos(current_spec, target_spec)

# Transformation that considered for redistribute cost:
# 1. allgather 2. alltoall
# 3. allreduce 4. reduce_scatter
# curr_placements = [current_spec.placements[i] for i in order]
# tgt_placements = [target_spec.placements[i] for i in order]
is_contiguous: bool = check_contiguous_sizes_strides(
current_spec.shape, current_spec.stride
)
for i, current, target in zip(order, curr_placements, tgt_placements):
if current == target:
continue
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
for transform_info in transform_infos:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, what we would like to have here I think is the minimal redistribution cost over all possible input/output orderings.

This is to ensure that we don't have to increase the search space for AutoParallel when performing the optimization, as we can focus only on the shardings (without order) and then optimize the ordering afterwards.

Does it make sense?

assert (
current_spec.tensor_meta is not None
), "spec should have tensor meta defined!"
comm_bytes_gb = (
current_spec.tensor_meta.dtype.itemsize
* math.prod(transform_info.logical_shape)
/ 1024
/ 1024
/ 1024
)
if not is_contiguous:
cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
current = transform_info.src_dst_placements[0]
target = transform_info.src_dst_placements[1]
if current == target:
continue
mesh_dim = transform_info.mesh_dim
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
if current.is_shard() and target.is_replicate():
current = cast(Shard, current)
# allgather gives larger comm bytes
comm_bytes_gb *= num_devices_on_mesh_dim
# add up allgather comm cost
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim)
if current.dim != 0:
# penalize cases like S(1) -> R as there are additional compute cost
# which corresponds to reshuffling the whole output tensor
Expand All @@ -98,7 +118,7 @@ def redistribute_cost(
target = cast(Shard, target)
# should be alltoall comm, since we haven't implement it yet, add penalty
# to favor allgather instead
cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us
cost += all_to_all_cost(comm_bytes_gb, mesh_topo, mesh_dim) # us

num_copies = 0
if current.dim != 0:
Expand All @@ -112,11 +132,11 @@ def redistribute_cost(

elif current.is_partial() and target.is_replicate():
# add up allreduce comm cost
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim)
elif current.is_partial() and target.is_shard():
target = cast(Shard, target)
# add up reduce_scatter comm cost
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim)
if target.dim != 0:
# penalize cases like P -> S(1) as there are additional compute cost
# which corresponds to reshuffling the whole input tensor
Expand Down
Loading