Skip to content

Commit 6cd2133

Browse files
authored
Remove decomposition from softmax (#171)
Taken from #3 and #29. Decomposing softmax_backward leads to prims.fma, which doesn't have a sharding rule and we end up having a Replicate showing up as only possible sharding
1 parent 5dcba70 commit 6cd2133

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

autoparallel/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def _get_decomp_table():
6060
decomp_table.pop(torch.ops.aten.native_layer_norm.default)
6161
decomp_table.pop(torch.ops.aten.embedding_dense_backward.default)
6262
decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default)
63+
decomp_table.pop(torch.ops.aten._softmax_backward_data.default)
64+
decomp_table.pop(torch.ops.aten._softmax.default)
6365

6466
# decompose addmm to allow for TP on mm
6567
decomp_table.pop(torch.ops.aten.addmm.default)

0 commit comments

Comments
 (0)