forked from google/XNNPACK
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathqs8-packw.cc
113 lines (95 loc) · 2.97 KB
/
qs8-packw.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <string>
#include <gtest/gtest.h>
#include "xnnpack/common.h"
#include "xnnpack/isa-checks.h"
#include "xnnpack/packw.h"
#include "packw-microkernel-tester.h"
namespace {
struct XnnTestQS8Param {
const char *name;
xnn_qs8_packw_gemm_goi_ukernel_fn ukernel;
uint64_t arch_flags;
size_t nr, kr, sr, kblock, nr_scale;
};
class XnnTestQS8 : public testing::TestWithParam<XnnTestQS8Param> {
};
std::string GetTestQS8Name(const testing::TestParamInfo<XnnTestQS8::ParamType>& info) {
return info.param.name;
}
#define XNN_QS8_UKERNEL(arch_flags, ukernel, nr, kr, sr, kblock, nr_scale) \
{ #ukernel, ukernel, arch_flags, nr, kr, sr, kblock, nr_scale },
const XnnTestQS8Param xnn_test_qs8_params[] = {
#include "src/qs8-packw/qs8-packw.h"
};
#undef XNN_QS8_UKERNEL
} // namespace
TEST_P(XnnTestQS8, k_eq_kblock) {
TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags);
PackWMicrokernelTester()
.n(GetParam().nr * GetParam().nr_scale)
.k(GetParam().kblock)
.nr(GetParam().nr * GetParam().nr_scale)
.kr(GetParam().kr)
.sr(GetParam().sr)
.Test(GetParam().ukernel);
}
TEST_P(XnnTestQS8, k_div_kblock) {
if (GetParam().kblock <= 1) {
GTEST_SKIP();
}
TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags);
PackWMicrokernelTester()
.n(GetParam().nr * GetParam().nr_scale)
.k(GetParam().kblock * 5)
.nr(GetParam().nr * GetParam().nr_scale)
.kr(GetParam().kr)
.sr(GetParam().sr)
.Test(GetParam().ukernel);
}
TEST_P(XnnTestQS8, k_lt_kblock) {
if (GetParam().kblock <= 1) {
GTEST_SKIP();
}
TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags);
for (size_t k = 1; k < GetParam().kblock; k++) {
PackWMicrokernelTester()
.n(GetParam().nr * GetParam().nr_scale)
.k(k)
.nr(GetParam().nr * GetParam().nr_scale)
.kr(GetParam().kr)
.sr(GetParam().sr)
.Test(GetParam().ukernel);
}
}
TEST_P(XnnTestQS8, k_gt_kblock) {
TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags);
for (size_t k = GetParam().kblock + 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) {
PackWMicrokernelTester()
.n(GetParam().nr * GetParam().nr_scale)
.k(k)
.nr(GetParam().nr * GetParam().nr_scale)
.kr(GetParam().kr)
.sr(GetParam().sr)
.Test(GetParam().ukernel);
}
}
TEST_P(XnnTestQS8, n_eq_nr) {
TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags);
for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) {
PackWMicrokernelTester()
.n(GetParam().nr * GetParam().nr_scale)
.k(k)
.nr(GetParam().nr * GetParam().nr_scale)
.kr(GetParam().kr)
.sr(GetParam().sr)
.Test(GetParam().ukernel);
}
}
INSTANTIATE_TEST_SUITE_P(qs8_packw,
XnnTestQS8,
testing::ValuesIn(xnn_test_qs8_params),
GetTestQS8Name);