Skip to content

Commit dcca06b

Browse files
committed
Add opt-in F16 CUDA MoE down path
1 parent 986152a commit dcca06b

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

src/gpu_cuda.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4651,6 +4651,41 @@ static __global__ void moe_q6k_down_routed_f32_cache_warp_kernel(
46514651
out[row] = sum;
46524652
}
46534653

4654+
static __global__ void moe_q6k_down_routed_f16_cache_warp_kernel(
4655+
float *out,
4656+
const __half *down,
4657+
const float *mid,
4658+
const float *route,
4659+
int dim,
4660+
int hidden,
4661+
int n_experts,
4662+
int k) {
4663+
int lane = threadIdx.x & 31;
4664+
int warp = threadIdx.x >> 5;
4665+
int warps_per_block = blockDim.x >> 5;
4666+
int row = blockIdx.x * warps_per_block + warp;
4667+
if (row >= dim) return;
4668+
4669+
float sum = 0.0f;
4670+
for (int slot = 0; slot < k; slot++) {
4671+
int expert = (int)(route[k + slot] + 0.5f);
4672+
if (expert < 0) expert = 0;
4673+
if (expert >= n_experts) expert = n_experts - 1;
4674+
const uint16_t *row_w =
4675+
(const uint16_t *)down +
4676+
((size_t)expert * (size_t)dim + (size_t)row) * (size_t)hidden;
4677+
const float *slot_mid = mid + (size_t)slot * (size_t)hidden;
4678+
float slot_sum = 0.0f;
4679+
for (int c = lane; c < hidden; c += 32)
4680+
slot_sum += cuda_fp16_to_fp32(row_w[c]) * slot_mid[c];
4681+
sum += route[slot] * slot_sum;
4682+
}
4683+
for (int offset = 16; offset > 0; offset >>= 1)
4684+
sum += __shfl_down_sync(0xffffffffu, sum, offset);
4685+
if (lane == 0)
4686+
out[row] = sum;
4687+
}
4688+
46544689
static __global__ void moe_q6k_down_routed_f32_cache_batch_kernel(
46554690
float *out,
46564691
const float *down,
@@ -12830,6 +12865,14 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
1283012865
down_blocks, route_threads, 0,
1283112866
out, (const float *)down->f32_data, mid,
1283212867
route, dim, hidden, n_experts, k);
12868+
} else if (down->f16_data &&
12869+
getenv("BN_CUDA_ENABLE_Q6K_MOE_DOWN_F16_CACHE") != NULL &&
12870+
getenv("BN_CUDA_DISABLE_Q6K_MOE_DOWN_F16_CACHE") == NULL) {
12871+
BN_CUDA_LAUNCH(ctx,
12872+
moe_q6k_down_routed_f16_cache_warp_kernel,
12873+
down_blocks, route_threads, 0,
12874+
out, (const __half *)down->f16_data, mid,
12875+
route, dim, hidden, n_experts, k);
1283312876
} else if (use_q6_float_down) {
1283412877
BN_CUDA_LAUNCH(ctx,
1283512878
moe_q6k_down_routed_float_accum_row_kernel,

0 commit comments

Comments
 (0)