Skip to content

Commit 6a10781

Browse files
committed
Support attention.key_length for Qwen3 MoE (head_size != dim/n_heads)
Qwen3-30B-A3B has head_size=128 but dim/n_heads=64. Read explicit head_size from attention.key_length GGUF key. Size Q/xb buffers to n_heads*head_size, fix gated-Q detection to compare against q_dim instead of dim. Verified: Qwen3-30B-A3B Q4_K_M generates coherent output at 1.8 tok/s (single-threaded, 48 layers, 128 experts K=8, 18.6GB model).
1 parent 0386559 commit 6a10781

1 file changed

Lines changed: 17 additions & 6 deletions

File tree

src/model.c

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,15 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
276276
}
277277

278278
// Derived dimensions (safe now — denominators validated above)
279-
c->head_size = c->dim / c->n_heads;
279+
// Check for explicit head size (Qwen3 has key_length != dim/n_heads)
280+
snprintf(key, sizeof(key), "%s.attention.key_length", prefix);
281+
int explicit_head_size = (int)bn_gguf_get_u32(f, key);
282+
c->head_size = (explicit_head_size > 0) ? explicit_head_size : (c->dim / c->n_heads);
280283
c->kv_dim = c->head_size * c->n_kv_heads;
281284
c->kv_mul = c->n_heads / c->n_kv_heads;
282285

283286
// Validate alignment for SIMD vectorized paths
284-
if (c->dim % c->n_heads != 0) {
287+
if (explicit_head_size == 0 && c->dim % c->n_heads != 0) {
285288
SH_LOG_ERROR("dim not divisible by n_heads");
286289
return -1;
287290
}
@@ -687,7 +690,13 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
687690
goto fail_state;
688691
}
689692

693+
// q_dim = n_heads * head_size (may differ from dim when attention.key_length is set)
694+
int q_dim = c->n_heads * c->head_size;
695+
int xb_size = q_dim > c->dim ? q_dim : c->dim; // xb must hold attention output
696+
int q_size = xb_size; // q must match xb for attention head access pattern
697+
690698
int x_q_size = c->dim > c->hidden_dim ? c->dim : c->hidden_dim;
699+
if (q_dim > x_q_size) x_q_size = q_dim;
691700
int half_head = c->head_size / 2;
692701

693702
// Scratch buffer sizes — enlarged for hybrid SSM + gated-Q attention
@@ -702,9 +711,11 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
702711
if (c->ssm_inner_size > xb2_size) xb2_size = c->ssm_inner_size;
703712
if (c->ssm_inner_size > x_q_size) x_q_size = c->ssm_inner_size;
704713
// Gated Q attention: q_full (Q + gate) → hb
705-
int gq = 2 * c->dim;
714+
int gq = 2 * q_dim;
706715
if (gq > hb_size) hb_size = gq;
707716
}
717+
// Non-gated models with q_dim > dim still need hb for gated-Q check
718+
if (2 * q_dim > hb_size) hb_size = 2 * q_dim;
708719
// MoE shared expert may need larger hb/hb2 buffers
709720
if (c->has_shared_expert && c->shared_expert_intermediate_size > hb_size)
710721
hb_size = c->shared_expert_intermediate_size;
@@ -787,7 +798,7 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
787798

788799
// Compute total arena capacity (all RunState buffers + INT8 embeddings + Q4 repacking)
789800
size_t arena_size = 0;
790-
arena_size += (3 * (size_t)c->dim + (size_t)xb2_size) * sizeof(float); // x, xb, xb2, q (xb2 may be > dim for SSM)
801+
arena_size += ((size_t)c->dim + (size_t)xb_size + (size_t)xb2_size + (size_t)q_size) * sizeof(float); // x, xb, xb2, q
791802
arena_size += ((size_t)hb_size + (size_t)hb2_size) * sizeof(float); // hb, hb2
792803
arena_size += att_size * sizeof(float); // att
793804
arena_size += (size_t)c->vocab_size * sizeof(float); // logits
@@ -813,9 +824,9 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
813824
}
814825

815826
s->x = (float *)sh_arena_calloc(m->arena, c->dim, sizeof(float));
816-
s->xb = (float *)sh_arena_calloc(m->arena, c->dim, sizeof(float));
827+
s->xb = (float *)sh_arena_calloc(m->arena, xb_size, sizeof(float));
817828
s->xb2 = (float *)sh_arena_calloc(m->arena, xb2_size, sizeof(float));
818-
s->q = (float *)sh_arena_calloc(m->arena, c->dim, sizeof(float));
829+
s->q = (float *)sh_arena_calloc(m->arena, q_size, sizeof(float));
819830
s->hb = (float *)sh_arena_calloc(m->arena, hb_size, sizeof(float));
820831
s->hb2 = (float *)sh_arena_calloc(m->arena, hb2_size, sizeof(float));
821832
s->att = (float *)sh_arena_calloc(m->arena, att_size, sizeof(float));

0 commit comments

Comments
 (0)