Skip to content

Commit 1af2a14

Browse files
bdhirshwconstab
andauthored
preserve tensor striding during compute estimation (#37)
* preserve tensor striding during compute estimation * Fix other callsites --------- Co-authored-by: Will Constable <[email protected]>
1 parent 4b7f973 commit 1af2a14

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

autoparallel/compute_estimation.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _get_device_tflops(dtype):
155155
return device_limit.gemm_tflops[dtype]
156156

157157

158-
def _get_sharded_shape(spec):
158+
def _get_sharded_shape_stride(spec):
159159
mesh = spec.mesh
160160
tensor_shape = spec.tensor_meta.shape
161161
# TODO: take dtype into account as well
@@ -164,11 +164,15 @@ def _get_sharded_shape(spec):
164164
# TODO: find a better heuristic other than
165165
# running DTensor
166166
new_tensor_shape = list(tensor_shape)
167+
new_tensor_stride = list(spec.tensor_meta.stride)
167168
for mesh_size, placement in zip(mesh.shape, placements):
168169
if placement.is_shard():
169170
dim = placement.dim
170171
new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size
171-
return new_tensor_shape
172+
new_tensor_stride[dim] = (
173+
new_tensor_stride[dim] + mesh_size - 1
174+
) // mesh_size
175+
return new_tensor_shape, new_tensor_stride
172176

173177

174178
def estimate_strategy_runtime_cost(node, strategy):
@@ -191,15 +195,18 @@ def estimate_strategy_runtime_cost(node, strategy):
191195
if len(kwargs) > 0:
192196
for k, v in kwargs.items():
193197
assert not isinstance(v, torch.Tensor), f"{node} {v}"
194-
args_shapes = tuple(_get_sharded_shape(spec) for spec in strategy.input_specs)
198+
args_sizes_strides = tuple(
199+
_get_sharded_shape_stride(spec) for spec in strategy.input_specs
200+
)
195201

196202
counter = 0
197203
args = list(args)
198204
for i, arg in enumerate(args):
199205
if isinstance(arg, torch.Tensor):
200206
with fake_mode:
201-
args[i] = torch.empty(
202-
args_shapes[counter], device=arg.device, dtype=arg.dtype
207+
sizes, strides = args_sizes_strides[counter]
208+
args[i] = torch.empty_strided(
209+
sizes, strides, device=arg.device, dtype=arg.dtype
203210
)
204211
counter += 1
205212

autoparallel/optimize_sharding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@
8585
from torch.distributed.tensor.placement_types import Replicate, Shard
8686
from torch.utils._pytree import tree_flatten, tree_map_only
8787

88-
from .compute_estimation import _get_sharded_shape, estimate_strategy_runtime_cost
88+
from .compute_estimation import (
89+
_get_sharded_shape_stride,
90+
estimate_strategy_runtime_cost,
91+
)
8992
from .propagation_rules import _create_all_options
9093
from .utils import get_placement_options
9194

@@ -628,7 +631,7 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high)
628631
data = self.ds[(s_i, 0, ii, 0)]
629632
spec = data["inp_strat"]
630633
tensor_shape = spec.tensor_meta.shape
631-
new_tensor_shape = _get_sharded_shape(spec)
634+
new_tensor_shape, _ = _get_sharded_shape_stride(spec)
632635
new_size = math.prod(new_tensor_shape)
633636
old_size = math.prod(tensor_shape)
634637
elms.append(data["va"] * new_size / old_size)

0 commit comments

Comments
 (0)