Skip to content

Commit 028d32d

Browse files
committed
Use 4-row CUDA MoE Q4K batch gateup
1 parent 02d9bb4 commit 028d32d

1 file changed

Lines changed: 17 additions & 10 deletions

File tree

src/gpu_cuda.cu

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3744,14 +3744,16 @@ static __global__ void moe_q4k_gateup_routed_mid_q8k_4row_kernel(
37443744
}
37453745
}
37463746

3747-
static __global__ void moe_q4k_gateup_routed_mid_q8k_batch_kernel(
3747+
static __global__ void moe_q4k_gateup_routed_mid_q8k_4row_batch_kernel(
37483748
float *mid, const BnBlockQ4K *gate, const BnBlockQ4K *up,
37493749
const BnBlockQ8K *xq, const int *indices, int hidden, int cols,
37503750
int n_experts, int k, int n_tokens) {
37513751
int lane = threadIdx.x & 31;
37523752
int warp = threadIdx.x >> 5;
37533753
int warps_per_block = blockDim.x >> 5;
3754-
int task = blockIdx.x * warps_per_block + warp;
3754+
int lane_group = lane >> 3;
3755+
int sublane = lane & 7;
3756+
int task = (blockIdx.x * warps_per_block + warp) * 4 + lane_group;
37553757
int total_tasks = n_tokens * k * hidden;
37563758
if (task >= total_tasks) return;
37573759

@@ -3770,15 +3772,18 @@ static __global__ void moe_q4k_gateup_routed_mid_q8k_batch_kernel(
37703772
const BnBlockQ8K *token_xq = xq + (size_t)token * (size_t)n_bpr;
37713773
float gate_sum = 0.0f;
37723774
float up_sum = 0.0f;
3773-
for (int b = lane; b < n_bpr; b += 32) {
3775+
for (int b = sublane; b < n_bpr; b += 8) {
37743776
gate_sum += cuda_vec_dot_q4k_q8k(&gate_blocks[b], token_xq + b);
37753777
up_sum += cuda_vec_dot_q4k_q8k(&up_blocks[b], token_xq + b);
37763778
}
3777-
for (int offset = 16; offset > 0; offset >>= 1) {
3778-
gate_sum += __shfl_down_sync(0xffffffffu, gate_sum, offset);
3779-
up_sum += __shfl_down_sync(0xffffffffu, up_sum, offset);
3780-
}
3781-
if (lane == 0) {
3779+
unsigned mask = 0xffu << (lane_group * 8);
3780+
gate_sum += __shfl_down_sync(mask, gate_sum, 4);
3781+
gate_sum += __shfl_down_sync(mask, gate_sum, 2);
3782+
gate_sum += __shfl_down_sync(mask, gate_sum, 1);
3783+
up_sum += __shfl_down_sync(mask, up_sum, 4);
3784+
up_sum += __shfl_down_sync(mask, up_sum, 2);
3785+
up_sum += __shfl_down_sync(mask, up_sum, 1);
3786+
if (sublane == 0) {
37823787
float silu = gate_sum / (1.0f + __expf(-gate_sum));
37833788
mid[((size_t)token * (size_t)k + (size_t)slot) *
37843789
(size_t)hidden + (size_t)row] = silu * up_sum;
@@ -7885,7 +7890,7 @@ static int cuda_moe_routed_ffn_batch(void *vctx, float *out,
78857890
(const BnBlockQ8_0 *)up->data, d_full_x, d_indices,
78867891
d_weights, hidden_dim, dim, n_experts, k, n_tokens);
78877892
} else {
7888-
if (getenv("BN_CUDA_ENABLE_Q4K_Q8K_DOT")) {
7893+
if (getenv("BN_CUDA_DISABLE_Q4K_Q8K_DOT") == NULL) {
78897894
if (cuda_ensure_q8_k(ctx, dim, n_tokens) != 0)
78907895
return -1;
78917896
BnBlockQ8K *xq = (BnBlockQ8K *)ctx->d_q8_k;
@@ -7899,7 +7904,9 @@ static int cuda_moe_routed_ffn_batch(void *vctx, float *out,
78997904
cudaGetErrorString(err));
79007905
return -1;
79017906
}
7902-
moe_q4k_gateup_routed_mid_q8k_batch_kernel<<<gateup_blocks, threads, 0>>>(
7907+
int gateup4_tasks = (gateup_tasks + 3) / 4;
7908+
int gateup4_blocks = (gateup4_tasks + warps - 1) / warps;
7909+
moe_q4k_gateup_routed_mid_q8k_4row_batch_kernel<<<gateup4_blocks, threads, 0>>>(
79037910
d_mid, (const BnBlockQ4K *)gate->data,
79047911
(const BnBlockQ4K *)up->data, xq, d_indices,
79057912
hidden_dim, dim, n_experts, k, n_tokens);

0 commit comments

Comments
 (0)