@@ -3744,14 +3744,16 @@ static __global__ void moe_q4k_gateup_routed_mid_q8k_4row_kernel(
37443744 }
37453745}
37463746
3747- static __global__ void moe_q4k_gateup_routed_mid_q8k_batch_kernel (
3747+ static __global__ void moe_q4k_gateup_routed_mid_q8k_4row_batch_kernel (
37483748 float *mid, const BnBlockQ4K *gate, const BnBlockQ4K *up,
37493749 const BnBlockQ8K *xq, const int *indices, int hidden, int cols,
37503750 int n_experts, int k, int n_tokens) {
37513751 int lane = threadIdx .x & 31 ;
37523752 int warp = threadIdx .x >> 5 ;
37533753 int warps_per_block = blockDim .x >> 5 ;
3754- int task = blockIdx .x * warps_per_block + warp;
3754+ int lane_group = lane >> 3 ;
3755+ int sublane = lane & 7 ;
3756+ int task = (blockIdx .x * warps_per_block + warp) * 4 + lane_group;
37553757 int total_tasks = n_tokens * k * hidden;
37563758 if (task >= total_tasks) return ;
37573759
@@ -3770,15 +3772,18 @@ static __global__ void moe_q4k_gateup_routed_mid_q8k_batch_kernel(
37703772 const BnBlockQ8K *token_xq = xq + (size_t )token * (size_t )n_bpr;
37713773 float gate_sum = 0 .0f ;
37723774 float up_sum = 0 .0f ;
3773- for (int b = lane ; b < n_bpr; b += 32 ) {
3775+ for (int b = sublane ; b < n_bpr; b += 8 ) {
37743776 gate_sum += cuda_vec_dot_q4k_q8k (&gate_blocks[b], token_xq + b);
37753777 up_sum += cuda_vec_dot_q4k_q8k (&up_blocks[b], token_xq + b);
37763778 }
3777- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
3778- gate_sum += __shfl_down_sync (0xffffffffu , gate_sum, offset);
3779- up_sum += __shfl_down_sync (0xffffffffu , up_sum, offset);
3780- }
3781- if (lane == 0 ) {
3779+ unsigned mask = 0xffu << (lane_group * 8 );
3780+ gate_sum += __shfl_down_sync (mask, gate_sum, 4 );
3781+ gate_sum += __shfl_down_sync (mask, gate_sum, 2 );
3782+ gate_sum += __shfl_down_sync (mask, gate_sum, 1 );
3783+ up_sum += __shfl_down_sync (mask, up_sum, 4 );
3784+ up_sum += __shfl_down_sync (mask, up_sum, 2 );
3785+ up_sum += __shfl_down_sync (mask, up_sum, 1 );
3786+ if (sublane == 0 ) {
37823787 float silu = gate_sum / (1 .0f + __expf (-gate_sum));
37833788 mid[((size_t )token * (size_t )k + (size_t )slot) *
37843789 (size_t )hidden + (size_t )row] = silu * up_sum;
@@ -7885,7 +7890,7 @@ static int cuda_moe_routed_ffn_batch(void *vctx, float *out,
78857890 (const BnBlockQ8_0 *)up->data , d_full_x, d_indices,
78867891 d_weights, hidden_dim, dim, n_experts, k, n_tokens);
78877892 } else {
7888- if (getenv (" BN_CUDA_ENABLE_Q4K_Q8K_DOT " ) ) {
7893+ if (getenv (" BN_CUDA_DISABLE_Q4K_Q8K_DOT " ) == NULL ) {
78897894 if (cuda_ensure_q8_k (ctx, dim, n_tokens) != 0 )
78907895 return -1 ;
78917896 BnBlockQ8K *xq = (BnBlockQ8K *)ctx->d_q8_k ;
@@ -7899,7 +7904,9 @@ static int cuda_moe_routed_ffn_batch(void *vctx, float *out,
78997904 cudaGetErrorString (err));
79007905 return -1 ;
79017906 }
7902- moe_q4k_gateup_routed_mid_q8k_batch_kernel<<<gateup_blocks, threads, 0 >>> (
7907+ int gateup4_tasks = (gateup_tasks + 3 ) / 4 ;
7908+ int gateup4_blocks = (gateup4_tasks + warps - 1 ) / warps;
7909+ moe_q4k_gateup_routed_mid_q8k_4row_batch_kernel<<<gateup4_blocks, threads, 0 >>> (
79037910 d_mid, (const BnBlockQ4K *)gate->data ,
79047911 (const BnBlockQ4K *)up->data , xq, d_indices,
79057912 hidden_dim, dim, n_experts, k, n_tokens);
0 commit comments