|
| 1 | +#include "quant_internal.h" |
| 2 | +#include <arm_neon.h> |
| 3 | + |
| 4 | +void bn_quant_q6k_neon_sdot_range(void *ctx, int row_start, int row_end) { |
| 5 | + BnQ6KSdotCtx *c = (BnQ6KSdotCtx *)ctx; |
| 6 | + int cols = c->W->cols; |
| 7 | + int n_blocks_per_row = cols / BN_QK_K; |
| 8 | + const BnBlockQ6K *blocks = (const BnBlockQ6K *)c->W->data; |
| 9 | + const int8_t *x_q = c->x_q; |
| 10 | + const float *x_scales = c->x_scales; |
| 11 | + |
| 12 | + const uint8x16_t mask_lo4 = vdupq_n_u8(0xF); |
| 13 | + const uint8x16_t mask_2 = vdupq_n_u8(3); |
| 14 | + const int8x16_t bias32 = vdupq_n_s8(32); |
| 15 | + const int32x4_t zero = vdupq_n_s32(0); |
| 16 | + |
| 17 | + for (int row = row_start; row < row_end; row++) { |
| 18 | + float row_sum = 0.0f; |
| 19 | + for (int b = 0; b < n_blocks_per_row; b++) { |
| 20 | + const BnBlockQ6K *blk = &blocks[row * n_blocks_per_row + b]; |
| 21 | + __builtin_prefetch(blk + 1, 0, 0); |
| 22 | + float d = bn_fp16_to_fp32(blk->d); |
| 23 | + const uint8_t *ql = blk->ql; |
| 24 | + const uint8_t *qh = blk->qh; |
| 25 | + const int8_t *sc = blk->scales; |
| 26 | + |
| 27 | + // 8 activation blocks per Q6_K block (256 / 32 = 8) |
| 28 | + const int8_t *xb = x_q + (b * BN_QK_K); |
| 29 | + const float *xs = x_scales + (b * 8); |
| 30 | + |
| 31 | + for (int chunk = 0; chunk < 2; chunk++) { |
| 32 | + uint8x16_t ql0 = vld1q_u8(ql); |
| 33 | + uint8x16_t ql1 = vld1q_u8(ql + 16); |
| 34 | + uint8x16_t ql2 = vld1q_u8(ql + 32); |
| 35 | + uint8x16_t ql3 = vld1q_u8(ql + 48); |
| 36 | + uint8x16_t qh0 = vld1q_u8(qh); |
| 37 | + uint8x16_t qh1 = vld1q_u8(qh + 16); |
| 38 | + |
| 39 | + // Unpack 8 weight vectors (identical to q6k_neon.c) |
| 40 | + int8x16_t w0a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 41 | + vandq_u8(ql0, mask_lo4), |
| 42 | + vshlq_n_u8(vandq_u8(qh0, mask_2), 4))), bias32); |
| 43 | + int8x16_t w0b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 44 | + vandq_u8(ql1, mask_lo4), |
| 45 | + vshlq_n_u8(vandq_u8(qh1, mask_2), 4))), bias32); |
| 46 | + int8x16_t w1a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 47 | + vandq_u8(ql2, mask_lo4), |
| 48 | + vshlq_n_u8(vandq_u8(vshrq_n_u8(qh0, 2), mask_2), 4))), bias32); |
| 49 | + int8x16_t w1b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 50 | + vandq_u8(ql3, mask_lo4), |
| 51 | + vshlq_n_u8(vandq_u8(vshrq_n_u8(qh1, 2), mask_2), 4))), bias32); |
| 52 | + int8x16_t w2a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 53 | + vshrq_n_u8(ql0, 4), |
| 54 | + vshlq_n_u8(vandq_u8(vshrq_n_u8(qh0, 4), mask_2), 4))), bias32); |
| 55 | + int8x16_t w2b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 56 | + vshrq_n_u8(ql1, 4), |
| 57 | + vshlq_n_u8(vandq_u8(vshrq_n_u8(qh1, 4), mask_2), 4))), bias32); |
| 58 | + int8x16_t w3a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 59 | + vshrq_n_u8(ql2, 4), |
| 60 | + vshlq_n_u8(vshrq_n_u8(qh0, 6), 4))), bias32); |
| 61 | + int8x16_t w3b = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8( |
| 62 | + vshrq_n_u8(ql3, 4), |
| 63 | + vshlq_n_u8(vshrq_n_u8(qh1, 6), 4))), bias32); |
| 64 | + |
| 65 | + // 4 pairs, each pair = 32 elements = 2 weight sub-blocks sharing 1 activation scale |
| 66 | + // Pair 0: w0a(sc[0]) + w0b(sc[1]), dx = xs[chunk*4 + 0] |
| 67 | + float dx0 = xs[chunk * 4 + 0]; |
| 68 | + int32x4_t s0a = vdotq_s32(zero, w0a, vld1q_s8(xb)); |
| 69 | + int32x4_t s0b = vdotq_s32(zero, w0b, vld1q_s8(xb + 16)); |
| 70 | + row_sum += d * dx0 * ((float)sc[0] * (float)vaddvq_s32(s0a) |
| 71 | + + (float)sc[1] * (float)vaddvq_s32(s0b)); |
| 72 | + |
| 73 | + // Pair 1: w1a(sc[2]) + w1b(sc[3]), dx = xs[chunk*4 + 1] |
| 74 | + float dx1 = xs[chunk * 4 + 1]; |
| 75 | + int32x4_t s1a = vdotq_s32(zero, w1a, vld1q_s8(xb + 32)); |
| 76 | + int32x4_t s1b = vdotq_s32(zero, w1b, vld1q_s8(xb + 48)); |
| 77 | + row_sum += d * dx1 * ((float)sc[2] * (float)vaddvq_s32(s1a) |
| 78 | + + (float)sc[3] * (float)vaddvq_s32(s1b)); |
| 79 | + |
| 80 | + // Pair 2: w2a(sc[4]) + w2b(sc[5]), dx = xs[chunk*4 + 2] |
| 81 | + float dx2 = xs[chunk * 4 + 2]; |
| 82 | + int32x4_t s2a = vdotq_s32(zero, w2a, vld1q_s8(xb + 64)); |
| 83 | + int32x4_t s2b = vdotq_s32(zero, w2b, vld1q_s8(xb + 80)); |
| 84 | + row_sum += d * dx2 * ((float)sc[4] * (float)vaddvq_s32(s2a) |
| 85 | + + (float)sc[5] * (float)vaddvq_s32(s2b)); |
| 86 | + |
| 87 | + // Pair 3: w3a(sc[6]) + w3b(sc[7]), dx = xs[chunk*4 + 3] |
| 88 | + float dx3 = xs[chunk * 4 + 3]; |
| 89 | + int32x4_t s3a = vdotq_s32(zero, w3a, vld1q_s8(xb + 96)); |
| 90 | + int32x4_t s3b = vdotq_s32(zero, w3b, vld1q_s8(xb + 112)); |
| 91 | + row_sum += d * dx3 * ((float)sc[6] * (float)vaddvq_s32(s3a) |
| 92 | + + (float)sc[7] * (float)vaddvq_s32(s3b)); |
| 93 | + |
| 94 | + xb += 128; |
| 95 | + ql += 64; |
| 96 | + qh += 32; |
| 97 | + sc += 8; |
| 98 | + } |
| 99 | + } |
| 100 | + c->out[row] = row_sum; |
| 101 | + } |
| 102 | +} |
0 commit comments