Skip to content

Commit d5e9854

Browse files
committed
Streaming callback API, WASM chat mode, and C safety audit fixes
Add callback-based generate_response() in main.c for streaming token output, refactor chat REPL to use it. Add 5-function chat state machine to WASM API (init/reset/submit/next/end_turn) with loop detection and auto-reset. Add chat UI with mode toggle, message history, and stop button to the browser demo. Safety audit fixes: validate tensor end-bounds in GGUF parser, add VLA size guards, fix platform unload for is_mmap==0, consolidate model file ownership in bn_model_free, replace atoi/atof with strtol/strtof, add -Wshadow and make asan target, replace assert with runtime fallback in batch matvec, add softmax size guard, add tokenizer malloc overflow checks, fix stale test files to use bn_-prefixed API, and add new tests for arena allocator, sampler repeat penalty/top-p, and tokenizer edge cases.
1 parent ec11e07 commit d5e9854

18 files changed

Lines changed: 1277 additions & 97 deletions

Makefile

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ LDFLAGS = -lm
66
# -march=native on Apple clang misses __ARM_FEATURE_FP16_VECTOR_ARITHMETIC.
77
UNAME_S := $(shell uname -s)
88
ifeq ($(UNAME_S),Darwin)
9-
CFLAGS = -O3 -mcpu=apple-m1 -Wall -Wextra -std=c11 -Iinclude
9+
CFLAGS = -O3 -mcpu=apple-m1 -Wall -Wextra -Wshadow -std=c11 -Iinclude
1010
else
11-
CFLAGS = -O3 -march=native -Wall -Wextra -std=c11 -Iinclude
11+
CFLAGS = -O3 -march=native -Wall -Wextra -Wshadow -std=c11 -Iinclude
1212
endif
1313

1414
# On Linux, enable GNU extensions for strdup, qsort_r, clock_gettime, etc.
@@ -30,14 +30,19 @@ bitnet: $(OBJS)
3030
debug: CFLAGS += -DDEBUG -g -O0
3131
debug: bitnet
3232

33+
# Sanitizer build (ASan + UBSan)
34+
asan: CFLAGS += -DDEBUG -g -O0 -fsanitize=address,undefined -fno-omit-frame-pointer
35+
asan: LDFLAGS += -fsanitize=address,undefined
36+
asan: bitnet
37+
3338
# Pattern rule for object files
3439
src/%.o: src/%.c
3540
$(CC) $(CFLAGS) -c -o $@ $<
3641

3742
# --- Tests ---
38-
.PHONY: debug test test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_prefill test_kv_f16 pgo clean
43+
.PHONY: debug asan test test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_arena test_prefill test_kv_f16 pgo clean
3944

40-
test: test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety
45+
test: test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_arena
4146

4247
test_gguf: test/test_gguf.c src/gguf.c src/platform.c src/sh_log.c
4348
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) && ./$@
@@ -61,6 +66,9 @@ test_safety: test/test_safety.c src/platform.c src/gguf.c src/quant.c src/model.
6166
src/sh_arena.c src/sh_log.c
6267
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) && ./$@
6368

69+
test_arena: test/test_arena.c src/sh_arena.c
70+
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS) && ./$@
71+
6472
test_e2e: test/test_e2e.c src/platform.c src/gguf.c src/quant.c src/model.c \
6573
src/transformer.c src/tokenizer.c src/sampler.c src/threadpool.c \
6674
src/sh_arena.c src/sh_log.c
@@ -93,4 +101,4 @@ pgo:
93101
@echo "=== PGO build complete ==="
94102

