@@ -1560,20 +1560,32 @@ void bn_transformer_gpu_emit_context_moe(BnTransformerGPUEmitContext *ctx,
15601560 }
15611561
15621562 if (lw -> shared .shared_gate .data && shared && shared -> shared_gate ) {
1563- uint32_t shared_gate_flags =
1564- lw -> shared .shared_gate .type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u ;
1565- uint32_t shared_up_flags =
1566- lw -> shared .shared_up .type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u ;
1567- emit_context_matvec_flags (
1568- ctx , lw -> shared .shared_gate .type ,
1569- shared -> shared_gate ,
1570- BN_GPU_VALUE_XB , BN_GPU_VALUE_HB , lw -> shared .shared_gate .rows ,
1571- lw -> shared .shared_gate .cols , 0 , shared_gate_flags );
1572- emit_context_matvec_flags (
1573- ctx , lw -> shared .shared_up .type ,
1574- shared -> shared_up ,
1575- BN_GPU_VALUE_XB , BN_GPU_VALUE_HB2 , lw -> shared .shared_up .rows ,
1576- lw -> shared .shared_up .cols , 0 , shared_up_flags );
1563+ if (shared -> shared_gateup_stacked ) {
1564+ emit_context_matvec_split (
1565+ ctx , lw -> shared .shared_gate .type ,
1566+ shared -> shared_gateup_stacked ,
1567+ BN_GPU_VALUE_XB , BN_GPU_VALUE_HB , BN_GPU_VALUE_HB2 , -1 ,
1568+ lw -> shared .shared_gate .rows + lw -> shared .shared_up .rows ,
1569+ lw -> shared .shared_gate .cols , lw -> shared .shared_gate .rows ,
1570+ 0 , 0 , 0 , 0 );
1571+ } else {
1572+ uint32_t shared_gate_flags =
1573+ lw -> shared .shared_gate .type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u ;
1574+ uint32_t shared_up_flags =
1575+ lw -> shared .shared_up .type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u ;
1576+ emit_context_matvec_flags (
1577+ ctx , lw -> shared .shared_gate .type ,
1578+ shared -> shared_gate ,
1579+ BN_GPU_VALUE_XB , BN_GPU_VALUE_HB ,
1580+ lw -> shared .shared_gate .rows , lw -> shared .shared_gate .cols , 0 ,
1581+ shared_gate_flags );
1582+ emit_context_matvec_flags (
1583+ ctx , lw -> shared .shared_up .type ,
1584+ shared -> shared_up ,
1585+ BN_GPU_VALUE_XB , BN_GPU_VALUE_HB2 ,
1586+ lw -> shared .shared_up .rows , lw -> shared .shared_up .cols , 0 ,
1587+ shared_up_flags );
1588+ }
15771589 bn_transformer_gpu_emit_context_activation (
15781590 ctx , BN_GPU_VALUE_HB , BN_GPU_VALUE_HB2 ,
15791591 lw -> shared .shared_gate .rows , 0 , BN_GPU_IR_ACTIVATION_SILU );
0 commit comments