Skip to content

Commit cdfc75f

Browse files
committed
SSM NEON vectorization + threading, Q6_K drift fix, debug cleanup
Extract SSM forward pass into backend range functions (ssm_neon.c, ssm_scalar.c) with threadpool dispatch over V-heads. NEON kernels use 4x-unrolled vmlaq_f32 for delta recurrence state operations, vectorized L2 norm and RMSNorm+SiLU gate. Fix Q6_K SDOT accumulation drift by factoring block scale `d` out of per-sub-block additions — single `d * block_sum` per 256-element block instead of 8 separate `d * ...` additions. Remove all 21 #ifdef DEBUG blocks from transformer.c (-332 lines), delete 5 verify_*.py scripts, update .gitignore for stale artifacts.
1 parent 48a9b03 commit cdfc75f

7 files changed

Lines changed: 456 additions & 448 deletions

File tree

.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,25 @@
22
bitnet
33
src/*.o
44
src/quant/*.o
5+
src/transformer/*.o
6+
*.dSYM/
57
test_gguf
68
test_quant
79
test_tokenizer
810
test_transformer
911
test_e2e
1012
test_safety
13+
test_threadpool
14+
test_arena
15+
test_q2k
16+
test_prefill
17+
test_kv_f16
1118
dump_model
19+
bench_kernels
20+
bench/*.wasm
21+
bench/*.js
22+
bitnet_prof*
23+
__pycache__/
1224

1325
# WASM build output
1426
wasm/bitnet.js

Makefile

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ ifneq ($(filter arm% aarch%,$(UNAME_M)),)
4646

4747
TRANSFORMER_BACKEND = src/transformer/rmsnorm_neon.c src/transformer/rmsnorm_scalar.c \
4848
src/transformer/gqa_neon.c src/transformer/gqa_scalar.c \
49-
src/transformer/logits_neon.c src/transformer/logits_scalar.c
49+
src/transformer/logits_neon.c src/transformer/logits_scalar.c \
50+
src/transformer/ssm_neon.c src/transformer/ssm_scalar.c
5051
else
5152
# x86: AVX2 + scalar
5253
QUANT_BACKEND = src/quant/x_quant_avx2.c \
@@ -73,7 +74,8 @@ else
7374

7475
TRANSFORMER_BACKEND = src/transformer/rmsnorm_avx2.c src/transformer/rmsnorm_scalar.c \
7576
src/transformer/gqa_avx2.c src/transformer/gqa_scalar.c \
76-
src/transformer/logits_avx2.c src/transformer/logits_scalar.c
77+
src/transformer/logits_avx2.c src/transformer/logits_scalar.c \
78+
src/transformer/ssm_scalar.c
7779
endif
7880

7981
QUANT_SRCS = $(QUANT_COMMON) $(QUANT_BACKEND)
@@ -204,7 +206,8 @@ AVX2_QUANT_SRCS = $(QUANT_COMMON) \
204206

205207
AVX2_TRANSFORMER_BACKEND = src/transformer/rmsnorm_avx2.c src/transformer/rmsnorm_scalar.c \
206208
src/transformer/gqa_avx2.c src/transformer/gqa_scalar.c \
207-
src/transformer/logits_avx2.c src/transformer/logits_scalar.c
209+
src/transformer/logits_avx2.c src/transformer/logits_scalar.c \
210+
src/transformer/ssm_scalar.c
208211

209212
AVX2_SRCS = src/platform.c src/gguf.c $(AVX2_QUANT_SRCS) src/model.c \
210213
src/transformer.c $(AVX2_TRANSFORMER_BACKEND) src/tokenizer.c src/sampler.c \

include/transformer_internal.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,48 @@ void bn_transformer_logits_f16_wasm_range(void *ctx, int start, int end);
9090
void bn_transformer_logits_f16_scalar_range(void *ctx, int start, int end);
9191
void bn_transformer_logits_f32_range(void *ctx, int start, int end);
9292

93+
// --- SSM context structs ---
94+
95+
typedef struct {
96+
float *qkv; // [qkv_dim] input/output
97+
float *conv_state; // [(kern-1) * qkv_dim]
98+
const float *conv1d_w; // [qkv_dim * kern]
99+
int qkv_dim, kern;
100+
} BnSSMConvCtx;
101+
102+
typedef struct {
103+
float *q, *k; // [key_dim] each
104+
int head_dim;
105+
} BnSSML2NormCtx;
106+
107+
typedef struct {
108+
float *state, *out;
109+
const float *q, *k;
110+
float *v; // also temp for sk
111+
const float *alpha, *beta;
112+
int num_k_heads, head_k_dim, head_v_dim;
113+
float q_scale;
114+
} BnSSMDeltaCtx;
115+
116+
typedef struct {
117+
float *out;
118+
const float *z, *norm_w;
119+
float eps;
120+
int head_v_dim;
121+
} BnSSMGateCtx;
122+
123+
// --- SSM range function declarations ---
124+
125+
void bn_transformer_ssm_conv_silu_neon_range(void *ctx, int start, int end);
126+
void bn_transformer_ssm_conv_silu_scalar_range(void *ctx, int start, int end);
127+
128+
void bn_transformer_ssm_l2norm_neon_range(void *ctx, int start, int end);
129+
void bn_transformer_ssm_l2norm_scalar_range(void *ctx, int start, int end);
130+
131+
void bn_transformer_ssm_delta_neon_range(void *ctx, int start, int end);
132+
void bn_transformer_ssm_delta_scalar_range(void *ctx, int start, int end);
133+
134+
void bn_transformer_ssm_gate_neon_range(void *ctx, int start, int end);
135+
void bn_transformer_ssm_gate_scalar_range(void *ctx, int start, int end);
136+
93137
#endif // BN_TRANSFORMER_INTERNAL_H

src/quant/q6k_neon_sdot.c

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ void bn_quant_q6k_neon_sdot_range(void *ctx, int row_start, int row_end) {
2828
const int8_t *xb = x_q + (b * BN_QK_K);
2929
const float *xs = x_scales + (b * 8);
3030

31+
// Accumulate per-block to reduce float rounding from repeated d*
32+
float block_sum = 0.0f;
3133
for (int chunk = 0; chunk < 2; chunk++) {
3234
uint8x16_t ql0 = vld1q_u8(ql);
3335
uint8x16_t ql1 = vld1q_u8(ql + 16);
@@ -36,7 +38,7 @@ void bn_quant_q6k_neon_sdot_range(void *ctx, int row_start, int row_end) {
3638
uint8x16_t qh0 = vld1q_u8(qh);
3739
uint8x16_t qh1 = vld1q_u8(qh + 16);
3840

39-
// Unpack 8 weight vectors (identical to q6k_neon.c)
41+
// Unpack 8 weight vectors
4042
int8x16_t w0a = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(
4143
vandq_u8(ql0, mask_lo4),
4244
vshlq_n_u8(vandq_u8(qh0, mask_2), 4))), bias32);
@@ -62,40 +64,37 @@ void bn_quant_q6k_neon_sdot_range(void *ctx, int row_start, int row_end) {
6264
vshrq_n_u8(ql3, 4),
6365
vshlq_n_u8(vshrq_n_u8(qh1, 6), 4))), bias32);
6466

65-
// 4 pairs, each pair = 32 elements = 2 weight sub-blocks sharing 1 activation scale
66-
// Pair 0: w0a(sc[0]) + w0b(sc[1]), dx = xs[chunk*4 + 0]
67+
// 4 pairs: d factored out, accumulated into block_sum
6768
float dx0 = xs[chunk * 4 + 0];
6869
int32x4_t s0a = vdotq_s32(zero, w0a, vld1q_s8(xb));
6970
int32x4_t s0b = vdotq_s32(zero, w0b, vld1q_s8(xb + 16));
70-
row_sum += d * dx0 * ((float)sc[0] * (float)vaddvq_s32(s0a)
71-
+ (float)sc[1] * (float)vaddvq_s32(s0b));
71+
block_sum += dx0 * ((float)sc[0] * (float)vaddvq_s32(s0a)
72+
+ (float)sc[1] * (float)vaddvq_s32(s0b));
7273

73-
// Pair 1: w1a(sc[2]) + w1b(sc[3]), dx = xs[chunk*4 + 1]
7474
float dx1 = xs[chunk * 4 + 1];
7575
int32x4_t s1a = vdotq_s32(zero, w1a, vld1q_s8(xb + 32));
7676
int32x4_t s1b = vdotq_s32(zero, w1b, vld1q_s8(xb + 48));
77-
row_sum += d * dx1 * ((float)sc[2] * (float)vaddvq_s32(s1a)
78-
+ (float)sc[3] * (float)vaddvq_s32(s1b));
77+
block_sum += dx1 * ((float)sc[2] * (float)vaddvq_s32(s1a)
78+
+ (float)sc[3] * (float)vaddvq_s32(s1b));
7979

80-
// Pair 2: w2a(sc[4]) + w2b(sc[5]), dx = xs[chunk*4 + 2]
8180
float dx2 = xs[chunk * 4 + 2];
8281
int32x4_t s2a = vdotq_s32(zero, w2a, vld1q_s8(xb + 64));
8382
int32x4_t s2b = vdotq_s32(zero, w2b, vld1q_s8(xb + 80));
84-
row_sum += d * dx2 * ((float)sc[4] * (float)vaddvq_s32(s2a)
85-
+ (float)sc[5] * (float)vaddvq_s32(s2b));
83+
block_sum += dx2 * ((float)sc[4] * (float)vaddvq_s32(s2a)
84+
+ (float)sc[5] * (float)vaddvq_s32(s2b));
8685

87-
// Pair 3: w3a(sc[6]) + w3b(sc[7]), dx = xs[chunk*4 + 3]
8886
float dx3 = xs[chunk * 4 + 3];
8987
int32x4_t s3a = vdotq_s32(zero, w3a, vld1q_s8(xb + 96));
9088
int32x4_t s3b = vdotq_s32(zero, w3b, vld1q_s8(xb + 112));
91-
row_sum += d * dx3 * ((float)sc[6] * (float)vaddvq_s32(s3a)
92-
+ (float)sc[7] * (float)vaddvq_s32(s3b));
89+
block_sum += dx3 * ((float)sc[6] * (float)vaddvq_s32(s3a)
90+
+ (float)sc[7] * (float)vaddvq_s32(s3b));
9391

9492
xb += 128;
9593
ql += 64;
9694
qh += 32;
9795
sc += 8;
9896
}
97+
row_sum += d * block_sum;
9998
}
10099
c->out[row] = row_sum;
101100
}

0 commit comments

Comments
 (0)