Skip to content

Commit c964fea

Browse files
committed
Q6_K NEON SDOT kernel: vdotq_s32 integer dot products instead of float FMA
Replaces ~320 float ops per block with 16 vdotq_s32 + 16 vaddvq_s32. Activations quantized to int8 blocks (reuses bn_quant_x_to_q8_blocks). Dispatched on __ARM_FEATURE_DOTPROD; float fallback unchanged.
1 parent 4b1e1c3 commit c964fea

4 files changed

Lines changed: 126 additions & 2 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ifneq ($(filter arm% aarch%,$(UNAME_M)),)
3030
src/quant/q4_neon_sdot.c src/quant/q4_neon.c src/quant/q4_scalar.c \
3131
src/quant/q4_1_neon.c src/quant/q4_1_scalar.c \
3232
src/quant/bf16_neon.c src/quant/bf16_scalar.c \
33-
src/quant/q6k_neon.c src/quant/q6k_scalar.c \
33+
src/quant/q6k_neon_sdot.c src/quant/q6k_neon.c src/quant/q6k_scalar.c \
3434
src/quant/q8k_neon.c src/quant/q8k_scalar.c \
3535
src/quant/q4k_neon.c src/quant/q4k_scalar.c \
3636
src/quant/q5k_neon.c src/quant/q5k_scalar.c \

include/quant_internal.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ typedef struct {
8989
const float *x;
9090
} BnQ6KCtx;
9191