95103
clean:
96-
rm -f bitnet src/*.o test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_e2e test_prefill test_kv_f16 default.profraw default.profdata
104+
rm -f bitnet src/*.o test_gguf test_quant test_tokenizer test_transformer test_threadpool test_safety test_arena test_e2e test_prefill test_kv_f16 default.profraw default.profdata

src/gguf.c

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ BnGGUFFile *bn_gguf_open(const uint8_t *buf, size_t size) {
171171

172172
uint32_t magic = read_u32(&r);
173173
if (magic != BN_GGUF_MAGIC) {
174-
char buf[16]; snprintf(buf, sizeof(buf), "0x%08x", magic);
175-
SH_LOG_ERROR("Bad GGUF magic", "got", buf);
174+
char hex[16]; snprintf(hex, sizeof(hex), "0x%08x", magic);
175+
SH_LOG_ERROR("Bad GGUF magic", "got", hex);
176176
return NULL;
177177
}
178178

@@ -186,8 +186,8 @@ BnGGUFFile *bn_gguf_open(const uint8_t *buf, size_t size) {
186186
f->alignment = BN_GGUF_DEFAULT_ALIGNMENT;
187187

188188
if (f->version < 2 || f->version > 3) {
189-
char buf[16]; snprintf(buf, sizeof(buf), "%u", f->version);
190-
SH_LOG_ERROR("Unsupported GGUF version", "version", buf);
189+
char ver[16]; snprintf(ver, sizeof(ver), "%u", f->version);
190+
SH_LOG_ERROR("Unsupported GGUF version", "version", ver);
191191
free(f);
192192
return NULL;
193193
}
@@ -349,10 +349,39 @@ int bn_gguf_find_tensor(BnGGUFFile *f, const char *name) {
349349
return -1;
350350
}
351351

352-
// #13: Validate that tensor data pointer falls within the mapped buffer
352+
// Compute byte size for a tensor given its type and element count.
353+
// Returns 0 for unknown types.
354+
static size_t tensor_type_size(uint32_t type, uint64_t nelements) {
355+
switch (type) {
356+
case BN_GGUF_TENSOR_F32: return (size_t)nelements * 4;
357+
case BN_GGUF_TENSOR_F16: return (size_t)nelements * 2;
358+
// I2_S: 2 bits per element + 4-byte per-tensor scale
359+
case BN_GGUF_TENSOR_I2_S: return (size_t)(nelements / 4) + 4;
360+
// TQ1_0: 54 bytes per 256-element block
361+
case BN_GGUF_TENSOR_TQ1_0: return (size_t)(nelements / 256) * 54;
362+
// TQ2_0: 66 bytes per 256-element block
363+
case BN_GGUF_TENSOR_TQ2_0: return (size_t)(nelements / 256) * 66;
364+
default: return 0;
365+
}
366+
}
367+
368+
// #13: Validate that tensor data falls entirely within the mapped buffer
353369
void *bn_gguf_tensor_data(BnGGUFFile *f, int idx) {
354370
if (idx < 0 || (uint64_t)idx >= f->n_tensors) return NULL;
355-
size_t offset = f->data_offset + f->tensors[idx].offset;
371+
BnGGUFTensorInfo *t = &f->tensors[idx];
372+
size_t offset = f->data_offset + t->offset;
356373
if (offset >= f->raw_size) return NULL;
374+
375+
// Compute total elements from dims
376+
uint64_t nelements = 1;
377+
for (uint32_t d = 0; d < t->n_dims; d++)
378+
nelements *= t->dims[d];
379+
380+
size_t tsize = tensor_type_size(t->type, nelements);
381+
if (tsize > 0 && offset + tsize > f->raw_size) {
382+
SH_LOG_ERROR("Tensor data exceeds buffer", "tensor", t->name ? t->name : "?");
383+
return NULL;
384+
}
385+
357386
return f->raw + offset;
358387
}

src/main.c

Lines changed: 98 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include <stdio.h>
1111
#include <stdlib.h>
1212
#include <string.h>
13+
#include <limits.h>
14+
15+
// Callback for streaming token output. Return non-zero to stop generation.
16+
typedef int (*bn_token_callback)(const char *piece, int token_id, void *user_data);
1317

1418
#if defined(__APPLE__)
1519
#include <sys/sysctl.h>
@@ -49,6 +53,26 @@ static void print_usage(const char *prog) {
4953
fprintf(stderr, " --no-prefill Disable batch prompt prefill (compute logits for every token)\n");
5054
}
5155

56+
static int parse_int(const char *s, const char *name) {
57+
char *end;
58+
long val = strtol(s, &end, 10);
59+
if (*end != '\0' || val < INT_MIN || val > INT_MAX) {
60+
fprintf(stderr, "Invalid value for %s: %s\n", name, s);
61+
exit(1);
62+
}
63+
return (int)val;
64+
}
65+
66+
static float parse_float(const char *s, const char *name) {
67+
char *end;
68+
float val = strtof(s, &end);
69+
if (*end != '\0') {
70+
fprintf(stderr, "Invalid value for %s: %s\n", name, s);
71+
exit(1);
72+
}
73+
return val;
74+
}
75+
5276
static CLIArgs parse_args(int argc, char **argv) {
5377
CLIArgs args = {0};
5478
args.prompt = "Hello";
@@ -69,16 +93,18 @@ static CLIArgs parse_args(int argc, char **argv) {
6993
if (strcmp(argv[i], "-p") == 0 && i + 1 < argc) {
7094
args.prompt = argv[++i];
7195
} else if (strcmp(argv[i], "-n") == 0 && i + 1 < argc) {
72-
args.n_tokens = atoi(argv[++i]);
96+
args.n_tokens = parse_int(argv[++i], "-n");
7397
} else if (strcmp(argv[i], "--temp") == 0 && i + 1 < argc) {
74-
args.temperature = (float)atof(argv[++i]);
98+
args.temperature = parse_float(argv[++i], "--temp");
7599
args.temp_set = 1;
76100
} else if (strcmp(argv[i], "--topp") == 0 && i + 1 < argc) {
77-
args.topp = (float)atof(argv[++i]);
101+
args.topp = parse_float(argv[++i], "--topp");
78102
} else if (strcmp(argv[i], "--seed") == 0 && i + 1 < argc) {
79-
args.seed = (uint64_t)atoll(argv[++i]);
103+
char *end;
104+
args.seed = (uint64_t)strtoull(argv[++i], &end, 10);
105+
if (*end != '\0') { fprintf(stderr, "Invalid value for --seed: %s\n", argv[i]); exit(1); }
80106
} else if (strcmp(argv[i], "--maxseq") == 0 && i + 1 < argc) {
81-
args.max_seq_len = atoi(argv[++i]);
107+
args.max_seq_len = parse_int(argv[++i], "--maxseq");
82108
} else if (strcmp(argv[i], "--flash") == 0) {
83109
args.flash_attn = 1;
84110
} else if (strcmp(argv[i], "--chat") == 0) {
@@ -88,7 +114,7 @@ static CLIArgs parse_args(int argc, char **argv) {
88114
} else if (strcmp(argv[i], "--no-prefill") == 0) {
89115
args.no_prefill = 1;
90116
} else if (strcmp(argv[i], "--repeat-penalty") == 0 && i + 1 < argc) {
91-
args.repeat_penalty = (float)atof(argv[++i]);
117+
args.repeat_penalty = parse_float(argv[++i], "--repeat-penalty");
92118
args.repeat_set = 1;
93119
} else {
94120
fprintf(stderr, "Unknown option: %s\n", argv[i]);
@@ -100,6 +126,67 @@ static CLIArgs parse_args(int argc, char **argv) {
100126
return args;
101127
}
102128

129+
// Loop detection constants
130+
#define LOOP_BUF_SIZE 32
131+
#define LOOP_NGRAM 4
132+
133+
// Generate tokens with callback-based streaming.
134+
// Returns: n_generated, -1 on loop, -2 on error.
135+
static int generate_response(BnModel *model, BnTokenizer *tok, BnSampler *sampler,
136+
int max_tokens, int *pos,
137+
bn_token_callback cb, void *user_data) {
138+
int loop_buf[LOOP_BUF_SIZE];
139+
int loop_idx = 0, gen_count = 0;
140+
memset(loop_buf, -1, sizeof(loop_buf));
141+
142+
float *logits = model->state.logits;
143+
if (!logits) return -2;
144+
145+
for (int i = 0; i < max_tokens; i++) {
146+
int next = bn_sampler_sample(sampler, logits);
147+
148+
if (next == tok->eot_id || next == tok->eos_id)
149+
break;
150+
151+
// Ring buffer loop detection
152+
loop_buf[loop_idx] = next;
153+
loop_idx = (loop_idx + 1) % LOOP_BUF_SIZE;
154+
gen_count++;
155+
156+
if (gen_count >= 2 * LOOP_NGRAM) {
157+
int looping = 1;
158+
for (int k = 0; k < LOOP_NGRAM; k++) {
159+
int a = loop_buf[((loop_idx - 1 - k) % LOOP_BUF_SIZE + LOOP_BUF_SIZE) % LOOP_BUF_SIZE];
160+
int b = loop_buf[((loop_idx - 1 - k - LOOP_NGRAM) % LOOP_BUF_SIZE + LOOP_BUF_SIZE) % LOOP_BUF_SIZE];
161+
if (a != b) { looping = 0; break; }
162+
}
163+
if (looping) return -1;
164+
}
165+
166+
bn_sampler_accept(sampler, next);
167+
168+
const char *piece = bn_tokenizer_decode(tok, next);
169+
if (piece && cb) {
170+
if (cb(piece, next, user_data))
171+
break;
172+
}
173+
174+
logits = bn_transformer_forward(model, next, *pos);
175+
(*pos)++;
176+
if (!logits) return -2;
177+
}
178+
179+
return gen_count;
180+
}
181+
182+
static int print_token(const char *piece, int token_id, void *user_data) {
183+
(void)token_id;
184+
(void)user_data;
185+
printf("%s", piece);
186+
fflush(stdout);
187+
return 0;
188+
}
189+
103190
int main(int argc, char **argv) {
104191
sh_log_init(NULL);
105192
CLIArgs args = parse_args(argc, argv);
@@ -187,7 +274,6 @@ int main(int argc, char **argv) {
187274
SH_LOG_ERROR("Failed to init tokenizer");
188275
bn_model_free(&model);
189276
bn_gguf_free(gf);
190-
bn_platform_unload_file(&mf);
191277
return 1;
192278
}
193279
{
@@ -220,7 +306,6 @@ int main(int argc, char **argv) {
220306
bn_tokenizer_free(&tokenizer);
221307
bn_model_free(&model);
222308
bn_gguf_free(gf);
223-
bn_platform_unload_file(&mf);
224309
return 1;
225310
}
226311

@@ -287,48 +372,9 @@ int main(int argc, char **argv) {
287372
break;
288373
}
289374

290-
// Generate until eot_id, eos_id, or seq_len
291-
// Loop detector: ring buffer of recent tokens, check for repeating n-grams
292-
#define LOOP_BUF_SIZE 32
293-
#define LOOP_NGRAM 4
294-
int loop_buf[LOOP_BUF_SIZE];
295-
int loop_idx = 0, gen_count = 0;
296-
memset(loop_buf, -1, sizeof(loop_buf));
297-
298-
for (int i = 0; i < args.n_tokens; i++) {
299-
int next = bn_sampler_sample(&sampler, logits);
300-
301-
if (next == tokenizer.eot_id || next == tokenizer.eos_id)
302-
break;
303-
304-
// Record token in ring buffer and check for loops
305-
loop_buf[loop_idx] = next;
306-
loop_idx = (loop_idx + 1) % LOOP_BUF_SIZE;
307-
gen_count++;
308-
309-
if (gen_count >= 2 * LOOP_NGRAM) {
310-
// Check if last LOOP_NGRAM tokens match the LOOP_NGRAM before them
311-
int looping = 1;
312-
for (int k = 0; k < LOOP_NGRAM; k++) {
313-
int a = loop_buf[((loop_idx - 1 - k) % LOOP_BUF_SIZE + LOOP_BUF_SIZE) % LOOP_BUF_SIZE];
314-
int b = loop_buf[((loop_idx - 1 - k - LOOP_NGRAM) % LOOP_BUF_SIZE + LOOP_BUF_SIZE) % LOOP_BUF_SIZE];
315-
if (a != b) { looping = 0; break; }
316-
}
317-
if (looping) { gen_count = -1; break; }
318-
}
319-
320-
bn_sampler_accept(&sampler, next);
321-
322-
const char *piece = bn_tokenizer_decode(&tokenizer, next);
323-
if (piece) {
324-
printf("%s", piece);
325-
fflush(stdout);
326-
}
327-
328-
logits = bn_transformer_forward(&model, next, pos);
329-
pos++;
330-
if (!logits) break;
331-
}
375+
int gen_count = generate_response(&model, &tokenizer, &sampler,
376+
args.n_tokens, &pos,
377+
print_token, NULL);
332378

333379
// Feed EOT into KV cache to close the assistant turn
334380
bn_transformer_forward(&model, tokenizer.eot_id, pos);
@@ -337,7 +383,7 @@ int main(int argc, char **argv) {
337383
printf("\n");
338384

339385
turn_count++;
340-
int should_reset = (turn_count >= 2) || (gen_count == -1);
386+
int should_reset = (turn_count >= 2) || (gen_count < 0);
341387
if (should_reset) {
342388
printf("[auto-reset: starting fresh]\n");
343389
pos = 0;
@@ -360,7 +406,6 @@ int main(int argc, char **argv) {
360406
bn_tokenizer_free(&tokenizer);
361407
bn_model_free(&model);
362408
bn_gguf_free(gf);
363-
bn_platform_unload_file(&mf);
364409
return 1;
365410
}
366411
int n_prompt = bn_tokenizer_encode(&tokenizer, args.prompt, 1, prompt_tokens,
@@ -442,9 +487,8 @@ int main(int argc, char **argv) {
442487
// Cleanup
443488
bn_sampler_free(&sampler);
444489
bn_tokenizer_free(&tokenizer);
445-
bn_model_free(&model);
490+
bn_model_free(&model); // also unloads mmap'd file
446491
bn_gguf_free(gf);
447-
bn_platform_unload_file(&mf);
448492

449493
return 0;
450494
}

src/model.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ void bn_model_free(BnModel *m) {
346346
bn_tp_free(m->pool);
347347
free(m->weights.layers);
348348
sh_arena_free(m->arena); // frees INT8 embeddings too
349+
bn_platform_unload_file(&m->file); // safe even if file.data is NULL
349350
memset(m, 0, sizeof(BnModel));
350351
}
351352

src/platform.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ void bn_platform_unload_file(BnMappedFile *f) {
6969
#else
7070
if (f->is_mmap == 1) {
7171
munmap(f->data, f->size);
72+
} else if (f->is_mmap == 0) {
73+
free(f->data);
7274
}
73-
// is_mmap == 0 or 2: don't free (external buffer or unused)
75+
// is_mmap == 2: externally owned, don't free
7476
#endif
7577
f->data = NULL;
7678
f->size = 0;

src/quant.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,12 @@ void bn_quant_matvec_batch(const BnMatvecTask *tasks, int n_tasks,
671671
}
672672

673673
if (all_i2s) {
674-
assert(n_tasks <= 4 && "bn_quant_matvec_batch: max 4 tasks");
674+
// Fall back to individual matvecs if too many tasks for stack arrays
675+
if (n_tasks > 4) {
676+
for (int t = 0; t < n_tasks; t++)
677+
bn_quant_matvec(tasks[t].out, tasks[t].W, x, x_q_buf, pool);
678+
return;
679+
}
675680

676681
// Quantize x to int8 once, shared across all tasks
677682
float x_scale = bn_quant_x_to_i8(x, x_q_buf, cols);

0 commit comments

Comments
 (0)