Skip to content

Commit 870fe9f

Browse files
ysiraichipobin6
authored andcommitted
Fix unary references' out dtype check. (pytorch#140288)
Tracking issue: pytorch#138399 This PR fixes a number of reference implementations (which are also used as meta functions), making them more consistent with CPU device. More specifically, it fixes those operations that use `_make_elementwise_unary_reference` decorator, and don't error on mismatching out argument dtype while they error when using concrete devices (e.g. CPU). The fixed operations are: - `abs` - `ceil` - `floor` - `frac` - `isneginf` - `isposinf` - `sgn` - `sign` - `signbit` - `trunc` Pull Request resolved: pytorch#140288 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#140186, pytorch#140286
1 parent 265ff3f commit 870fe9f

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

test/test_ops.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def reduction_dtype_filter(op):
121121
aten = torch.ops.aten
122122

123123
meta_consistency_out_dtype_mismatch_xfails = {
124-
xfail("abs"),
125124
xfail("addbmm"),
126125
xfail("addmv"),
127126
xfail("alias_copy"),
@@ -133,7 +132,6 @@ def reduction_dtype_filter(op):
133132
xfail("as_strided_copy"),
134133
xfail("baddbmm"),
135134
xfail("bucketize"),
136-
xfail("ceil"),
137135
xfail("conj_physical"),
138136
xfail("cross"),
139137
xfail("cummax"),
@@ -144,8 +142,6 @@ def reduction_dtype_filter(op):
144142
xfail("expand_copy"),
145143
xfail("fft.ihfft2"),
146144
xfail("fft.ihfftn"),
147-
xfail("floor"),
148-
xfail("frac"),
149145
xfail("frexp"),
150146
xfail("geqrf"),
151147
xfail("heaviside"),
@@ -154,8 +150,6 @@ def reduction_dtype_filter(op):
154150
xfail("index_copy"),
155151
xfail("index_select"),
156152
xfail("isin"),
157-
xfail("isneginf"),
158-
xfail("isposinf"),
159153
xfail("kthvalue"),
160154
xfail("lerp"),
161155
xfail("linalg.cross"),
@@ -209,9 +203,6 @@ def reduction_dtype_filter(op):
209203
xfail("scatter_reduce", "prod"),
210204
xfail("scatter_reduce", "sum"),
211205
xfail("searchsorted"),
212-
xfail("sgn"),
213-
xfail("sign"),
214-
xfail("signbit"),
215206
xfail("slice_scatter"),
216207
xfail("softmax"),
217208
xfail("sort"),
@@ -223,7 +214,6 @@ def reduction_dtype_filter(op):
223214
xfail("transpose_copy"),
224215
xfail("tril"),
225216
xfail("triu"),
226-
xfail("trunc"),
227217
xfail("unfold_copy"),
228218
xfail("unsqueeze_copy"),
229219
xfail("vdot"),

torch/_refs/__init__.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,13 @@ def _make_elementwise_unary_reference(
476476
*,
477477
aten_op=infer_aten_op,
478478
extra_meta=None,
479+
exact_dtype=False,
479480
) -> Callable:
480481
def inner(prim: Callable):
481482
nonlocal aten_op
482483

483484
@wraps(prim)
484-
@out_wrapper()
485+
@out_wrapper(exact_dtype=exact_dtype)
485486
@elementwise_unary_scalar_wrapper
486487
@elementwise_type_promotion_wrapper(
487488
type_promoting_args=("a",),
@@ -545,7 +546,10 @@ def _fn(a, *args, **kwargs):
545546
return _fn
546547

547548

548-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
549+
@_make_elementwise_unary_reference(
550+
ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
551+
exact_dtype=True,
552+
)
549553
def abs(a):
550554
return prims.abs(a)
551555

@@ -585,7 +589,10 @@ def bitwise_not(a):
585589
return prims.bitwise_not(a)
586590

587591

588-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
592+
@_make_elementwise_unary_reference(
593+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
594+
exact_dtype=True,
595+
)
589596
def ceil(a):
590597
return prims.ceil(a)
591598

@@ -679,12 +686,18 @@ def zero(input: TensorLikeType) -> TensorLikeType:
679686
return torch.zeros_like(input)
680687

681688

682-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
689+
@_make_elementwise_unary_reference(
690+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
691+
exact_dtype=True,
692+
)
683693
def floor(a):
684694
return prims.floor(a)
685695

686696

687-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
697+
@_make_elementwise_unary_reference(
698+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
699+
exact_dtype=True,
700+
)
688701
def frac(x: TensorLikeType) -> TensorLikeType:
689702
trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
690703
return torch.sub(x, trunc_x)
@@ -719,7 +732,10 @@ def isinf(a: TensorLikeType) -> TensorLikeType:
719732
return torch.zeros_like(a, dtype=torch.bool)
720733

721734

722-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
735+
@_make_elementwise_unary_reference(
736+
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
737+
exact_dtype=True,
738+
)
723739
def isposinf(a: TensorLikeType) -> TensorLikeType:
724740
torch._check(
725741
not utils.is_complex_dtype(a.dtype),
@@ -730,7 +746,10 @@ def isposinf(a: TensorLikeType) -> TensorLikeType:
730746
return torch.zeros_like(a, dtype=torch.bool)
731747

732748

733-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
749+
@_make_elementwise_unary_reference(
750+
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
751+
exact_dtype=True,
752+
)
734753
def isneginf(a: TensorLikeType) -> TensorLikeType:
735754
torch._check(
736755
not utils.is_complex_dtype(a.dtype),
@@ -920,7 +939,10 @@ def sigmoid(a: TensorLikeType) -> TensorLikeType:
920939
return true_divide(1, add(1, exp(neg(a))))
921940

922941

923-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
942+
@_make_elementwise_unary_reference(
943+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
944+
exact_dtype=True,
945+
)
924946
def sgn(a):
925947
if utils.is_complex_dtype(a.dtype):
926948
a_abs = a.abs()
@@ -929,12 +951,18 @@ def sgn(a):
929951
return a.sign()
930952

931953

932-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
954+
@_make_elementwise_unary_reference(
955+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
956+
exact_dtype=True,
957+
)
933958
def sign(a):
934959
return prims.sign(a)
935960

936961

937-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
962+
@_make_elementwise_unary_reference(
963+
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
964+
exact_dtype=True,
965+
)
938966
def signbit(a):
939967
return prims.signbit(a)
940968

@@ -980,7 +1008,10 @@ def tanh(a):
9801008
return prims.tanh(a)
9811009

9821010

983-
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
1011+
@_make_elementwise_unary_reference(
1012+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1013+
exact_dtype=True,
1014+
)
9841015
def trunc(a):
9851016
return prims.trunc(a)
9861017

0 commit comments

Comments
 (0)