Skip to content

Commit 8de6996

Browse files
committed
Auto-cache CUDA MoE Q6K down weights
1 parent 637ba50 commit 8de6996

3 files changed

Lines changed: 61 additions & 7 deletions

File tree

include/gpu_backend.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ struct BnGPUBackend {
4343
// caches. Optional; callers use this for memory-sensitive resident caches.
4444
void *(*buffer_create_quant_only)(void *ctx, const void *data, size_t size,
4545
int type, int rows, int cols);
46+
// Upload quantized Q6_K data and request an FP32 auxiliary cache when the
47+
// backend can fit it. Optional; callers use this for CUDA MoE down paths
48+
// that are faster with resident dequantized weights.
49+
void *(*buffer_create_q6_f32_cache)(void *ctx, const void *data,
50+
size_t size, int type,
51+
int rows, int cols);
4652
void (*buffer_destroy)(void *ctx, void *buffer);
4753

4854
// Upload quantized weight data with fused bias. Returns opaque buffer handle.

src/gpu_cuda.cu

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6091,7 +6091,8 @@ static int cuda_read_activation(void *vctx, int buf_idx, void *out,
60916091
return err == cudaSuccess ? 0 : -1;
60926092
}
60936093

6094-
static int cuda_buffer_create_f16_cache(BnCudaBuffer *buf) {
6094+
static int cuda_buffer_create_f16_cache(BnCudaBuffer *buf,
6095+
int force_q6_f32) {
60956096
if (!buf || !buf->data || buf->rows <= 0 || buf->cols <= 0 ||
60966097
(buf->cols & 31) != 0)
60976098
return 0;
@@ -6105,6 +6106,7 @@ static int cuda_buffer_create_f16_cache(BnCudaBuffer *buf) {
61056106
return 0;
61066107

61076108
int q6_as_f16 = buf->type == BN_GGUF_TENSOR_Q6_K &&
6109+
!force_q6_f32 &&
61086110
getenv("BN_CUDA_DISABLE_Q6K_CUBLAS_F16") == NULL &&
61096111
getenv("BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE") == NULL;
61106112
size_t n = (size_t)buf->rows * (size_t)buf->cols;
@@ -6224,7 +6226,7 @@ static void *cuda_buffer_create_impl(void *vctx, const void *data, size_t size,
62246226
return NULL;
62256227
}
62266228
if (create_aux_cache)
6227-
cuda_buffer_create_f16_cache(buf);
6229+
cuda_buffer_create_f16_cache(buf, create_aux_cache == 2);
62286230
return buf;
62296231
}
62306232

@@ -6239,6 +6241,12 @@ static void *cuda_buffer_create_quant_only(void *vctx, const void *data,
62396241
return cuda_buffer_create_impl(vctx, data, size, type, rows, cols, 0);
62406242
}
62416243

6244+
static void *cuda_buffer_create_q6_f32_cache(void *vctx, const void *data,
6245+
size_t size, int type, int rows,
6246+
int cols) {
6247+
return cuda_buffer_create_impl(vctx, data, size, type, rows, cols, 2);
6248+
}
6249+
62426250
static void cuda_buffer_destroy(void *vctx, void *buffer) {
62436251
BnCudaCtx *ctx = (BnCudaCtx *)vctx;
62446252
if (cuda_ctx_set_device(ctx) != 0) return;
@@ -9396,13 +9404,17 @@ enum {
93969404
BN_CUDA_PROFILE_QKV_MIXED = 64,
93979405
BN_CUDA_PROFILE_READBACK = 65,
93989406
BN_CUDA_PROFILE_LOGITS = 66,
9399-
BN_CUDA_PROFILE_MAX = 67
9407+
BN_CUDA_PROFILE_MOE_GATEUP = 67,
9408+
BN_CUDA_PROFILE_MOE_DOWN = 68,
9409+
BN_CUDA_PROFILE_MAX = 69
94009410
};
94019411
94029412
static const char *cuda_profile_name(int code) {
94039413
if (code == BN_CUDA_PROFILE_QKV_MIXED) return "qkv_mixed";
94049414
if (code == BN_CUDA_PROFILE_READBACK) return "readback";
94059415
if (code == BN_CUDA_PROFILE_LOGITS) return "logits";
9416+
if (code == BN_CUDA_PROFILE_MOE_GATEUP) return "moe_gateup";
9417+
if (code == BN_CUDA_PROFILE_MOE_DOWN) return "moe_down";
94069418
return cuda_op_name(code);
94079419
}
94089420
@@ -10509,6 +10521,15 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1050910521
} else {
1051010522
int use_q4k_q8k_dot =
1051110523
getenv("BN_CUDA_DISABLE_MOE_Q4K_Q8K_DOT") == NULL;
10524+
int profile_moe_internal =
10525+
profile && getenv("BN_CUDA_PROFILE_MOE_INTERNAL");
10526+
cudaEvent_t moe_ev_start = NULL;
10527+
cudaEvent_t moe_ev_stop = NULL;
10528+
if (profile_moe_internal) {
10529+
cudaEventCreate(&moe_ev_start);
10530+
cudaEventCreate(&moe_ev_stop);
10531+
cudaEventRecord(moe_ev_start, ctx->exec_stream);
10532+
}
1051210533
if (use_q4k_q8k_dot) {
1051310534
if (cuda_ensure_q8_k(ctx, dim, 1) != 0) return -1;
1051410535
BnBlockQ8K *xq = (BnBlockQ8K *)ctx->d_q8_k;
@@ -10534,11 +10555,20 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1053410555
(const BnBlockQ4K *)up->data, xq, route, hidden,
1053510556
dim, n_experts, k);
1053610557
}
10558+
if (profile_moe_internal) {
10559+
cudaEventRecord(moe_ev_stop, ctx->exec_stream);
10560+
cudaEventSynchronize(moe_ev_stop);
10561+
float ms = 0.0f;
10562+
cudaEventElapsedTime(&ms, moe_ev_start, moe_ev_stop);
10563+
profile_ops[BN_CUDA_PROFILE_MOE_GATEUP]++;
10564+
profile_ms[BN_CUDA_PROFILE_MOE_GATEUP] += (double)ms;
10565+
cudaEventRecord(moe_ev_start, ctx->exec_stream);
10566+
}
1053710567
if (down_type == BN_GGUF_TENSOR_Q6_K) {
1053810568
int use_q6_float_down =
1053910569
getenv("BN_CUDA_DISABLE_Q6K_FLOAT_MOE_DOWN") == NULL;
10540-
if (getenv("BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE") &&
10541-
down->f32_data) {
10570+
if (down->f32_data &&
10571+
getenv("BN_CUDA_DISABLE_Q6K_MOE_DOWN_F32_CACHE") == NULL) {
1054210572
BN_CUDA_LAUNCH(ctx,
1054310573
moe_q6k_down_routed_f32_cache_warp_kernel,
1054410574
down_blocks, route_threads, 0,
@@ -10574,6 +10604,16 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1057410604
out, (const BnBlockQ4K *)down->data, mid_q,
1057510605
route, dim, hidden, n_experts, k);
1057610606
}
10607+
if (profile_moe_internal) {
10608+
cudaEventRecord(moe_ev_stop, ctx->exec_stream);
10609+
cudaEventSynchronize(moe_ev_stop);
10610+
float ms = 0.0f;
10611+
cudaEventElapsedTime(&ms, moe_ev_start, moe_ev_stop);
10612+
profile_ops[BN_CUDA_PROFILE_MOE_DOWN]++;
10613+
profile_ms[BN_CUDA_PROFILE_MOE_DOWN] += (double)ms;
10614+
cudaEventDestroy(moe_ev_start);
10615+
cudaEventDestroy(moe_ev_stop);
10616+
}
1057710617
}
1057810618
}
1057910619
break;
@@ -11453,6 +11493,7 @@ BnGPUBackend *bn_gpu_cuda_create(void) {
1145311493
(void)cublasSetMathMode(ctx->cublas, CUBLAS_TENSOR_OP_MATH);
1145411494
gpu->buffer_create = cuda_buffer_create;
1145511495
gpu->buffer_create_quant_only = cuda_buffer_create_quant_only;
11496+
gpu->buffer_create_q6_f32_cache = cuda_buffer_create_q6_f32_cache;
1145611497
gpu->buffer_destroy = cuda_buffer_destroy;
1145711498
gpu->matvec = cuda_matvec;
1145811499
gpu->matmul = cuda_matmul;

src/model_gpu.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,17 @@ static void *upload_moe_all_proj(BnModel *model,
120120
return NULL;
121121
if ((size_t)n_experts > (size_t)INT_MAX / (size_t)rows)
122122
return NULL;
123-
int force_full_buffer =
123+
int prefer_q6_f32_cache =
124124
proj == 2 && type == BN_GGUF_TENSOR_Q6_K &&
125-
getenv("BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE");
125+
gpu->buffer_create_q6_f32_cache &&
126+
getenv("BN_CUDA_DISABLE_Q6K_MOE_DOWN_F32_CACHE") == NULL;
127+
int force_full_buffer =
128+
prefer_q6_f32_cache ||
129+
(proj == 2 && type == BN_GGUF_TENSOR_Q6_K &&
130+
getenv("BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE"));
126131
void *(*create_buffer)(void *, const void *, size_t, int, int, int) =
132+
prefer_q6_f32_cache
133+
? gpu->buffer_create_q6_f32_cache :
127134
(!force_full_buffer && gpu->buffer_create_quant_only)
128135
? gpu->buffer_create_quant_only
129136
: gpu->buffer_create;

0 commit comments

Comments
 (0)