Skip to content

Commit b3873c2

Browse files
committed
Fix audit findings: overflow, VLA guards, pread checks, portability
H1: Add BN_MAX_SCALE_BLOCKS/8 guard to Q4_K/Q6_K VLA allocations H2: Cap cache n_slots to INT_MAX/2, use size_t for raw division H3: Use unsigned for hash_size computation to avoid signed overflow M1: Check/log fallback pread return values when prefetch fails M2: Make threadpool dispatching flag _Atomic M3: Better hash: layer*65537+expert instead of 16-bit truncation M4: Guard layers[0] access with n_layers > 0 check M5: Replace memset(0xFF) with explicit -1 loop for portability L1: Static assert BN_QK_K % 16 == 0 for NEON alignment L2-L3: Assert n divisibility in Q8_K and Q8_0 quantization L4: Overflow-safe chunk size computation in threadpool L5: Clean up partial prefetch init (free succeeded thread on failure) L6: Move atomic cursors to pool-internal storage (no public _Atomic)
1 parent 13159b6 commit b3873c2

6 files changed

Lines changed: 57 additions & 37 deletions

File tree

include/threadpool.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,12 @@
44
// Persistent pthread thread pool with atomic work-stealing dispatch.
55
// Threads grab chunks of rows via atomic_fetch_add for load balancing.
66

7-
#ifndef __EMSCRIPTEN__
8-
#include <stdatomic.h>
9-
#endif
10-
117
typedef void (*bn_tp_fn)(void *ctx, int start, int end);
128

139
typedef struct {
1410
bn_tp_fn fn; // range function: called with [start, end)
1511
void *ctx; // opaque context pointer
1612
int n; // iteration count
17-
#ifndef __EMSCRIPTEN__
18-
_Atomic int cursor; // atomic work-stealing cursor (initialized by dispatch)
19-
#endif
2013
} BnTPTask;
2114

2215
typedef struct BnThreadPool BnThreadPool;

