Skip to content

Commit c4a49f2

Browse files
committed
Optional F16 KV cache (--kv16): halve attention DRAM bandwidth
Store KV cache in FP16 instead of F32 when --kv16 is passed. Writes K/V to temp F32 buffers, applies RoPE, then converts to F16 via NEON vcvt_f16_f32 (scalar fallback for non-ARM). Read path converts F16 back to a stack F32 buffer before dot products — 512 bytes fits in L1. Both gqa_range and flash_gqa_range get the F16 treatment. Arena sizing and allocation are conditional on the new kv_f16 config flag. Benchmark (M1 Max, bitnet-b1.58-2B-4T, 128 tokens): F32 KV: 40.8 tok/s F16 KV: 47.9 tok/s (+17%) Greedy argmax matches F32; generated text diverges slightly due to accumulated F16 rounding but remains coherent and correct.
1 parent d6db0f8 commit c4a49f2

9 files changed

Lines changed: 354 additions & 37 deletions

File tree

Makefile

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ src/%.o: src/%.c
3535
$(CC) $(CFLAGS) -c -o $@ $<
3636

3737
# --- Tests ---
38-
.PHONY: debug test test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_prefill clean
38+
.PHONY: debug test test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_prefill test_kv_f16 clean
3939

4040
test: test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety
4141

@@ -71,5 +71,10 @@ test_prefill: test/test_prefill.c src/platform.c src/gguf.c src/quant.c src/mode
7171
src/sh_arena.c src/sh_log.c
7272
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS)
7373

