@@ -314,53 +314,6 @@ def split_with_sizes_rule(mesh, specs):
314314 return OpStrategy (strats )
315315
316316
317- @register_rule (torch .ops .aten .cat .default )
318- def cat_rule (mesh , specs ):
319- op_spec = specs [0 ]
320- dim = specs [1 ]
321- strats = []
322-
323- num_tensors = len (op_spec )
324-
325- inp_ts = []
326- for i in range (num_tensors ):
327- tm = op_spec [i ].strategies [0 ].output_spec .tensor_meta
328- inp_ts .append (_build_meta_tensor (tm ))
329-
330- out_t = torch .cat (inp_ts , dim )
331-
332- banned_idxs = set ()
333- for i , ss in enumerate (op_spec [0 ].strategies ):
334- for placement in ss .output_spec .placements :
335- if placement .is_shard (dim ) or placement .is_partial ():
336- banned_idxs .add (i )
337-
338- for i in range (1 , num_tensors ):
339- assert len (op_spec [i ].strategies ) == len (
340- op_spec [0 ].strategies
341- ), "Assume each cat input has same number of strategies"
342-
343- for strat_idx , strat in enumerate (op_spec [0 ].strategies ):
344- placements = strat .output_spec .placements
345- if any (p .is_shard (dim ) or p .is_partial () for p in placements ):
346- continue
347- output_spec = DTensorSpec (mesh , placements , tensor_meta = _gen_tensor_meta (out_t ))
348-
349- all_costs = []
350- input_specs = []
351- for i in range (num_tensors ):
352- input_specs .append (op_spec [i ].strategies [strat_idx ].output_spec )
353- redistribute_costs = generate_redistribute_costs (op_spec [i ], output_spec )
354- for banned in banned_idxs :
355- redistribute_costs [banned ] = math .inf
356- all_costs .append (redistribute_costs )
357-
358- s = OpSpec (output_spec , input_specs = input_specs )
359- s .redistribute_cost = all_costs
360- strats .append (s )
361- return OpStrategy (strats )
362-
363-
364317@register_rule (torch .ops .prims .iota .default )
365318def iota_rule (mesh , specs ):
366319 raise NotImplementedError ("Needs hardening, only tested on a few cases" )
0 commit comments