Skip to content

Commit 1053272

Browse files
authored
Update convert_element_type_rule to latest PyTorch signature (#31)
Taken from #29
1 parent 09fab9b commit 1053272

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

autoparallel/propagation_rules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,7 @@ def native_layer_norm_backward_rule(mesh, op_schema):
535535
def convert_element_type_rule(mesh, op_schema):
536536
from torch.distributed.tensor._ops._tensor_ops import default_strategy
537537

538-
# TODO: API has changed in latest main
539-
out_strat = default_strategy(mesh, op_schema)
538+
out_strat = default_strategy(op_schema)
540539
return out_strat
541540

542541

0 commit comments

Comments
 (0)