74+
test_kv_f16: test/test_kv_f16.c src/platform.c src/gguf.c src/quant.c src/model.c \
75+
src/transformer.c src/tokenizer.c src/sampler.c src/threadpool.c \
76+
src/sh_arena.c src/sh_log.c
77+
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS)
78+
7479
clean:
75-
rm -f bitnet src/*.o test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_e2e test_prefill
80+
rm -f bitnet src/*.o test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_e2e test_prefill test_kv_f16

README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ A clean-room, pure C inference engine for [BitNet b1.58](https://arxiv.org/abs/2
44

55
Inspired by Andrej Karpathy's [llama2.c](https://github.com/karpathy/llama2.c) — a beautifully minimal LLaMA inference implementation in a single C file — **bitnet.c** takes the same philosophy and applies it to Microsoft's [BitNet](https://github.com/microsoft/BitNet) architecture with its 1.58-bit ternary weights.
66

7-
Where Microsoft's official BitNet inference framework depends on a modified llama.cpp fork (~100K+ lines of C++), bitnet.c delivers a complete inference pipeline in ~4,100 lines of modular, readable C.
7+
Where Microsoft's official BitNet inference framework depends on a modified llama.cpp fork (~100K+ lines of C++), bitnet.c delivers a complete inference pipeline in ~4,500 lines of modular, readable C.
88

99
## Features
1010

@@ -13,6 +13,7 @@ Where Microsoft's official BitNet inference framework depends on a modified llam
1313
- **I2_S, TQ1_0, & TQ2_0 formats** — native support for Microsoft's I2_S and GGML ternary quantization
1414
- **Full transformer forward pass** — RoPE, GQA, RMSNorm, sub-norms, tied embeddings
1515
- **Flash GQA attention** — online softmax with KV-head grouping, single-pass over KV cache
16+
- **Optional F16 KV cache**`--kv16` halves attention DRAM bandwidth with minimal precision loss
1617
- **ARM NEON/SDOT optimizations** — SDOT int8 matvec, native FP16 logits, INT8 output embeddings
1718
- **Pthread thread pool** — persistent workers with condvar dispatch (~2us), replaces OpenMP
1819
- **BPE tokenizer** — loaded directly from GGUF metadata
@@ -50,8 +51,11 @@ Usage: ./bitnet <model.gguf> [options]
5051
--topp <float> Top-p sampling (default: 0.9)
5152
--seed <int> Random seed (default: 42)
5253
--maxseq <int> Max sequence length (default: model max)
54+
--flash Use flash attention (online softmax)
5355
--chat Interactive chat REPL mode
5456
--repeat-penalty <float> Repetition penalty (default: 1.0, chat: 1.1)
57+
--kv16 Store KV cache in FP16 (halves attention DRAM bandwidth)
58+
--no-prefill Disable batch prompt prefill (compute logits for every token)
5559
```
5660

5761
### Chat Mode
@@ -120,14 +124,17 @@ bitnet.c/
120124
│ ├── test_tokenizer.c # BPE encode/decode tests
121125
│ ├── test_threadpool.c # Thread pool dispatch tests
122126
│ ├── test_safety.c # Safety/bounds-checking regression tests
123-
│ └── test_e2e.c # End-to-end greedy decode test
127+
│ ├── test_prefill.c # Prefill vs sequential correctness test
128+
│ ├── test_kv_f16.c # F16 KV cache correctness test
129+
│ └── test_e2e.c # End-to-end greedy decode test
124130
├── wasm/
125131
│ ├── api.c # WASM-exported API wrapper
126132
│ ├── build.sh # Emscripten build script
127133
│ ├── worker.js # Web Worker for non-blocking inference
128134
│ └── index.html # Browser demo
129135
├── docs/
130-
│ └── roadmap.md # Development roadmap
136+
│ ├── roadmap.md # Development roadmap
137+
│ └── audit.md # Security/correctness audit
131138
└── Makefile
132139
```
133140

@@ -176,7 +183,7 @@ Benchmarked on Apple M1 Max (8 P-cores, 32 GB), `bitnet-b1.58-2B-4T` (I2_S forma
176183
| Baseline (scalar C) | ~15.5 | 1.0x |
177184
| + SDOT int8 accumulation + batch matvec | ~33 | 2.1x |
178185
| + Arithmetic ternary decode + RoPE precompute | ~38 | 2.5x |
179-
| + Pthread thread pool (replace OpenMP) | ~38 | 2.5x |
186+
| + Pthread thread pool (replace OpenMP) | ~41 | 2.6x |
180187
| + Arena allocator + native FP16 logits + prefetch | ~46 | 3.0x |
181188
| + INT8 output embeddings (SDOT logits) | **~52** | **3.4x** |
182189

@@ -214,6 +221,7 @@ BitNet b1.58 is a transformer variant where all linear layer weights are constra
214221

215222
| Format | Bits/Weight | Packing | Block Size |
216223
|--------|-------------|---------|------------|
224+
| I2_S | 2.0 | 2-bit interleaved (4 values/byte) + per-tensor scale | 128 |
217225
| TQ1_0 | 1.6875 | Base-3 (5 values/byte) + residual | 256 |
218226
| TQ2_0 | 2.0625 | 2-bit fields (4 values/byte) | 256 |
219227

@@ -223,9 +231,9 @@ BitNet b1.58 is a transformer variant where all linear layer weights are constra
223231
|-----------|------|
224232
| GGUF buffer (weights + F16 embeddings) | ~620 MB |
225233
| INT8 embedding cache (128K × 2560) | ~329 MB |
226-
| KV cache (30 layers × 2048 × 640 × 4 × 2) | ~298 MB |
234+
| KV cache (30 layers × 2048 × 640 × 4 × 2) | ~298 MB (~149 MB with `--kv16`) |
227235
| RunState activations | ~3 MB |
228-
| **Total** | **~1,250 MB** |
236+
| **Total** | **~1,250 MB** (~1,101 MB with `--kv16`) |
229237

230238
## Design Decisions
231239

include/model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ typedef struct {
1717
int head_size, kv_dim, kv_mul; // derived
1818
int has_ffn_gate, act_type; // 0=SiLU, 1=ReLU²
1919
int flash_attn; // use flash attention (online softmax)
20+
int kv_f16; // store KV cache in FP16 (halves attention DRAM bandwidth)
2021
} BnConfig;
2122

2223
typedef struct {
@@ -56,7 +57,7 @@ typedef struct {
5657
SHArena *arena; // arena for all RunState buffers
5758
} BnModel;
5859

59-
int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len);
60+
int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16);
6061
void bn_model_free(BnModel *m);
6162
void bn_model_embed_token(const BnModel *m, float *out, int token);
6263

src/main.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ typedef struct {
3030
float repeat_penalty;
3131
int repeat_set; // whether user explicitly set --repeat-penalty
3232
int no_prefill;
33+
int kv_f16;
3334
} CLIArgs;
3435

3536
static void print_usage(const char *prog) {
@@ -44,6 +45,7 @@ static void print_usage(const char *prog) {
4445
fprintf(stderr, " --flash Use flash attention (online softmax)\n");
4546
fprintf(stderr, " --chat Interactive chat REPL mode\n");
4647
fprintf(stderr, " --repeat-penalty <float> Repetition penalty (default: 1.0, chat: 1.1)\n");
48+
fprintf(stderr, " --kv16 Store KV cache in FP16 (halves attention DRAM bandwidth)\n");
4749
fprintf(stderr, " --no-prefill Disable batch prompt prefill (compute logits for every token)\n");
4850
}
4951

@@ -81,6 +83,8 @@ static CLIArgs parse_args(int argc, char **argv) {
8183
args.flash_attn = 1;
8284
} else if (strcmp(argv[i], "--chat") == 0) {
8385
args.chat = 1;
86+
} else if (strcmp(argv[i], "--kv16") == 0) {
87+
args.kv_f16 = 1;
8488
} else if (strcmp(argv[i], "--no-prefill") == 0) {
8589
args.no_prefill = 1;
8690
} else if (strcmp(argv[i], "--repeat-penalty") == 0 && i + 1 < argc) {
@@ -145,7 +149,7 @@ int main(int argc, char **argv) {
145149

146150
// Load model
147151
BnModel model;
148-
if (bn_model_load(&model, gf, args.max_seq_len) != 0) {
152+
if (bn_model_load(&model, gf, args.max_seq_len, args.kv_f16) != 0) {
149153
SH_LOG_ERROR("Failed to load model");
150154
bn_gguf_free(gf);
151155
bn_platform_unload_file(&mf);

src/model.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ static float *load_f32_tensor(BnGGUFFile *f, const char *name) {
5555

5656
// --- Model loading ---
5757

58-
int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len) {
58+
int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
5959
memset(m, 0, sizeof(BnModel));
6060
BnConfig *c = &m->config;
61+
c->kv_f16 = kv_f16;
6162

6263
// Try to detect architecture prefix
6364
const char *arch = bn_gguf_get_str(f, "general.architecture");
@@ -272,7 +273,8 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len) {
272273
arena_size += 2 * (size_t)c->hidden_dim * sizeof(float); // hb, hb2
273274
arena_size += att_size * sizeof(float); // att
274275
arena_size += (size_t)c->vocab_size * sizeof(float); // logits
275-
arena_size += 2 * kv_cache_size * sizeof(float); // key_cache, value_cache
276+
size_t kv_elem_size = c->kv_f16 ? sizeof(uint16_t) : sizeof(float);
277+
arena_size += 2 * kv_cache_size * kv_elem_size; // key_cache, value_cache
276278
arena_size += (size_t)x_q_size * sizeof(int8_t); // x_q
277279
arena_size += (size_t)half_head * sizeof(float); // rope_freq
278280
arena_size += emb_i8_bytes + emb_i8_scales_bytes; // INT8 embeddings
@@ -292,8 +294,8 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len) {
292294
s->hb2 = (float *)sh_arena_calloc(m->arena, c->hidden_dim, sizeof(float));
293295
s->att = (float *)sh_arena_calloc(m->arena, att_size, sizeof(float));
294296
s->logits = (float *)sh_arena_calloc(m->arena, c->vocab_size, sizeof(float));
295-
s->key_cache = (float *)sh_arena_calloc(m->arena, kv_cache_size, sizeof(float));
296-
s->value_cache = (float *)sh_arena_calloc(m->arena, kv_cache_size, sizeof(float));
297+
s->key_cache = (float *)sh_arena_calloc(m->arena, kv_cache_size, kv_elem_size);
298+
s->value_cache = (float *)sh_arena_calloc(m->arena, kv_cache_size, kv_elem_size);
297299
s->x_q = (int8_t *)sh_arena_calloc(m->arena, x_q_size, sizeof(int8_t));
298300
s->rope_freq = (float *)sh_arena_alloc(m->arena, half_head * sizeof(float));
299301

src/transformer.c

Lines changed: 124 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ static void gqa_range(void *ctx, int h_start, int h_end) {
7575
int kv_mul = g->kv_mul;
7676
int pos = g->pos;
7777
size_t loff = g->loff;
78+
int kv_f16 = c->kv_f16;
7879

7980
for (int h = h_start; h < h_end; h++) {
8081
float *q_h = s->q + h * head_size;
@@ -83,7 +84,22 @@ static void gqa_range(void *ctx, int h_start, int h_end) {
8384
float inv_sqrt_hs = 1.0f / sqrtf((float)head_size);
8485

8586
for (int t = 0; t <= pos; t++) {
86-
float *k_t = s->key_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
87+
float k_buf[head_size];
88+
const float *k_t;
89+
if (kv_f16) {
90+
const uint16_t *k_f16 = (const uint16_t *)s->key_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
91+
#ifdef __ARM_NEON
92+
for (int d = 0; d < head_size; d += 4) {
93+
float16x4_t hv = vreinterpret_f16_u16(vld1_u16(k_f16 + d));
94+
vst1q_f32(k_buf + d, vcvt_f32_f16(hv));
95+
}
96+
#else
97+
for (int d = 0; d < head_size; d++) k_buf[d] = bn_fp16_to_fp32(k_f16[d]);
98+
#endif
99+
k_t = k_buf;
100+
} else {
101+
k_t = s->key_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
102+
}
87103
#ifdef __ARM_NEON
88104
float32x4_t a0 = vdupq_n_f32(0), a1 = vdupq_n_f32(0);
89105
float32x4_t a2 = vdupq_n_f32(0), a3 = vdupq_n_f32(0);
@@ -107,7 +123,22 @@ static void gqa_range(void *ctx, int h_start, int h_end) {
107123
float *xb_h = s->xb + h * head_size;
108124
memset(xb_h, 0, head_size * sizeof(float));
109125
for (int t = 0; t <= pos; t++) {
110-
float *v_t = s->value_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
126+
float v_buf[head_size];
127+
const float *v_t;
128+
if (kv_f16) {
129+
const uint16_t *v_f16 = (const uint16_t *)s->value_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
130+
#ifdef __ARM_NEON
131+
for (int d = 0; d < head_size; d += 4) {
132+
float16x4_t hv = vreinterpret_f16_u16(vld1_u16(v_f16 + d));
133+
vst1q_f32(v_buf + d, vcvt_f32_f16(hv));
134+
}
135+
#else
136+
for (int d = 0; d < head_size; d++) v_buf[d] = bn_fp16_to_fp32(v_f16[d]);
137+
#endif
138+
v_t = v_buf;
139+
} else {
140+
v_t = s->value_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
141+
}
111142
float a = att[t];
112143
#ifdef __ARM_NEON
113144
float32x4_t a_v = vdupq_n_f32(a);
@@ -131,13 +162,15 @@ static void gqa_range(void *ctx, int h_start, int h_end) {
131162

132163
static void flash_gqa_range(void *ctx, int h_start, int h_end) {
133164
GQACtx *g = (GQACtx *)ctx;
165+
const BnConfig *c = g->c;
134166
BnRunState *s = g->s;
135167
int head_size = g->head_size;
136168
int kv_dim = g->kv_dim;
137169
int kv_mul = g->kv_mul;
138170
int pos = g->pos;
139171
size_t loff = g->loff;
140172
int n_pos = pos + 1;
173+
int kv_f16 = c->kv_f16;
141174
float inv_sqrt_hs = 1.0f / sqrtf((float)head_size);
142175

143176
for (int h = h_start; h < h_end; h++) {
@@ -156,10 +189,25 @@ static void flash_gqa_range(void *ctx, int h_start, int h_end) {
156189
if (t_end > n_pos) t_end = n_pos;
157190

158191
for (int t = t_start; t < t_end; t++) {
159-
float *k_t = s->key_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
192+
float k_buf[head_size];
193+
const float *k_t;
194+
if (kv_f16) {
195+
const uint16_t *k_f16 = (const uint16_t *)s->key_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
196+
for (int d = 0; d < head_size; d += 4) {
197+
float16x4_t hv = vreinterpret_f16_u16(vld1_u16(k_f16 + d));
198+
vst1q_f32(k_buf + d, vcvt_f32_f16(hv));
199+
}
200+
k_t = k_buf;
201+
} else {
202+
k_t = s->key_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
203+
}
160204

161-
if (t + 1 < t_end)
162-
__builtin_prefetch(s->key_cache + loff + (size_t)(t+1) * kv_dim + kv_h * head_size, 0, 0);
205+
if (t + 1 < t_end) {
206+
if (kv_f16)
207+
__builtin_prefetch((const uint16_t *)s->key_cache + loff + (size_t)(t+1) * kv_dim + kv_h * head_size, 0, 0);
208+
else
209+
__builtin_prefetch(s->key_cache + loff + (size_t)(t+1) * kv_dim + kv_h * head_size, 0, 0);
210+
}
163211

164212
// Score: dot(Q, K) * scale
165213
float32x4_t a0 = vdupq_n_f32(0), a1 = vdupq_n_f32(0);
@@ -173,7 +221,18 @@ static void flash_gqa_range(void *ctx, int h_start, int h_end) {
173221
float score = neon_hsum_f32(vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3))) * inv_sqrt_hs;
174222

175223
// Online softmax update
176-
float *v_t = s->value_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
224+
float v_buf[head_size];
225+
const float *v_t;
226+
if (kv_f16) {
227+
const uint16_t *v_f16 = (const uint16_t *)s->value_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
228+
for (int d = 0; d < head_size; d += 4) {
229+
float16x4_t hv = vreinterpret_f16_u16(vld1_u16(v_f16 + d));
230+
vst1q_f32(v_buf + d, vcvt_f32_f16(hv));
231+
}
232+
v_t = v_buf;
233+
} else {
234+
v_t = s->value_cache + loff + (size_t)t * kv_dim + kv_h * head_size;
235+
}
177236
__builtin_prefetch(v_t, 0, 0);
178237

179238
float old_max = running_max;
@@ -394,28 +453,72 @@ static int forward_layers(BnModel *m, int token, int pos) {
394453

395454
rmsnorm(s->xb, s->x, lw->attn_norm, dim, c->norm_eps);
396455

397-
// QKV projections (unified path — bn_quant_matvec_batch handles SDOT internally)
398-
{
456+
if (c->kv_f16) {
457+
// F16 KV cache: write K/V to temp F32 buffers, apply RoPE, convert to F16
458+
float *k_tmp = s->hb, *v_tmp = s->hb2; // [hidden_dim] >= kv_dim
459+
BnMatvecTask qkv[3] = {
460+
{ s->q, &lw->wq },
461+
{ k_tmp, &lw->wk },
462+
{ v_tmp, &lw->wv },
463+
};
464+
bn_quant_matvec_batch(qkv, 3, s->xb, s->x_q, m->pool);
465+
466+
// RoPE on Q
467+
for (int i = 0; i < dim; i += 2) {
468+
int fi = (i / 2) % half_head;
469+
float v0 = s->q[i], v1 = s->q[i + 1];
470+
s->q[i] = v0 * rope_cos[fi] - v1 * rope_sin[fi];
471+
s->q[i + 1] = v0 * rope_sin[fi] + v1 * rope_cos[fi];
472+
}
473+
474+
// RoPE on K temp buffer
475+
for (int i = 0; i < kv_dim; i += 2) {
476+
int fi = (i / 2) % half_head;
477+
float v0 = k_tmp[i], v1 = k_tmp[i + 1];
478+
k_tmp[i] = v0 * rope_cos[fi] - v1 * rope_sin[fi];
479+
k_tmp[i + 1] = v0 * rope_sin[fi] + v1 * rope_cos[fi];
480+
}
481+
482+
// Convert F32 → F16 into cache
483+
uint16_t *kc = (uint16_t *)s->key_cache + loff + (size_t)pos * kv_dim;
484+
uint16_t *vc = (uint16_t *)s->value_cache + loff + (size_t)pos * kv_dim;
485+
#ifdef __ARM_NEON
486+
for (int i = 0; i < kv_dim; i += 4) {
487+
float32x4_t kv4 = vld1q_f32(k_tmp + i);
488+
float16x4_t kh4 = vcvt_f16_f32(kv4);
489+
vst1_u16(kc + i, vreinterpret_u16_f16(kh4));
490+
float32x4_t vv4 = vld1q_f32(v_tmp + i);
491+
float16x4_t vh4 = vcvt_f16_f32(vv4);
492+
vst1_u16(vc + i, vreinterpret_u16_f16(vh4));
493+
}
494+
#else
495+
for (int i = 0; i < kv_dim; i++) {
496+
kc[i] = bn_fp32_to_fp16(k_tmp[i]);
497+
vc[i] = bn_fp32_to_fp16(v_tmp[i]);
498+
}
499+
#endif
500+
} else {
501+
// F32 KV cache: matvec directly into cache, RoPE in-place
399502
BnMatvecTask qkv[3] = {
400503
{ s->q, &lw->wq },
401504
{ key_cache_row, &lw->wk },
402505
{ value_cache_row, &lw->wv },
403506
};
404507
bn_quant_matvec_batch(qkv, 3, s->xb, s->x_q, m->pool);
405-
}
406508

407-
// RoPE using precomputed cos/sin (no trig calls here)
408-
for (int i = 0; i < dim; i += 2) {
409-
int fi = (i / 2) % half_head;
410-
float v0 = s->q[i], v1 = s->q[i + 1];
411-
s->q[i] = v0 * rope_cos[fi] - v1 * rope_sin[fi];
412-
s->q[i + 1] = v0 * rope_sin[fi] + v1 * rope_cos[fi];
413-
}
414-
for (int i = 0; i < kv_dim; i += 2) {
415-
int fi = (i / 2) % half_head;
416-
float v0 = key_cache_row[i], v1 = key_cache_row[i + 1];
417-
key_cache_row[i] = v0 * rope_cos[fi] - v1 * rope_sin[fi];
418-
key_cache_row[i + 1] = v0 * rope_sin[fi] + v1 * rope_cos[fi];
509+
// RoPE using precomputed cos/sin (no trig calls here)
510+
for (int i = 0; i < dim; i += 2) {
511+
int fi = (i / 2) % half_head;
512+
float v0 = s->q[i], v1 = s->q[i + 1];
513+
s->q[i] = v0 * rope_cos[fi] - v1 * rope_sin[fi];
514+
s->q[i + 1] = v0 * rope_sin[fi] + v1 * rope_cos[fi];
515+
}
516+
for (int i = 0; i < kv_dim; i += 2) {
517+
int fi = (i / 2) % half_head;
518+
float v0 = key_cache_row[i], v1 = key_cache_row[i + 1];
519+
key_cache_row[i] = v0 * rope_cos[fi] - v1 * rope_sin[fi];
520+
key_cache_row[i + 1] = v0 * rope_sin[fi] + v1 * rope_cos[fi];
521+
}
419522
}
420523

421524
// GQA attention

0 commit comments

Comments
 (0)