Skip to content

Commit d9d35de

Browse files
authored
[test] add ut and bm for get_last_loc (sgl-project#6746)
1 parent 6df81e8 commit d9d35de

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import os
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
8+
@torch.compile(dynamic=True)
9+
def get_last_loc_torch(
10+
req_to_token: torch.Tensor,
11+
req_pool_indices_tensor: torch.Tensor,
12+
prefix_lens_tensor: torch.Tensor,
13+
) -> torch.Tensor:
14+
return torch.where(
15+
prefix_lens_tensor > 0,
16+
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
17+
torch.full_like(prefix_lens_tensor, -1),
18+
)
19+
20+
21+
@triton.jit
22+
def get_last_loc_kernel(
23+
req_to_token,
24+
req_pool_indices_tensor,
25+
prefix_lens_tensor,
26+
result,
27+
num_tokens,
28+
req_to_token_stride,
29+
BLOCK_SIZE: tl.constexpr,
30+
):
31+
pid = tl.program_id(0)
32+
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
33+
mask = offset < num_tokens
34+
35+
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
36+
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
37+
38+
token_mask = prefix_lens > 0
39+
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
40+
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
41+
42+
tl.store(result + offset, tokens, mask=mask)
43+
44+
45+
def get_last_loc_triton(
46+
req_to_token: torch.Tensor,
47+
req_pool_indices_tensor: torch.Tensor,
48+
prefix_lens_tensor: torch.Tensor,
49+
) -> torch.Tensor:
50+
BLOCK_SIZE = 256
51+
num_tokens = prefix_lens_tensor.shape[0]
52+
result = torch.empty_like(prefix_lens_tensor)
53+
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
54+
55+
get_last_loc_kernel[grid](
56+
req_to_token,
57+
req_pool_indices_tensor,
58+
prefix_lens_tensor,
59+
result,
60+
num_tokens,
61+
req_to_token.stride(0),
62+
BLOCK_SIZE,
63+
)
64+
return result
65+
66+
67+
def test_get_last_loc():
68+
max_batch = 4097
69+
max_context_len = 6148
70+
batch_size = 20
71+
72+
# Initialize input tensors
73+
req_to_token = torch.zeros(
74+
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
75+
)
76+
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
77+
pre_lens = torch.randint(
78+
-max_context_len // 2,
79+
max_context_len,
80+
(batch_size,),
81+
dtype=torch.int64,
82+
device="cuda",
83+
)
84+
85+
last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens)
86+
last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens)
87+
88+
# Compare results
89+
torch.testing.assert_close(last_loc_res, last_loc_ref)
90+
91+
92+
def get_benchmark():
93+
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
94+
95+
@triton.testing.perf_report(
96+
triton.testing.Benchmark(
97+
x_names=["batch_size"],
98+
x_vals=batch_sizes,
99+
line_arg="provider",
100+
line_vals=["reference", "triton"],
101+
line_names=["PyTorch", "Triton"],
102+
styles=[("blue", "-"), ("green", "-")],
103+
ylabel="us",
104+
plot_name="get-last-loc-performance",
105+
args={},
106+
)
107+
)
108+
def benchmark(batch_size, provider):
109+
max_batch = 2048
110+
max_context_len = 16384
111+
112+
req_to_token = torch.zeros(
113+
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
114+
)
115+
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
116+
pre_lens = torch.randint(
117+
-max_context_len // 2,
118+
max_context_len,
119+
(batch_size,),
120+
dtype=torch.int64,
121+
device="cuda",
122+
)
123+
124+
quantiles = [0.5, 0.2, 0.8]
125+
126+
if provider == "reference":
127+
ms, min_ms, max_ms = triton.testing.do_bench(
128+
lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens),
129+
quantiles=quantiles,
130+
)
131+
elif provider == "triton":
132+
ms, min_ms, max_ms = triton.testing.do_bench(
133+
lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens),
134+
quantiles=quantiles,
135+
)
136+
137+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
138+
139+
return benchmark
140+
141+
142+
def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"):
143+
"""Run benchmark and save results"""
144+
145+
# Ensure save path exists
146+
os.makedirs(save_path, exist_ok=True)
147+
148+
# Run correctness test
149+
test_get_last_loc()
150+
print("Correctness test passed!")
151+
152+
# Run performance test
153+
benchmark = get_benchmark()
154+
benchmark.run(print_data=True, save_path=save_path)
155+
156+
157+
if __name__ == "__main__":
158+
import argparse
159+
160+
parser = argparse.ArgumentParser()
161+
parser.add_argument(
162+
"--save_path",
163+
type=str,
164+
default="./configs/benchmark_ops/get_last_loc/",
165+
help="Path to save benchmark results",
166+
)
167+
args = parser.parse_args()
168+
169+
run_benchmark(args.save_path)

0 commit comments

Comments
 (0)