Skip to content

Commit 8d5c4d1

Browse files
authored
Remove detach, clone, cat rules fixed upstream (#39)
1 parent d5edb93 commit 8d5c4d1

File tree

2 files changed

+0
-49
lines changed

2 files changed

+0
-49
lines changed

autoparallel/propagation_rules.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -314,53 +314,6 @@ def split_with_sizes_rule(mesh, specs):
314314
return OpStrategy(strats)
315315

316316

317-
@register_rule(torch.ops.aten.cat.default)
318-
def cat_rule(mesh, specs):
319-
op_spec = specs[0]
320-
dim = specs[1]
321-
strats = []
322-
323-
num_tensors = len(op_spec)
324-
325-
inp_ts = []
326-
for i in range(num_tensors):
327-
tm = op_spec[i].strategies[0].output_spec.tensor_meta
328-
inp_ts.append(_build_meta_tensor(tm))
329-
330-
out_t = torch.cat(inp_ts, dim)
331-
332-
banned_idxs = set()
333-
for i, ss in enumerate(op_spec[0].strategies):
334-
for placement in ss.output_spec.placements:
335-
if placement.is_shard(dim) or placement.is_partial():
336-
banned_idxs.add(i)
337-
338-
for i in range(1, num_tensors):
339-
assert len(op_spec[i].strategies) == len(
340-
op_spec[0].strategies
341-
), "Assume each cat input has same number of strategies"
342-
343-
for strat_idx, strat in enumerate(op_spec[0].strategies):
344-
placements = strat.output_spec.placements
345-
if any(p.is_shard(dim) or p.is_partial() for p in placements):
346-
continue
347-
output_spec = DTensorSpec(mesh, placements, tensor_meta=_gen_tensor_meta(out_t))
348-
349-
all_costs = []
350-
input_specs = []
351-
for i in range(num_tensors):
352-
input_specs.append(op_spec[i].strategies[strat_idx].output_spec)
353-
redistribute_costs = generate_redistribute_costs(op_spec[i], output_spec)
354-
for banned in banned_idxs:
355-
redistribute_costs[banned] = math.inf
356-
all_costs.append(redistribute_costs)
357-
358-
s = OpSpec(output_spec, input_specs=input_specs)
359-
s.redistribute_cost = all_costs
360-
strats.append(s)
361-
return OpStrategy(strats)
362-
363-
364317
@register_rule(torch.ops.prims.iota.default)
365318
def iota_rule(mesh, specs):
366319
raise NotImplementedError("Needs hardening, only tested on a few cases")

autoparallel/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ def fill_missing_redistribute_cost(op, specs, out_strat):
7070
# TODO: the torch.ops.aten.slice.Tensor is wrong here and in the input_spec!!!!!
7171
handled_ops = {
7272
torch.ops.aten.ones_like.default,
73-
torch.ops.aten.detach.default,
74-
torch.ops.aten.clone.default,
7573
torch.ops.aten.full_like.default,
7674
torch.ops.aten.empty_like.default,
7775
torch.ops.prims.convert_element_type.default,

0 commit comments

Comments
 (0)