@@ -155,7 +155,7 @@ def _get_device_tflops(dtype):
155155 return device_limit .gemm_tflops [dtype ]
156156
157157
158- def _get_sharded_shape (spec ):
158+ def _get_sharded_shape_stride (spec ):
159159 mesh = spec .mesh
160160 tensor_shape = spec .tensor_meta .shape
161161 # TODO: take dtype into account as well
@@ -164,11 +164,15 @@ def _get_sharded_shape(spec):
164164 # TODO: find a better heuristic other than
165165 # running DTensor
166166 new_tensor_shape = list (tensor_shape )
167+ new_tensor_stride = list (spec .tensor_meta .stride )
167168 for mesh_size , placement in zip (mesh .shape , placements ):
168169 if placement .is_shard ():
169170 dim = placement .dim
170171 new_tensor_shape [dim ] = (new_tensor_shape [dim ] + mesh_size - 1 ) // mesh_size
171- return new_tensor_shape
172+ new_tensor_stride [dim ] = (
173+ new_tensor_stride [dim ] + mesh_size - 1
174+ ) // mesh_size
175+ return new_tensor_shape , new_tensor_stride
172176
173177
174178def estimate_strategy_runtime_cost (node , strategy ):
@@ -191,15 +195,18 @@ def estimate_strategy_runtime_cost(node, strategy):
191195 if len (kwargs ) > 0 :
192196 for k , v in kwargs .items ():
193197 assert not isinstance (v , torch .Tensor ), f"{ node } { v } "
194- args_shapes = tuple (_get_sharded_shape (spec ) for spec in strategy .input_specs )
198+ args_sizes_strides = tuple (
199+ _get_sharded_shape_stride (spec ) for spec in strategy .input_specs
200+ )
195201
196202 counter = 0
197203 args = list (args )
198204 for i , arg in enumerate (args ):
199205 if isinstance (arg , torch .Tensor ):
200206 with fake_mode :
201- args [i ] = torch .empty (
202- args_shapes [counter ], device = arg .device , dtype = arg .dtype
207+ sizes , strides = args_sizes_strides [counter ]
208+ args [i ] = torch .empty_strided (
209+ sizes , strides , device = arg .device , dtype = arg .dtype
203210 )
204211 counter += 1
205212
0 commit comments