forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_ops.py
155 lines (125 loc) · 5.93 KB
/
test_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from functools import partial, wraps
import torch
from torch.testing._internal.common_utils import \
(TestCase, run_tests)
from torch.testing._internal.common_methods_invocations import \
(op_db)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA, skipCUDAIfRocm)
from torch.autograd.gradcheck import gradcheck, gradgradcheck
# Tests that apply to all operators
class TestOpInfo(TestCase):
exact_dtype = True
# Verifies that ops have their unsupported dtypes
# registered correctly by testing that each claimed unsupported dtype
# throws a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db, unsupported_dtypes_only=True)
def test_unsupported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
with self.assertRaises(RuntimeError):
op(sample.input, *sample.args, **sample.kwargs)
# Verifies that ops have their supported dtypes
# registered correctly by testing that each claimed supported dtype
# does NOT throw a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db)
def test_supported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
op(sample.input, *sample.args, **sample.kwargs)
class TestGradients(TestCase):
exact_dtype = True
# Copies inputs to inplace operations to avoid inplace modifications
# to leaves requiring gradient
def _get_safe_inplace(self, inplace_variant):
@wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn
def _check_helper(self, device, dtype, op, variant, check):
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
partial_fn = partial(variant, **sample.kwargs)
if check == 'gradcheck':
self.assertTrue(gradcheck(partial_fn, (sample.input,) + sample.args,
check_grad_dtypes=True))
elif check == 'gradgradcheck':
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
gen_non_contig_grad_outputs=False,
check_grad_dtypes=True))
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
gen_non_contig_grad_outputs=True,
check_grad_dtypes=True))
else:
self.assertTrue(False, msg="Unknown check requested!")
def _grad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradcheck')
def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
# Tests that gradients are computed correctly
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_op())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_method_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_method())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_inplace_grad(self, device, dtype, op):
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_op())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_method_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_method())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradgradcheck marked to skip.")
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
class TestOut(TestCase):
exact_dtype = True
@ops(op_db)
def test_out(self, device, dtype, op):
if not op.supports_tensor_out:
self.skipTest("Skipped! Operator %s does not support out=..." % op.name)
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
# call it normally to get the expected result
expected = op(sample.input, *sample.args, **sample.kwargs)
# call it with out=... and check we get the expected result
out_kwargs = sample.kwargs.copy()
out_kwargs['out'] = out = torch.empty_like(expected)
op(sample.input, *sample.args, **out_kwargs)
self.assertEqual(expected, out)
instantiate_device_type_tests(TestOpInfo, globals())
instantiate_device_type_tests(TestGradients, globals())
instantiate_device_type_tests(TestOut, globals())
if __name__ == '__main__':
run_tests()