@@ -6091,7 +6091,8 @@ static int cuda_read_activation(void *vctx, int buf_idx, void *out,
60916091 return err == cudaSuccess ? 0 : -1 ;
60926092}
60936093
6094- static int cuda_buffer_create_f16_cache (BnCudaBuffer *buf) {
6094+ static int cuda_buffer_create_f16_cache (BnCudaBuffer *buf,
6095+ int force_q6_f32) {
60956096 if (!buf || !buf->data || buf->rows <= 0 || buf->cols <= 0 ||
60966097 (buf->cols & 31 ) != 0 )
60976098 return 0 ;
@@ -6105,6 +6106,7 @@ static int cuda_buffer_create_f16_cache(BnCudaBuffer *buf) {
61056106 return 0 ;
61066107
61076108 int q6_as_f16 = buf->type == BN_GGUF_TENSOR_Q6_K &&
6109+ !force_q6_f32 &&
61086110 getenv (" BN_CUDA_DISABLE_Q6K_CUBLAS_F16" ) == NULL &&
61096111 getenv (" BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE" ) == NULL ;
61106112 size_t n = (size_t )buf->rows * (size_t )buf->cols ;
@@ -6224,7 +6226,7 @@ static void *cuda_buffer_create_impl(void *vctx, const void *data, size_t size,
62246226 return NULL ;
62256227 }
62266228 if (create_aux_cache)
6227- cuda_buffer_create_f16_cache (buf);
6229+ cuda_buffer_create_f16_cache (buf, create_aux_cache == 2 );
62286230 return buf;
62296231}
62306232
@@ -6239,6 +6241,12 @@ static void *cuda_buffer_create_quant_only(void *vctx, const void *data,
62396241 return cuda_buffer_create_impl (vctx, data, size, type, rows, cols, 0 );
62406242}
62416243
6244+ static void *cuda_buffer_create_q6_f32_cache (void *vctx, const void *data,
6245+ size_t size, int type, int rows,
6246+ int cols) {
6247+ return cuda_buffer_create_impl (vctx, data, size, type, rows, cols, 2 );
6248+ }
6249+
62426250static void cuda_buffer_destroy (void *vctx, void *buffer) {
62436251 BnCudaCtx *ctx = (BnCudaCtx *)vctx;
62446252 if (cuda_ctx_set_device (ctx) != 0 ) return ;
@@ -9396,13 +9404,17 @@ enum {
93969404 BN_CUDA_PROFILE_QKV_MIXED = 64 ,
93979405 BN_CUDA_PROFILE_READBACK = 65 ,
93989406 BN_CUDA_PROFILE_LOGITS = 66 ,
9399- BN_CUDA_PROFILE_MAX = 67
9407+ BN_CUDA_PROFILE_MOE_GATEUP = 67 ,
9408+ BN_CUDA_PROFILE_MOE_DOWN = 68 ,
9409+ BN_CUDA_PROFILE_MAX = 69
94009410};
94019411
94029412static const char *cuda_profile_name (int code) {
94039413 if (code == BN_CUDA_PROFILE_QKV_MIXED) return " qkv_mixed" ;
94049414 if (code == BN_CUDA_PROFILE_READBACK) return " readback" ;
94059415 if (code == BN_CUDA_PROFILE_LOGITS) return " logits" ;
9416+ if (code == BN_CUDA_PROFILE_MOE_GATEUP) return " moe_gateup" ;
9417+ if (code == BN_CUDA_PROFILE_MOE_DOWN) return " moe_down" ;
94069418 return cuda_op_name (code);
94079419}
94089420
@@ -10509,6 +10521,15 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1050910521 } else {
1051010522 int use_q4k_q8k_dot =
1051110523 getenv (" BN_CUDA_DISABLE_MOE_Q4K_Q8K_DOT" ) == NULL ;
10524+ int profile_moe_internal =
10525+ profile && getenv (" BN_CUDA_PROFILE_MOE_INTERNAL" );
10526+ cudaEvent_t moe_ev_start = NULL ;
10527+ cudaEvent_t moe_ev_stop = NULL ;
10528+ if (profile_moe_internal) {
10529+ cudaEventCreate (&moe_ev_start);
10530+ cudaEventCreate (&moe_ev_stop);
10531+ cudaEventRecord (moe_ev_start, ctx->exec_stream );
10532+ }
1051210533 if (use_q4k_q8k_dot) {
1051310534 if (cuda_ensure_q8_k (ctx, dim, 1 ) != 0 ) return -1 ;
1051410535 BnBlockQ8K *xq = (BnBlockQ8K *)ctx->d_q8_k ;
@@ -10534,11 +10555,20 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1053410555 (const BnBlockQ4K *)up->data , xq, route, hidden,
1053510556 dim, n_experts, k);
1053610557 }
10558+ if (profile_moe_internal) {
10559+ cudaEventRecord (moe_ev_stop, ctx->exec_stream );
10560+ cudaEventSynchronize (moe_ev_stop);
10561+ float ms = 0 .0f ;
10562+ cudaEventElapsedTime (&ms, moe_ev_start, moe_ev_stop);
10563+ profile_ops[BN_CUDA_PROFILE_MOE_GATEUP]++;
10564+ profile_ms[BN_CUDA_PROFILE_MOE_GATEUP] += (double )ms;
10565+ cudaEventRecord (moe_ev_start, ctx->exec_stream );
10566+ }
1053710567 if (down_type == BN_GGUF_TENSOR_Q6_K) {
1053810568 int use_q6_float_down =
1053910569 getenv (" BN_CUDA_DISABLE_Q6K_FLOAT_MOE_DOWN" ) == NULL ;
10540- if (getenv ( " BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE " ) &&
10541- down-> f32_data ) {
10570+ if (down-> f32_data &&
10571+ getenv ( " BN_CUDA_DISABLE_Q6K_MOE_DOWN_F32_CACHE " ) == NULL ) {
1054210572 BN_CUDA_LAUNCH (ctx,
1054310573 moe_q6k_down_routed_f32_cache_warp_kernel,
1054410574 down_blocks, route_threads, 0 ,
@@ -10574,6 +10604,16 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1057410604 out, (const BnBlockQ4K *)down->data , mid_q,
1057510605 route, dim, hidden, n_experts, k);
1057610606 }
10607+ if (profile_moe_internal) {
10608+ cudaEventRecord (moe_ev_stop, ctx->exec_stream );
10609+ cudaEventSynchronize (moe_ev_stop);
10610+ float ms = 0 .0f ;
10611+ cudaEventElapsedTime (&ms, moe_ev_start, moe_ev_stop);
10612+ profile_ops[BN_CUDA_PROFILE_MOE_DOWN]++;
10613+ profile_ms[BN_CUDA_PROFILE_MOE_DOWN] += (double )ms;
10614+ cudaEventDestroy (moe_ev_start);
10615+ cudaEventDestroy (moe_ev_stop);
10616+ }
1057710617 }
1057810618 }
1057910619 break ;
@@ -11453,6 +11493,7 @@ BnGPUBackend *bn_gpu_cuda_create(void) {
1145311493 (void )cublasSetMathMode (ctx->cublas , CUBLAS_TENSOR_OP_MATH);
1145411494 gpu->buffer_create = cuda_buffer_create;
1145511495 gpu->buffer_create_quant_only = cuda_buffer_create_quant_only;
11496+ gpu->buffer_create_q6_f32_cache = cuda_buffer_create_q6_f32_cache;
1145611497 gpu->buffer_destroy = cuda_buffer_destroy;
1145711498 gpu->matvec = cuda_matvec;
1145811499 gpu->matmul = cuda_matmul;
0 commit comments