Skip to content

Commit ff26158

Browse files
Add unit tests for triton_utils_v2 (#5073)
1 parent c35e540 commit ff26158

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
import paddle
5+
import triton.language as tl
6+
7+
TRITON_UTILS_V2_PATH = "fastdeploy.model_executor.ops.triton_ops.triton_utils_v2"
8+
import fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 as tu2
9+
10+
11+
class TestGetValueHint(unittest.TestCase):
12+
"""Test the helper function get_value_hint from triton_utils_v2."""
13+
14+
def test_get_value_hint_int_and_float(self):
15+
"""Ensure get_value_hint handles mixed int and float values."""
16+
vals = [10, 1, -3, 1.5]
17+
hint = tu2.get_value_hint(vals)
18+
self.assertEqual(hint, "i64,i64,i64,fp32,")
19+
20+
21+
class TestKernelInterfaceV2(unittest.TestCase):
22+
"""Test cases for KernelInterface and decorator behavior."""
23+
24+
def mock_kernel(self, a, b, N: tl.constexpr, K: tl.constexpr):
25+
return
26+
27+
def test_kernel_interface_constexpr_detection(self):
28+
"""Verify constexpr argument detection and exclusion list generation."""
29+
kernel_interface = tu2.KernelInterface(self.mock_kernel, other_config={})
30+
self.assertEqual(kernel_interface.arg_names, ["a", "b", "N", "K"])
31+
self.assertEqual(kernel_interface.constexprs, [2, 3])
32+
self.assertEqual(kernel_interface.arg_exclude_constexpr, ["a", "b"])
33+
34+
@patch("paddle.distributed.get_rank", return_value=0)
35+
def test_decorator_cache_hit(self, _mock_rank):
36+
"""Ensure cached compiled ops are reused without recompilation."""
37+
kernel_interface = tu2.KernelInterface(self.mock_kernel, other_config={})
38+
kernel_interface.grid = [1, 1, 1]
39+
op_name = "haha_N8_K16"
40+
cached_fn = MagicMock()
41+
kernel_interface.func_map[op_name] = cached_fn
42+
kernel_interface.decorator(1, 2, N=8, K=16)
43+
cached_fn.assert_called_once_with(1, 2)
44+
45+
@patch("os.system")
46+
@patch("os.makedirs")
47+
@patch("os.getenv", return_value="/tmp/triton_cache/rank0")
48+
@patch("builtins.open", new_callable=MagicMock)
49+
@patch("importlib.import_module")
50+
@patch("paddle.distributed.get_rank", return_value=0)
51+
@patch(f"{TRITON_UTILS_V2_PATH}.build_package")
52+
@patch(f"{TRITON_UTILS_V2_PATH}.rename_c_to_cu")
53+
@patch(f"{TRITON_UTILS_V2_PATH}.multi_process_do")
54+
@patch(f"{TRITON_UTILS_V2_PATH}.extract_triton_kernel")
55+
@patch(f"{TRITON_UTILS_V2_PATH}.find_so_path")
56+
def test_decorator_compile_and_import(
57+
self,
58+
mock_find_so_path,
59+
mock_extract,
60+
mock_mp_do,
61+
mock_rename,
62+
mock_build,
63+
mock_rank,
64+
mock_import,
65+
mock_open,
66+
mock_getenv,
67+
mock_makedirs,
68+
mock_system,
69+
):
70+
"""Test full compilation → linking → building → importing pipeline when no cached .so exists."""
71+
mock_find_so_path.side_effect = [
72+
None,
73+
"/tmp/triton_cache/rank0/haha_N8_K16/mock_lib.so",
74+
]
75+
mock_module = MagicMock()
76+
mock_pybind_func = MagicMock()
77+
mock_module.haha_N8_K16_func = mock_pybind_func
78+
mock_import.return_value = mock_module
79+
mock_system.return_value = 0
80+
mock_mp_do.return_value = None
81+
mock_build.return_value = None
82+
mock_extract.return_value = None
83+
kernel_interface = tu2.KernelInterface(self.mock_kernel, other_config={})
84+
kernel_interface.grid = ["N * M * N"]
85+
kernel_interface.decorator(1, 2, N=8, K=16)
86+
mock_extract.assert_called_once()
87+
mock_mp_do.assert_called_once()
88+
mock_build.assert_called_once()
89+
mock_import.assert_called_once_with("haha_N8_K16_package")
90+
mock_pybind_func.assert_called_once_with(1, 2)
91+
92+
@patch("os.system")
93+
@patch("os.makedirs")
94+
@patch("os.getenv", return_value="/tmp/triton_cache/rank0")
95+
@patch("builtins.open", new_callable=MagicMock)
96+
@patch("importlib.import_module")
97+
@patch("paddle.distributed.get_rank", return_value=0)
98+
@patch(f"{TRITON_UTILS_V2_PATH}.build_package")
99+
@patch(f"{TRITON_UTILS_V2_PATH}.rename_c_to_cu")
100+
@patch(f"{TRITON_UTILS_V2_PATH}.multi_process_do")
101+
@patch(f"{TRITON_UTILS_V2_PATH}.extract_triton_kernel")
102+
@patch(f"{TRITON_UTILS_V2_PATH}.find_so_path")
103+
@patch(f"{TRITON_UTILS_V2_PATH}.get_pointer_hint")
104+
def test_tensor_and_none_branch(
105+
self,
106+
mock_get_pointer_hint,
107+
mock_find_so_path,
108+
mock_extract,
109+
mock_mp_do,
110+
mock_rename,
111+
mock_build,
112+
mock_rank,
113+
mock_import,
114+
mock_open,
115+
mock_getenv,
116+
mock_makedirs,
117+
mock_system,
118+
):
119+
"""Ensure decorator correctly handles Tensor and None arguments during dtype and pointer analysis."""
120+
ki = tu2.KernelInterface(self.mock_kernel, other_config={})
121+
mock_find_so_path.return_value = "/tmp/triton_cache/rank0/haha_N8_K16/mock_lib.so"
122+
mock_module = MagicMock()
123+
mock_pybind_func = MagicMock()
124+
mock_module.haha_N8_K16_func = mock_pybind_func
125+
mock_import.return_value = mock_module
126+
ki.grid = [1, 1, 1]
127+
a = paddle.to_tensor([1], dtype="float32")
128+
b = None
129+
mock_get_pointer_hint.return_value = "addr_hint"
130+
ki.decorator(a, b, N=8, K=16)
131+
mock_get_pointer_hint.assert_called_once()
132+
dtypes_arg = mock_get_pointer_hint.call_args[0][0]
133+
self.assertEqual(len(dtypes_arg), 2)
134+
self.assertEqual(dtypes_arg[0], a.dtype)
135+
self.assertEqual(dtypes_arg[1], paddle.int8)
136+
mock_import.assert_called_once_with("haha_N8_K16_package")
137+
mock_pybind_func.assert_called_once_with(a, b)
138+
139+
def test_getitem_sets_grid_and_returns_decorator(self):
140+
"""Ensure __getitem__ sets internal grid and returns a callable decorator."""
141+
kernel_interface = tu2.KernelInterface(self.mock_kernel, other_config={})
142+
dec = kernel_interface[["unused"]]
143+
self.assertTrue(isinstance(kernel_interface.grid, tuple))
144+
self.assertIn("max_possible_num_post_padded", kernel_interface.grid[0])
145+
self.assertTrue(callable(dec))
146+
147+
148+
class TestPaddleUseTritonV2(unittest.TestCase):
149+
"""Tests for paddle_use_triton_v2 decorator wrapper."""
150+
151+
def test_paddle_use_triton_v2_wraps_function(self):
152+
"""Verify paddle_use_triton_v2 returns a KernelInterface with correct key arguments."""
153+
154+
@tu2.paddle_use_triton_v2(other_config={"foo": 1}, key=["N", "K"])
155+
def my_kernel(a, N, K):
156+
return
157+
158+
self.assertIsInstance(my_kernel, tu2.KernelInterface)
159+
self.assertEqual(my_kernel.key_args, ["N", "K"])
160+
161+
162+
if __name__ == "__main__":
163+
unittest.main()

0 commit comments

Comments
 (0)