92+
// Q6_K SDOT context
93+
typedef struct {
94+
float *out;
95+
const BnQWeight *W;
96+
const int8_t *x_q;
97+
const float *x_scales;
98+
} BnQ6KSdotCtx;
99+
92100
// Q8_K context
93101
typedef struct {
94102
float *out;
@@ -227,6 +235,7 @@ void bn_quant_q4_wasm_sdot_range(void *ctx, int start, int end);
227235
void bn_quant_q4_scalar_range(void *ctx, int start, int end);
228236

229237
// Q6_K kernels
238+
void bn_quant_q6k_neon_sdot_range(void *ctx, int start, int end);
230239
void bn_quant_q6k_neon_range(void *ctx, int start, int end);
231240
void bn_quant_q6k_avx2_range(void *ctx, int start, int end);
232241
void bn_quant_q6k_wasm_range(void *ctx, int start, int end);

src/quant/dispatch.c

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,28 @@ void bn_quant_matvec(float *out, const BnQWeight *W, const float *x,
120120
}
121121

122122
if (W->type == BN_GGUF_TENSOR_Q6_K) {
123+
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
124+
int n_blocks = W->cols / 32;
125+
if (n_blocks > BN_MAX_SCALE_BLOCKS) return;
126+
float x_scales[n_blocks];
127+
bn_quant_x_to_q8_blocks(x, x_q_buf, x_scales, W->cols);
128+
BnQ6KSdotCtx ctx = { out, W, x_q_buf, x_scales };
129+
BnTPTask task = { bn_quant_q6k_neon_sdot_range, &ctx, W->rows };
130+
#elif defined(__ARM_NEON)
123131
(void)x_q_buf;
124132
BnQ6KCtx ctx = { out, W, x };
125-
#ifdef __ARM_NEON
126133
BnTPTask task = { bn_quant_q6k_neon_range, &ctx, W->rows };
127134
#elif defined(__AVX2__)
135+
(void)x_q_buf;
136+
BnQ6KCtx ctx = { out, W, x };
128137
BnTPTask task = { bn_quant_q6k_avx2_range, &ctx, W->rows };
129138
#elif defined(__wasm_simd128__)
139+
(void)x_q_buf;
140+
BnQ6KCtx ctx = { out, W, x };
130141
BnTPTask task = { bn_quant_q6k_wasm_range, &ctx, W->rows };
131142
#else
143+
(void)x_q_buf;
144+
BnQ6KCtx ctx = { out, W, x };
132145
BnTPTask task = { bn_quant_q6k_scalar_range, &ctx, W->rows };
133146
#endif
134147
bn_tp_dispatch(pool, &task, 1);

src/quant/q6k_neon_sdot.c

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include "quant_internal.h"
2+
#include <arm_neon.h>
3+
4+
void bn_quant_q6k_neon_sdot_range(void *ctx, int row_start, int row_end) {
5+
BnQ6KSdotCtx *c = (BnQ6KSdotCtx *)ctx;
6+
int cols = c->W->cols;
7+
int n_blocks_per_row = cols / BN_QK_K;
8+
const BnBlockQ6K *blocks = (const BnBlockQ6K *)c->W->data;
9+
const int8_t *x_q = c->x_q;
10+
const float *x_scales = c->x_scales;
11+
12+
const uint8x16_t mask_lo4 = vdupq_n_u8(0xF);
13+
const uint8x16_t mask_2 = vdupq_n_u8(3);
14+
const int8x16_t bias32 = vdupq_n_s8(32);
15+
const int32x4_t zero = vdupq_n_s32(0);
16+
17+
for (int row = row_start; row < row_end; row++) {
18+
float row_sum = 0.0f;
19+
for (int b = 0; b < n_blocks_per_row; b++) {
20+
const BnBlockQ6K *blk = &blocks[row * n_blocks_per_row + b];
21+
__builtin_prefetch(blk + 1, 0, 0);
22+
float d = bn_fp16_to_fp32(blk->d);
23+
const uint8_t *ql = blk->ql;
24+
const uint8_t *qh = blk->qh;
25+
const int8_t *sc = blk->scales;
26+
27+
// 8 activation blocks per Q6_K block (256 / 32 = 8)
28+
const int8_t *xb = x_q + (b * BN_QK_K);
29+
const float *xs = x_scales + (b * 8);
30+
31+
for (int chunk = 0; chunk < 2; chunk++) {
32+
uint8x16_t ql0 = vld1q_u8(ql);
33+
uint8x16_t ql1 = vld1q_u8(ql + 16);
34+
uint8x16_t ql2 = vld1q_u8(ql + 32);
35+
uint8x16_t ql3 = vld1q_u8(ql + 48);
36+
uint8x16_t qh0 = vld1q_u8(qh);
37+
uint8x16_t qh1 = vld1q_u8(qh + 16);
38+
39+
// Unpack 8 weight vectors (identical to q6k_neon.c)
40+
int8x16_t w0a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
41+
vandq_u8(ql0, mask_lo4),
42+
vshlq_n_u8(vandq_u8(qh0, mask_2), 4))), bias32);
43+
int8x16_t w0b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
44+
vandq_u8(ql1, mask_lo4),
45+
vshlq_n_u8(vandq_u8(qh1, mask_2), 4))), bias32);
46+
int8x16_t w1a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
47+
vandq_u8(ql2, mask_lo4),
48+
vshlq_n_u8(vandq_u8(vshrq_n_u8(qh0, 2), mask_2), 4))), bias32);
49+
int8x16_t w1b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
50+
vandq_u8(ql3, mask_lo4),
51+
vshlq_n_u8(vandq_u8(vshrq_n_u8(qh1, 2), mask_2), 4))), bias32);
52+
int8x16_t w2a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
53+
vshrq_n_u8(ql0, 4),
54+
vshlq_n_u8(vandq_u8(vshrq_n_u8(qh0, 4), mask_2), 4))), bias32);
55+
int8x16_t w2b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
56+
vshrq_n_u8(ql1, 4),
57+
vshlq_n_u8(vandq_u8(vshrq_n_u8(qh1, 4), mask_2), 4))), bias32);
58+
int8x16_t w3a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
59+
vshrq_n_u8(ql2, 4),
60+
vshlq_n_u8(vshrq_n_u8(qh0, 6), 4))), bias32);
61+
int8x16_t w3b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
62+
vshrq_n_u8(ql3, 4),
63+
vshlq_n_u8(vshrq_n_u8(qh1, 6), 4))), bias32);
64+
65+
// 4 pairs, each pair = 32 elements = 2 weight sub-blocks sharing 1 activation scale
66+
// Pair 0: w0a(sc[0]) + w0b(sc[1]), dx = xs[chunk*4 + 0]
67+
float dx0 = xs[chunk * 4 + 0];
68+
int32x4_t s0a = vdotq_s32(zero, w0a, vld1q_s8(xb));
69+
int32x4_t s0b = vdotq_s32(zero, w0b, vld1q_s8(xb + 16));
70+
row_sum += d * dx0 * ((float)sc[0] * (float)vaddvq_s32(s0a)
71+
+ (float)sc[1] * (float)vaddvq_s32(s0b));
72+
73+
// Pair 1: w1a(sc[2]) + w1b(sc[3]), dx = xs[chunk*4 + 1]
74+
float dx1 = xs[chunk * 4 + 1];
75+
int32x4_t s1a = vdotq_s32(zero, w1a, vld1q_s8(xb + 32));
76+
int32x4_t s1b = vdotq_s32(zero, w1b, vld1q_s8(xb + 48));
77+
row_sum += d * dx1 * ((float)sc[2] * (float)vaddvq_s32(s1a)
78+
+ (float)sc[3] * (float)vaddvq_s32(s1b));
79+
80+
// Pair 2: w2a(sc[4]) + w2b(sc[5]), dx = xs[chunk*4 + 2]
81+
float dx2 = xs[chunk * 4 + 2];
82+
int32x4_t s2a = vdotq_s32(zero, w2a, vld1q_s8(xb + 64));
83+
int32x4_t s2b = vdotq_s32(zero, w2b, vld1q_s8(xb + 80));
84+
row_sum += d * dx2 * ((float)sc[4] * (float)vaddvq_s32(s2a)
85+
+ (float)sc[5] * (float)vaddvq_s32(s2b));
86+
87+
// Pair 3: w3a(sc[6]) + w3b(sc[7]), dx = xs[chunk*4 + 3]
88+
float dx3 = xs[chunk * 4 + 3];
89+
int32x4_t s3a = vdotq_s32(zero, w3a, vld1q_s8(xb + 96));
90+
int32x4_t s3b = vdotq_s32(zero, w3b, vld1q_s8(xb + 112));
91+
row_sum += d * dx3 * ((float)sc[6] * (float)vaddvq_s32(s3a)
92+
+ (float)sc[7] * (float)vaddvq_s32(s3b));
93+
94+
xb += 128;
95+
ql += 64;
96+
qh += 32;
97+
sc += 8;
98+
}
99+
}
100+
c->out[row] = row_sum;
101+
}
102+
}

0 commit comments

Comments
 (0)