Skip to content

Commit b54eefe

Browse files
ysiraichipobin6
authored andcommitted
log_softmax: fix meta function output argument dtype check. (pytorch#140289)
Tracking issue: pytorch#138399 Pull Request resolved: pytorch#140289 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#140186, pytorch#140286, pytorch#140288
1 parent 870fe9f commit b54eefe

File tree

2 files changed

+1
-2
lines changed

2 files changed

+1
-2
lines changed

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ def reduction_dtype_filter(op):
168168
xfail("linalg.solve"),
169169
xfail("linalg.solve_ex"),
170170
xfail("linalg.solve_triangular"),
171-
xfail("log_softmax"),
172171
xfail("logcumsumexp"),
173172
xfail("lu_solve"),
174173
xfail("lu_unpack"),

torch/_decomp/decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
12201220

12211221

12221222
@register_decomposition(aten._log_softmax)
1223-
@out_wrapper()
1223+
@out_wrapper(exact_dtype=True)
12241224
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
12251225
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
12261226
# returns a contiguous tensor.

0 commit comments

Comments
 (0)