@@ -276,12 +276,15 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
276276 }
277277
278278 // Derived dimensions (safe now — denominators validated above)
279- c -> head_size = c -> dim / c -> n_heads ;
279+ // Check for explicit head size (Qwen3 has key_length != dim/n_heads)
280+ snprintf (key , sizeof (key ), "%s.attention.key_length" , prefix );
281+ int explicit_head_size = (int )bn_gguf_get_u32 (f , key );
282+ c -> head_size = (explicit_head_size > 0 ) ? explicit_head_size : (c -> dim / c -> n_heads );
280283 c -> kv_dim = c -> head_size * c -> n_kv_heads ;
281284 c -> kv_mul = c -> n_heads / c -> n_kv_heads ;
282285
283286 // Validate alignment for SIMD vectorized paths
284- if (c -> dim % c -> n_heads != 0 ) {
287+ if (explicit_head_size == 0 && c -> dim % c -> n_heads != 0 ) {
285288 SH_LOG_ERROR ("dim not divisible by n_heads" );
286289 return -1 ;
287290 }
@@ -687,7 +690,13 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
687690 goto fail_state ;
688691 }
689692
693+ // q_dim = n_heads * head_size (may differ from dim when attention.key_length is set)
694+ int q_dim = c -> n_heads * c -> head_size ;
695+ int xb_size = q_dim > c -> dim ? q_dim : c -> dim ; // xb must hold attention output
696+ int q_size = xb_size ; // q must match xb for attention head access pattern
697+
690698 int x_q_size = c -> dim > c -> hidden_dim ? c -> dim : c -> hidden_dim ;
699+ if (q_dim > x_q_size ) x_q_size = q_dim ;
691700 int half_head = c -> head_size / 2 ;
692701
693702 // Scratch buffer sizes — enlarged for hybrid SSM + gated-Q attention
@@ -702,9 +711,11 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
702711 if (c -> ssm_inner_size > xb2_size ) xb2_size = c -> ssm_inner_size ;
703712 if (c -> ssm_inner_size > x_q_size ) x_q_size = c -> ssm_inner_size ;
704713 // Gated Q attention: q_full (Q + gate) → hb
705- int gq = 2 * c -> dim ;
714+ int gq = 2 * q_dim ;
706715 if (gq > hb_size ) hb_size = gq ;
707716 }
717+ // Non-gated models with q_dim > dim still need hb for gated-Q check
718+ if (2 * q_dim > hb_size ) hb_size = 2 * q_dim ;
708719 // MoE shared expert may need larger hb/hb2 buffers
709720 if (c -> has_shared_expert && c -> shared_expert_intermediate_size > hb_size )
710721 hb_size = c -> shared_expert_intermediate_size ;
@@ -787,7 +798,7 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
787798
788799 // Compute total arena capacity (all RunState buffers + INT8 embeddings + Q4 repacking)
789800 size_t arena_size = 0 ;
790- arena_size += (3 * (size_t )c -> dim + (size_t )xb2_size ) * sizeof (float ); // x, xb, xb2, q (xb2 may be > dim for SSM)
801+ arena_size += ((size_t )c -> dim + (size_t )xb_size + ( size_t ) xb2_size + ( size_t ) q_size ) * sizeof (float ); // x, xb, xb2, q
791802 arena_size += ((size_t )hb_size + (size_t )hb2_size ) * sizeof (float ); // hb, hb2
792803 arena_size += att_size * sizeof (float ); // att
793804 arena_size += (size_t )c -> vocab_size * sizeof (float ); // logits
@@ -813,9 +824,9 @@ int bn_model_load(BnModel *m, BnGGUFFile *f, int max_seq_len, int kv_f16) {
813824 }
814825
815826 s -> x = (float * )sh_arena_calloc (m -> arena , c -> dim , sizeof (float ));
816- s -> xb = (float * )sh_arena_calloc (m -> arena , c -> dim , sizeof (float ));
827+ s -> xb = (float * )sh_arena_calloc (m -> arena , xb_size , sizeof (float ));
817828 s -> xb2 = (float * )sh_arena_calloc (m -> arena , xb2_size , sizeof (float ));
818- s -> q = (float * )sh_arena_calloc (m -> arena , c -> dim , sizeof (float ));
829+ s -> q = (float * )sh_arena_calloc (m -> arena , q_size , sizeof (float ));
819830 s -> hb = (float * )sh_arena_calloc (m -> arena , hb_size , sizeof (float ));
820831 s -> hb2 = (float * )sh_arena_calloc (m -> arena , hb2_size , sizeof (float ));
821832 s -> att = (float * )sh_arena_calloc (m -> arena , att_size , sizeof (float ));
0 commit comments