@@ -3958,6 +3958,46 @@ static __global__ void moe_q6k_down_routed_float_accum_row_kernel(
39583958 out[row] = partial[0 ];
39593959}
39603960
3961+ static __global__ void moe_q6k_down_routed_f32_cache_row_kernel (
3962+ float *out,
3963+ const float *down,
3964+ const float *mid,
3965+ const float *route,
3966+ int dim,
3967+ int hidden,
3968+ int n_experts,
3969+ int k) {
3970+ int row = blockIdx .x ;
3971+ int tid = threadIdx .x ;
3972+ if (row >= dim) return ;
3973+
3974+ float sum = 0 .0f ;
3975+ for (int slot = 0 ; slot < k; slot++) {
3976+ int expert = (int )(route[k + slot] + 0 .5f );
3977+ if (expert < 0 ) expert = 0 ;
3978+ if (expert >= n_experts) expert = n_experts - 1 ;
3979+ const float *row_w =
3980+ down + ((size_t )expert * (size_t )dim + (size_t )row) *
3981+ (size_t )hidden;
3982+ const float *slot_mid = mid + (size_t )slot * (size_t )hidden;
3983+ float slot_sum = 0 .0f ;
3984+ for (int c = tid; c < hidden; c += blockDim .x )
3985+ slot_sum += row_w[c] * slot_mid[c];
3986+ sum += route[slot] * slot_sum;
3987+ }
3988+
3989+ __shared__ float partial[256 ];
3990+ partial[tid] = sum;
3991+ __syncthreads ();
3992+ for (int stride = blockDim .x >> 1 ; stride > 0 ; stride >>= 1 ) {
3993+ if (tid < stride)
3994+ partial[tid] += partial[tid + stride];
3995+ __syncthreads ();
3996+ }
3997+ if (tid == 0 )
3998+ out[row] = partial[0 ];
3999+ }
4000+
39614001static __global__ void moe_q6k_down_routed_q8k_accum_batch_kernel (
39624002 float *out,
39634003 const BnBlockQ6K *down,
@@ -6070,7 +6110,8 @@ static int cuda_buffer_create_f16_cache(BnCudaBuffer *buf) {
60706110 return 0 ;
60716111
60726112 int q6_as_f16 = buf->type == BN_GGUF_TENSOR_Q6_K &&
6073- getenv (" BN_CUDA_DISABLE_Q6K_CUBLAS_F16" ) == NULL ;
6113+ getenv (" BN_CUDA_DISABLE_Q6K_CUBLAS_F16" ) == NULL &&
6114+ getenv (" BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE" ) == NULL ;
60746115 size_t n = (size_t )buf->rows * (size_t )buf->cols ;
60756116 size_t bytes = n * (buf->type == BN_GGUF_TENSOR_Q6_K && !q6_as_f16
60766117 ? sizeof (float )
@@ -10508,7 +10549,14 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1050810549 dim3 (hidden / BN_QK_K, k, 1 ), BN_QK_K, 0 ,
1050910550 mid_q, mid, hidden, k);
1051010551 if (down_type == BN_GGUF_TENSOR_Q6_K) {
10511- if (getenv (" BN_CUDA_ENABLE_Q6K_FLOAT_MOE_DOWN" )) {
10552+ if (getenv (" BN_CUDA_ENABLE_Q6K_MOE_DOWN_F32_CACHE" ) &&
10553+ down->f32_data ) {
10554+ BN_CUDA_LAUNCH (ctx,
10555+ moe_q6k_down_routed_f32_cache_row_kernel,
10556+ dim, route_threads, 0 ,
10557+ out, (const float *)down->f32_data , mid,
10558+ route, dim, hidden, n_experts, k);
10559+ } else if (getenv (" BN_CUDA_ENABLE_Q6K_FLOAT_MOE_DOWN" )) {
1051210560 BN_CUDA_LAUNCH (ctx,
1051310561 moe_q6k_down_routed_float_accum_row_kernel,
1051410562 dim, route_threads, 0 ,
0 commit comments