@@ -4651,6 +4651,41 @@ static __global__ void moe_q6k_down_routed_f32_cache_warp_kernel(
46514651 out[row] = sum;
46524652}
46534653
4654+ static __global__ void moe_q6k_down_routed_f16_cache_warp_kernel(
4655+ float *out,
4656+ const __half *down,
4657+ const float *mid,
4658+ const float *route,
4659+ int dim,
4660+ int hidden,
4661+ int n_experts,
4662+ int k) {
4663+ int lane = threadIdx.x & 31;
4664+ int warp = threadIdx.x >> 5;
4665+ int warps_per_block = blockDim.x >> 5;
4666+ int row = blockIdx.x * warps_per_block + warp;
4667+ if (row >= dim) return;
4668+
4669+ float sum = 0.0f;
4670+ for (int slot = 0; slot < k; slot++) {
4671+ int expert = (int)(route[k + slot] + 0.5f);
4672+ if (expert < 0) expert = 0;
4673+ if (expert >= n_experts) expert = n_experts - 1;
4674+ const uint16_t *row_w =
4675+ (const uint16_t *)down +
4676+ ((size_t)expert * (size_t)dim + (size_t)row) * (size_t)hidden;
4677+ const float *slot_mid = mid + (size_t)slot * (size_t)hidden;
4678+ float slot_sum = 0.0f;
4679+ for (int c = lane; c < hidden; c += 32)
4680+ slot_sum += cuda_fp16_to_fp32(row_w[c]) * slot_mid[c];
4681+ sum += route[slot] * slot_sum;
4682+ }
4683+ for (int offset = 16; offset > 0; offset >>= 1)
4684+ sum += __shfl_down_sync(0xffffffffu, sum, offset);
4685+ if (lane == 0)
4686+ out[row] = sum;
4687+ }
4688+
46544689static __global__ void moe_q6k_down_routed_f32_cache_batch_kernel(
46554690 float *out,
46564691 const float *down,
@@ -12830,6 +12865,14 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1283012865 down_blocks, route_threads, 0,
1283112866 out, (const float *)down->f32_data, mid,
1283212867 route, dim, hidden, n_experts, k);
12868+ } else if (down->f16_data &&
12869+ getenv("BN_CUDA_ENABLE_Q6K_MOE_DOWN_F16_CACHE") != NULL &&
12870+ getenv("BN_CUDA_DISABLE_Q6K_MOE_DOWN_F16_CACHE") == NULL) {
12871+ BN_CUDA_LAUNCH(ctx,
12872+ moe_q6k_down_routed_f16_cache_warp_kernel,
12873+ down_blocks, route_threads, 0,
12874+ out, (const __half *)down->f16_data, mid,
12875+ route, dim, hidden, n_experts, k);
1283312876 } else if (use_q6_float_down) {
1283412877 BN_CUDA_LAUNCH(ctx,
1283512878 moe_q6k_down_routed_float_accum_row_kernel,
0 commit comments