-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathtest_blas_perf.py
174 lines (145 loc) · 4.87 KB
/
test_blas_perf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import Generator
import pytest
import torch
from .attri_util import DEFAULT_METRICS, FLOAT_DTYPES, BenchLevel, model_shapes
from .conftest import Config
from .performance_utils import Benchmark, GenericBenchmark2DOnly
class BlasBenchmark(Benchmark):
"""
benchmark for blas
"""
DEFAULT_METRICS = DEFAULT_METRICS[:] + ["tflops"]
def __init__(self, *args, input_fn, **kwargs):
super().__init__(*args, **kwargs)
self.input_fn = input_fn
def get_input_iter(self, cur_dtype) -> Generator:
for b, m, n, k in self.shapes:
yield from self.input_fn(b, m, n, k, cur_dtype, self.device, False)
if Config.bench_level == BenchLevel.COMPREHENSIVE:
for b, m, n, k in self.shapes:
yield from self.input_fn(b, m, n, k, cur_dtype, self.device, True)
def set_more_shapes(self):
large_k_shapes = [
(8, 1848, 1536, 151936),
(8, 1848, 1536, 128256),
(8, 1848, 1536, 152064),
]
model_shaps = model_shapes()
return large_k_shapes + model_shaps
def get_tflops(self, op, *args, **kwargs):
total_flops = 0
# shape(m,k)(k,n)
# total_flops mxnx2k
if self.op_name == "mm":
total_flops = args[0].shape[0] * args[0].shape[1] * args[1].shape[1] * 2
# shape(m,n)(n,p)
# total_flops mxpx(2n+1)
elif self.op_name == "addmm":
total_flops = (
args[0].shape[0] * args[1].shape[1] * (args[1].shape[0] * 2 + 1)
)
# total_flops bxnxpx2m
elif self.op_name == "bmm":
total_flops = (
args[0].shape[0]
* args[0].shape[1]
* args[1].shape[2]
* 2
* args[0].shape[2]
)
return total_flops
def addmm_input_fn(b, m, n, k, cur_dtype, device, b_column_major):
inp1 = torch.randn([m, k], dtype=cur_dtype, device=device)
bias = torch.randn([m, n], dtype=cur_dtype, device=device)
if b_column_major:
inp2 = torch.randn([n, k], dtype=cur_dtype, device=device)
yield bias, inp1, inp2.t(),
else:
inp2 = torch.randn([k, n], dtype=cur_dtype, device=device)
yield bias, inp1, inp2,
def bmm_input_fn(b, m, n, k, cur_dtype, device, b_column_major):
inp1 = torch.randn([b, m, k], dtype=cur_dtype, device=device)
if b_column_major:
inp2 = torch.randn([b, n, k], dtype=cur_dtype, device=device)
yield inp1, inp2.transpose(1, 2)
else:
inp2 = torch.randn([b, k, n], dtype=cur_dtype, device=device)
yield inp1, inp2
def mm_input_fn(b, m, n, k, cur_dtype, device, b_column_major):
inp1 = torch.randn([m, k], dtype=cur_dtype, device=device)
if b_column_major:
inp2 = torch.randn([n, k], dtype=cur_dtype, device=device)
yield inp1, inp2.t()
else:
inp2 = torch.randn([k, n], dtype=cur_dtype, device=device)
yield inp1, inp2
@pytest.mark.parametrize(
"op_name, torch_op, input_fn",
[
pytest.param(
"addmm",
torch.addmm,
addmm_input_fn,
marks=pytest.mark.addmm,
),
pytest.param(
"bmm",
torch.bmm,
bmm_input_fn,
marks=pytest.mark.bmm,
),
pytest.param(
"mm",
torch.Tensor.mm,
mm_input_fn,
marks=pytest.mark.mm,
),
],
)
def test_blas_benchmark(op_name, torch_op, input_fn):
bench = BlasBenchmark(
input_fn=input_fn, op_name=op_name, torch_op=torch_op, dtypes=FLOAT_DTYPES
)
bench.run()
class MvAndOuterBenchmark(GenericBenchmark2DOnly):
"""
Benchmark for MV and Outer operations
"""
def set_more_shapes(self):
return None
def get_input_iter(self, cur_dtype) -> Generator:
for m, n in self.shapes:
yield from self.input_fn(m, n, cur_dtype, self.device)
def mv_input_fn(m, n, cur_dtype, device):
inp1 = torch.randn([m, n], dtype=cur_dtype, device=device)
inp2 = torch.randn([n], dtype=cur_dtype, device=device)
yield inp1, inp2
def outer_input_fn(m, n, cur_dtype, device):
inp1 = torch.randn([m], dtype=cur_dtype, device=device)
inp2 = torch.randn([n], dtype=cur_dtype, device=device)
yield inp1, inp2
@pytest.mark.parametrize(
"op_name, torch_op, input_fn",
[
pytest.param(
"mv",
torch.Tensor.mv,
mv_input_fn,
marks=pytest.mark.mv,
),
pytest.param(
"outer",
torch.Tensor.outer,
outer_input_fn,
marks=pytest.mark.outer,
),
],
)
def test_mv_and_outer_benchmark(op_name, torch_op, input_fn):
bench = MvAndOuterBenchmark(
input_fn=input_fn,
op_name=op_name,
torch_op=torch_op,
dtypes=FLOAT_DTYPES,
)
bench.run()