src/main.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ int main(int argc, char **argv) {
290290
bn_moe_prefetch_create(model.moe_state);
291291

292292
// Create expert LRU cache (pread only)
293-
if (args.cache_mb > 0 && !model.moe_state->mmap_base && model.moe_state->fd >= 0) {
293+
if (args.cache_mb > 0 && !model.moe_state->mmap_base && model.moe_state->fd >= 0
294+
&& model.config.n_layers > 0) {
294295
BnMoEExpertMap *em = &model.weights.layers[0].expert_map;
295296
model.moe_state->cache = bn_moe_cache_create(
296297
(size_t)args.cache_mb * 1024 * 1024,

src/moe.c

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <math.h>
77
#include <string.h>
88
#include <stdlib.h>
9+
#include <limits.h>
910

1011
#ifndef __EMSCRIPTEN__
1112
#include <unistd.h>
@@ -43,7 +44,7 @@ typedef struct {
4344
} BnMoECache;
4445

4546
static uint32_t moe_cache_hash(int layer, int expert_idx) {
46-
uint32_t key = ((uint32_t)layer << 16) | (uint32_t)(expert_idx & 0xFFFF);
47+
uint32_t key = (uint32_t)layer * 65537u + (uint32_t)expert_idx;
4748
// murmurhash3 finalizer
4849
key ^= key >> 16;
4950
key *= 0x85ebca6b;
@@ -187,8 +188,10 @@ void *bn_moe_cache_create(size_t budget_bytes, size_t gate_bytes,
187188
size_t entry_bytes = gate_bytes + up_bytes + down_bytes;
188189
if (entry_bytes == 0) return NULL;
189190

190-
int n_slots = (int)(budget_bytes / entry_bytes);
191-
if (n_slots < 1) return NULL;
191+
size_t raw_slots = budget_bytes / entry_bytes;
192+
if (raw_slots < 1) return NULL;
193+
if (raw_slots > (size_t)INT_MAX / 2) raw_slots = (size_t)INT_MAX / 2; // cap to avoid overflow
194+
int n_slots = (int)raw_slots;
192195

193196
BnMoECache *c = (BnMoECache *)calloc(1, sizeof(BnMoECache));
194197
if (!c) return NULL;
@@ -198,10 +201,10 @@ void *bn_moe_cache_create(size_t budget_bytes, size_t gate_bytes,
198201
c->up_bytes = up_bytes;
199202
c->n_slots = n_slots;
200203

201-
// Hash table: next power of 2 >= 2 * n_slots
202-
int hs = 1;
203-
while (hs < 2 * n_slots) hs *= 2;
204-
c->hash_size = hs;
204+
// Hash table: next power of 2 >= 2 * n_slots (unsigned to avoid overflow)
205+
unsigned hs = 1;
206+
while (hs < (unsigned)n_slots * 2) hs *= 2;
207+
c->hash_size = (int)hs;
205208

206209
// Allocate slab (32-byte aligned)
207210
size_t slab_size = (size_t)n_slots * entry_bytes;
@@ -222,8 +225,8 @@ void *bn_moe_cache_create(size_t budget_bytes, size_t gate_bytes,
222225
return NULL;
223226
}
224227

225-
// Initialize
226-
memset(c->hash_table, 0xFF, (size_t)hs * sizeof(int)); // -1
228+
// Initialize hash table to -1 (empty)
229+
for (int i = 0; i < (int)hs; i++) c->hash_table[i] = -1;
227230
c->lru_head = c->lru_tail = -1;
228231

229232
// Build free list (singly-linked via .next)
@@ -896,14 +899,16 @@ void bn_moe_forward(BnModel *m, BnLayerWeights *lw, int l) {
896899
ms->prefetch_wait_ms += moe_time_ms() - tw;
897900
COLLECT_PF_STATS(pf_gu);
898901
if (!ok) {
899-
pread(ms->fd, miss_g_dst, miss_g_sz, (off_t)miss_g_off);
900-
pread(ms->fd, miss_u_dst, miss_u_sz, (off_t)miss_u_off);
902+
if (pread(ms->fd, miss_g_dst, miss_g_sz, (off_t)miss_g_off) < 0)
903+
SH_LOG_ERROR("Fallback gate pread failed");
904+
if (pread(ms->fd, miss_u_dst, miss_u_sz, (off_t)miss_u_off) < 0)
905+
SH_LOG_ERROR("Fallback up pread failed");
901906
}
902907
} else {
903-
pread(ms->fd, miss_g_dst, miss_g_sz, (off_t)miss_g_off);
904-
pread(ms->fd, miss_u_dst, miss_u_sz, (off_t)miss_u_off);
908+
(void)pread(ms->fd, miss_g_dst, miss_g_sz, (off_t)miss_g_off);
909+
(void)pread(ms->fd, miss_u_dst, miss_u_sz, (off_t)miss_u_off);
905910
if (!pf_dn)
906-
pread(ms->fd, miss_d_dst, miss_d_sz, (off_t)miss_d_off);
911+
(void)pread(ms->fd, miss_d_dst, miss_d_sz, (off_t)miss_d_off);
907912
}
908913
gate_ptr = miss_g_dst;
909914
up_ptr = miss_u_dst;
@@ -934,13 +939,15 @@ void bn_moe_forward(BnModel *m, BnLayerWeights *lw, int l) {
934939
ms->prefetch_wait_ms += moe_time_ms() - tw;
935940
COLLECT_PF_STATS(pf_gu);
936941
if (!ok) {
937-
pread(ms->fd, g_dst, g_sz, (off_t)g_off);
938-
pread(ms->fd, u_dst, u_sz, (off_t)u_off);
942+
if (pread(ms->fd, g_dst, g_sz, (off_t)g_off) < 0)
943+
SH_LOG_ERROR("Fallback gate pread failed");
944+
if (pread(ms->fd, u_dst, u_sz, (off_t)u_off) < 0)
945+
SH_LOG_ERROR("Fallback up pread failed");
939946
}
940947
} else {
941-
pread(ms->fd, g_dst, g_sz, (off_t)g_off);
942-
pread(ms->fd, u_dst, u_sz, (off_t)u_off);
943-
pread(ms->fd, d_dst, d_sz, (off_t)d_off);
948+
(void)pread(ms->fd, g_dst, g_sz, (off_t)g_off);
949+
(void)pread(ms->fd, u_dst, u_sz, (off_t)u_off);
950+
(void)pread(ms->fd, d_dst, d_sz, (off_t)d_off);
944951
}
945952

946953
gate_ptr = g_dst;
@@ -1148,10 +1155,14 @@ void bn_moe_prefetch_create(BnMoEState *ms) {
11481155
if (ms->fd >= 0 && !ms->mmap_base) {
11491156
ms->prefetch = moe_prefetch_init(ms->fd);
11501157
ms->prefetch_down = moe_prefetch_init(ms->fd);
1151-
if (ms->prefetch && ms->prefetch_down)
1158+
if (ms->prefetch && ms->prefetch_down) {
11521159
SH_LOG_INFO("MoE I/O prefetch threads", "status", "2 created (gate+up, down)");
1153-
else
1154-
SH_LOG_INFO("MoE I/O prefetch threads", "status", "partial");
1160+
} else {
1161+
// Clean up partial init — free whichever succeeded
1162+
if (ms->prefetch) { moe_prefetch_free((BnMoEPrefetch *)ms->prefetch); ms->prefetch = NULL; }
1163+
if (ms->prefetch_down) { moe_prefetch_free((BnMoEPrefetch *)ms->prefetch_down); ms->prefetch_down = NULL; }
1164+
SH_LOG_WARN("MoE I/O prefetch threads failed to create");
1165+
}
11551166
}
11561167
#endif
11571168
}

src/quant/dispatch.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ void bn_quant_matvec(float *out, const BnQWeight *W, const float *x,
128128
if (W->type == BN_GGUF_TENSOR_Q6_K) {
129129
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
130130
int n_sb = W->cols / BN_QK_K;
131+
if (n_sb < 1 || n_sb > BN_MAX_SCALE_BLOCKS / 8) return;
131132
float q8k_d[n_sb];
132133
int16_t q8k_bsums[n_sb * 16];
133134
bn_quant_x_to_q8k(x, x_q_buf, q8k_d, q8k_bsums, W->cols);
@@ -173,6 +174,7 @@ void bn_quant_matvec(float *out, const BnQWeight *W, const float *x,
173174
if (W->type == BN_GGUF_TENSOR_Q4_K) {
174175
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
175176
int n_sb = W->cols / BN_QK_K;
177+
if (n_sb < 1 || n_sb > BN_MAX_SCALE_BLOCKS / 8) return;
176178
float q8k_d[n_sb];
177179
int16_t q8k_bsums[n_sb * 16];
178180
bn_quant_x_to_q8k(x, x_q_buf, q8k_d, q8k_bsums, W->cols);
@@ -632,6 +634,7 @@ void bn_quant_matvec_batch(const BnMatvecTask *tasks, int n_tasks,
632634

633635
if (all_q6k && n_tasks <= BN_MAX_BATCH) {
634636
int n_sb = cols / BN_QK_K;
637+
if (n_sb < 1 || n_sb > BN_MAX_SCALE_BLOCKS / 8) { for (int t = 0; t < n_tasks; t++) bn_quant_matvec(tasks[t].out, tasks[t].W, x, x_q_buf, pool); return; }
635638
float q8k_d[n_sb];
636639
int16_t q8k_bsums[n_sb * 16];
637640
bn_quant_x_to_q8k(x, x_q_buf, q8k_d, q8k_bsums, cols);
@@ -650,6 +653,7 @@ void bn_quant_matvec_batch(const BnMatvecTask *tasks, int n_tasks,
650653

651654
if (all_q4k && n_tasks <= BN_MAX_BATCH) {
652655
int n_sb = cols / BN_QK_K;
656+
if (n_sb < 1 || n_sb > BN_MAX_SCALE_BLOCKS / 8) { for (int t = 0; t < n_tasks; t++) bn_quant_matvec(tasks[t].out, tasks[t].W, x, x_q_buf, pool); return; }
653657
float q8k_d[n_sb];
654658
int16_t q8k_bsums[n_sb * 16];
655659
bn_quant_x_to_q8k(x, x_q_buf, q8k_d, q8k_bsums, cols);

src/quant/x_quant_neon.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "quant_internal.h"
22
#include <arm_neon.h>
3+
#include <assert.h>
34
#include <math.h>
45

56
// Quantize float vector x[n] to int8, returning scale = amax/127.
@@ -113,11 +114,14 @@ void bn_quant_f16_rows_to_i8(const uint16_t *f16, int8_t *i8_out,
113114
}
114115
}
115116

117+
_Static_assert(BN_QK_K % 16 == 0, "BN_QK_K must be a multiple of 16 for NEON");
118+
116119
// Q8_K quantization: 256-element super-blocks with bsums for Q4_K SDOT.
117120
// x_d[n/256]: one float scale per super-block
118121
// x_bsums[n/256 * 16]: int16 sum per 16-element group (for min correction)
119122
void bn_quant_x_to_q8k(const float *x, int8_t *x_q, float *x_d,
120123
int16_t *x_bsums, int n) {
124+
assert(n % BN_QK_K == 0 && "bn_quant_x_to_q8k: n must be multiple of BN_QK_K");
121125
int n_sb = n / BN_QK_K;
122126
for (int sb = 0; sb < n_sb; sb++) {
123127
const float *xb = x + sb * BN_QK_K;
@@ -170,6 +174,7 @@ void bn_quant_x_to_q8k(const float *x, int8_t *x_q, float *x_d,
170174
// Per-block Q8_0 quantization for Q4_0 integer dot product path.
171175
// Quantizes each 32-element block independently with its own scale.
172176
void bn_quant_x_to_q8_blocks(const float *x, int8_t *x_q, float *x_scales, int n) {
177+
assert(n % 32 == 0 && "bn_quant_x_to_q8_blocks: n must be multiple of 32");
173178
int n_blocks = n / 32;
174179
for (int b = 0; b < n_blocks; b++) {
175180
const float *xb = x + b * 32;

src/threadpool.c

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <stdlib.h>
44
#include <stdint.h>
55
#include <stdatomic.h>
6+
#include <limits.h>
67
#include <assert.h>
78

89
#if defined(__APPLE__)
@@ -19,32 +20,36 @@ typedef struct {
1920
int tid;
2021
} WorkerArg;
2122

23+
#define TP_MAX_TASKS 32 // max concurrent tasks per dispatch
24+
2225
struct BnThreadPool {
2326
pthread_t *threads;
2427
int n_workers; // background threads
2528
int n_threads; // n_workers + 1 (main)
2629
BnTPTask *tasks;
2730
int n_tasks;
31+
_Atomic int cursors[TP_MAX_TASKS]; // atomic work-stealing cursors
2832
pthread_mutex_t mtx;
2933
pthread_cond_t work_cond;
3034
pthread_cond_t done_cond;
3135
int64_t generation;
3236
int n_done;
3337
int shutdown;
34-
int dispatching; // reentrancy guard
38+
_Atomic int dispatching; // reentrancy guard (main-thread-only, atomic for safety)
3539
};
3640

3741
// Execute all tasks via atomic work-stealing with adaptive chunk size.
3842
// Chunk = n / (4 * n_threads) — mostly static, stealing for tail imbalance.
39-
static void tp_execute(const BnThreadPool *pool) {
43+
static void tp_execute(BnThreadPool *pool) {
4044
int nt = pool->n_threads;
4145
for (int t = 0; t < pool->n_tasks; t++) {
4246
BnTPTask *task = &pool->tasks[t];
4347
int n = task->n;
44-
int chunk = n / (nt * 4);
48+
int nt4 = nt <= INT_MAX / 4 ? nt * 4 : nt; // avoid overflow
49+
int chunk = n / nt4;
4550
if (chunk < TP_CHUNK_MIN) chunk = TP_CHUNK_MIN;
4651
for (;;) {
47-
int start = atomic_fetch_add_explicit(&task->cursor, chunk,
52+
int start = atomic_fetch_add_explicit(&pool->cursors[t], chunk,
4853
memory_order_relaxed);
4954
if (start >= n) break;
5055
int end = start + chunk;
@@ -180,9 +185,10 @@ void bn_tp_dispatch(BnThreadPool *pool, BnTPTask *tasks, int n_tasks) {
180185
assert(!pool->dispatching && "bn_tp_dispatch is not reentrant");
181186
pool->dispatching = 1;
182187

183-
// Initialize atomic cursors
184-
for (int t = 0; t < n_tasks; t++)
185-
atomic_store_explicit(&tasks[t].cursor, 0, memory_order_relaxed);
188+
// Initialize atomic cursors (pool-internal storage)
189+
int capped_tasks = n_tasks <= TP_MAX_TASKS ? n_tasks : TP_MAX_TASKS;
190+
for (int t = 0; t < capped_tasks; t++)
191+
atomic_store_explicit(&pool->cursors[t], 0, memory_order_relaxed);
186192

187193
// Set up work and wake workers
188194
pthread_mutex_lock(&pool->mtx);

0 commit comments

Comments
 (0)