forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_kernel_launch_checks.py
42 lines (33 loc) · 1.57 KB
/
test_kernel_launch_checks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches
class AlwaysCheckCudaLaunchTest(TestCase):
def test_check_code(self):
"""Verifies that the regex works for a few different situations"""
# Try some different spacings
self.assertEqual(2, check_code_for_cuda_kernel_launches("""
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
some_other_stuff;
some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
some_function_call<TemplateArg><<<1,2,0,stream>>> (arg1,arg2,arg3);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
some_function_call<TemplateArg><<<1,2,0,stream>>> ( arg1 , arg2 , arg3 ) ;
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
"""))
# Does it work for macros?
self.assertEqual(0, check_code_for_cuda_kernel_launches("""
#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \\
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
"""))
def test_check_cuda_launches(self):
check_cuda_kernel_launches()
# TODO: Enable this after warning messages have been dealt with.
self.assertTrue(True)
# self.assertTrue(check_cuda_kernel_launches() == 0)
if __name__ == '__main__':
run_tests()