Skip to content

Commit 200b8dc

Browse files
committed
Add opt-in CUDA routed MoE FFN
1 parent b57dfab commit 200b8dc

7 files changed

Lines changed: 338 additions & 1 deletion

File tree

include/backend_model.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ typedef enum {
3838
BN_BACKEND_HANDLE_MOE_ROUTER_DIFF = 25,
3939
BN_BACKEND_HANDLE_SHARED_GATEUP_STACKED = 26,
4040
BN_BACKEND_HANDLE_MOE_ROUTER = 27,
41+
BN_BACKEND_HANDLE_MOE_GATE_ALL = 28,
42+
BN_BACKEND_HANDLE_MOE_UP_ALL = 29,
43+
BN_BACKEND_HANDLE_MOE_DOWN_ALL = 30,
4144
} BnBackendHandleRole;
4245

4346
BnBackendModel *bn_backend_model_create(void);

src/gpu_cuda.cu

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <stdlib.h>
1717
#include <string.h>
1818
#include <time.h>
19+
#include <limits.h>
1920

2021
typedef struct {
2122
void *data;
@@ -3444,6 +3445,87 @@ static __global__ void moe_route_topk_kernel(float *route,
34443445
route[k + i] = (float)selected[i];
34453446
}
34463447

3448+
static __global__ void moe_q4k_gateup_routed_mid_kernel(
3449+
float *mid,
3450+
const BnBlockQ4K *gate,
3451+
const BnBlockQ4K *up,
3452+
const BnCudaBlockQ8_1 *xq,
3453+
const float *route,
3454+
int hidden,
3455+
int cols,
3456+
int n_experts,
3457+
int k) {
3458+
int lane = threadIdx.x & 31;
3459+
int warp = threadIdx.x >> 5;
3460+
int warps_per_block = blockDim.x >> 5;
3461+
int task = blockIdx.x * warps_per_block + warp;
3462+
if (task >= hidden * k) return;
3463+
3464+
int slot = task / hidden;
3465+
int row = task - slot * hidden;
3466+
int expert = (int)(route[k + slot] + 0.5f);
3467+
if (expert < 0) expert = 0;
3468+
if (expert >= n_experts) expert = n_experts - 1;
3469+
float route_weight = route[slot];
3470+
3471+
int n_bpr = cols / BN_QK_K;
3472+
int kbx = lane / 16;
3473+
int iqs = 2 * (lane & 15);
3474+
size_t expert_row = ((size_t)expert * (size_t)hidden + (size_t)row);
3475+
const BnBlockQ4K *gate_blocks = gate + expert_row * (size_t)n_bpr;
3476+
const BnBlockQ4K *up_blocks = up + expert_row * (size_t)n_bpr;
3477+
float gate_sum = 0.0f;
3478+
float up_sum = 0.0f;
3479+
for (int b = kbx; b < n_bpr; b += 2) {
3480+
const BnCudaBlockQ8_1 *xqb = xq + (size_t)b * 8;
3481+
gate_sum += cuda_vec_dot_q4k_q8_1(&gate_blocks[b], xqb, iqs);
3482+
up_sum += cuda_vec_dot_q4k_q8_1(&up_blocks[b], xqb, iqs);
3483+
}
3484+
for (int offset = 16; offset > 0; offset >>= 1) {
3485+
gate_sum += __shfl_down_sync(0xffffffffu, gate_sum, offset);
3486+
up_sum += __shfl_down_sync(0xffffffffu, up_sum, offset);
3487+
}
3488+
if (lane == 0) {
3489+
float silu = gate_sum / (1.0f + __expf(-gate_sum));
3490+
mid[(size_t)slot * (size_t)hidden + (size_t)row] =
3491+
route_weight * silu * up_sum;
3492+
}
3493+
}
3494+
3495+
static __global__ void moe_q6k_down_routed_q8k_accum_kernel(
3496+
float *out,
3497+
const BnBlockQ6K *down,
3498+
const BnBlockQ8K *mid_q,
3499+
const float *route,
3500+
int dim,
3501+
int hidden,
3502+
int n_experts,
3503+
int k) {
3504+
int lane = threadIdx.x & 31;
3505+
int warp = threadIdx.x >> 5;
3506+
int warps_per_block = blockDim.x >> 5;
3507+
int row = blockIdx.x * warps_per_block + warp;
3508+
if (row >= dim) return;
3509+
3510+
int n_bpr = hidden / BN_QK_K;
3511+
float sum = 0.0f;
3512+
for (int slot = 0; slot < k; slot++) {
3513+
int expert = (int)(route[k + slot] + 0.5f);
3514+
if (expert < 0) expert = 0;
3515+
if (expert >= n_experts) expert = n_experts - 1;
3516+
const BnBlockQ6K *row_blocks =
3517+
down + (((size_t)expert * (size_t)dim + (size_t)row) *
3518+
(size_t)n_bpr);
3519+
const BnBlockQ8K *slot_mid_q = mid_q + (size_t)slot * (size_t)n_bpr;
3520+
for (int b = lane; b < n_bpr; b += 32)
3521+
sum += cuda_vec_dot_q6k_q8k(&row_blocks[b], slot_mid_q + b);
3522+
}
3523+
for (int offset = 16; offset > 0; offset >>= 1)
3524+
sum += __shfl_down_sync(0xffffffffu, sum, offset);
3525+
if (lane == 0)
3526+
out[row] = sum;
3527+
}
3528+
34473529
static __device__ __forceinline__ float cuda_fast_exp(float x) {
34483530
x = fminf(88.7f, fmaxf(-87.3f, x));
34493531
float n_f = floorf(x * 1.4426950409f + 0.5f);
@@ -5173,6 +5255,10 @@ static int cuda_init_activations(void *vctx, const void *config_ptr) {
51735255
moe_scratch = c->n_experts;
51745256
if (2 * c->n_experts_active > moe_scratch)
51755257
moe_scratch = 2 * c->n_experts_active;
5258+
if (c->n_experts_active > 0 &&
5259+
c->moe_intermediate_size <= INT_MAX / c->n_experts_active &&
5260+
c->moe_intermediate_size * c->n_experts_active > moe_scratch)
5261+
moe_scratch = c->moe_intermediate_size * c->n_experts_active;
51765262
sizes[BN_GPU_VALUE_MOE_HB] =
51775263
(size_t)moe_scratch * sizeof(float);
51785264
sizes[BN_GPU_VALUE_MOE_HB2] =
@@ -8369,6 +8455,7 @@ static const char *cuda_op_name(int code) {
83698455
case BN_GPU_CODE_SILU_ACT: return "silu_act";
83708456
case BN_GPU_CODE_RELU2_ACT: return "relu2_act";
83718457
case BN_GPU_CODE_MOE_ROUTE_TOPK: return "moe_route_topk";
8458+
case BN_GPU_CODE_MOE_ROUTED_FFN: return "moe_routed_ffn";
83728459
case BN_GPU_CODE_ROPE: return "rope";
83738460
case BN_GPU_CODE_ROPE_QK: return "rope_qk";
83748461
case BN_GPU_CODE_GQA_SCORES: return "gqa_scores";
@@ -8434,6 +8521,7 @@ static int cuda_op_reads_buf(const BnGPUOp *op, int buf) {
84348521
case BN_GPU_CODE_Q5K_MATVEC_SPLIT:
84358522
case BN_GPU_CODE_FUSED_GATEUP_SILU:
84368523
case BN_GPU_CODE_MOE_ROUTE_TOPK:
8524+
case BN_GPU_CODE_MOE_ROUTED_FFN:
84378525
case BN_GPU_CODE_RMSNORM:
84388526
case BN_GPU_CODE_PER_HEAD_RMSNORM:
84398527
case BN_GPU_CODE_COPY:
@@ -9433,6 +9521,65 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops,
94339521
route, logits, n_experts, k);
94349522
break;
94359523
}
9524+
case BN_GPU_CODE_MOE_ROUTED_FFN: {
9525+
BnCudaBuffer *gate = (BnCudaBuffer *)op->W_buf;
9526+
BnCudaBuffer *up = (BnCudaBuffer *)op->W_buf2;
9527+
BnCudaBuffer *down = (BnCudaBuffer *)op->W_buf3;
9528+
float *in = cuda_act(ctx, op->buf_in);
9529+
float *route = cuda_act(ctx, op->buf_aux);
9530+
int mid_buf = (int)op->p[4];
9531+
float *mid = cuda_act(ctx, mid_buf);
9532+
float *out = cuda_act(ctx, op->buf_out);
9533+
int hidden = (int)op->p[0];
9534+
int n_experts = (int)op->p[1];
9535+
int k = (int)op->p[2];
9536+
int down_type = (int)op->p[3];
9537+
int dim = op->cols;
9538+
if (!gate || !gate->data || !up || !up->data ||
9539+
!down || !down->data || !in || !route || !mid || !out ||
9540+
op->type != BN_GGUF_TENSOR_Q4_K ||
9541+
down_type != BN_GGUF_TENSOR_Q6_K ||
9542+
dim <= 0 || hidden <= 0 || n_experts <= 0 || k <= 0 ||
9543+
(dim % BN_QK_K) != 0 || (hidden % BN_QK_K) != 0 ||
9544+
gate->type != BN_GGUF_TENSOR_Q4_K ||
9545+
up->type != BN_GGUF_TENSOR_Q4_K ||
9546+
down->type != BN_GGUF_TENSOR_Q6_K ||
9547+
gate->rows < hidden * n_experts || gate->cols < dim ||
9548+
up->rows < hidden * n_experts || up->cols < dim ||
9549+
down->rows < dim * n_experts || down->cols < hidden ||
9550+
ctx->act_sizes[op->buf_aux] <
9551+
(size_t)(2 * k) * sizeof(float) ||
9552+
ctx->act_sizes[mid_buf] <
9553+
(size_t)k * (size_t)hidden * sizeof(float) ||
9554+
ctx->act_sizes[op->buf_out] < (size_t)dim * sizeof(float))
9555+
return -1;
9556+
if (cuda_ensure_q8_1(ctx, dim) != 0) return -1;
9557+
BnCudaBlockQ8_1 *xq = (BnCudaBlockQ8_1 *)ctx->d_q8_1;
9558+
BN_CUDA_LAUNCH(ctx, quantize_q8_1_kernel,
9559+
(dim + 31) / 32, 32, 0, xq, in, dim);
9560+
{
9561+
int route_threads = 256;
9562+
int warps = route_threads / 32;
9563+
int gateup_tasks = hidden * k;
9564+
int gateup_blocks = (gateup_tasks + warps - 1) / warps;
9565+
BN_CUDA_LAUNCH(ctx, moe_q4k_gateup_routed_mid_kernel,
9566+
gateup_blocks, route_threads, 0,
9567+
mid, (const BnBlockQ4K *)gate->data,
9568+
(const BnBlockQ4K *)up->data, xq, route, hidden, dim,
9569+
n_experts, k);
9570+
if (cuda_ensure_q8_k(ctx, hidden, k) != 0) return -1;
9571+
BnBlockQ8K *mid_q = (BnBlockQ8K *)ctx->d_q8_k;
9572+
BN_CUDA_LAUNCH(ctx, quantize_q8k_batch_kernel,
9573+
dim3(hidden / BN_QK_K, k, 1), BN_QK_K, 0,
9574+
mid_q, mid, hidden, k);
9575+
int down_blocks = (dim + warps - 1) / warps;
9576+
BN_CUDA_LAUNCH(ctx, moe_q6k_down_routed_q8k_accum_kernel,
9577+
down_blocks, route_threads, 0,
9578+
out, (const BnBlockQ6K *)down->data, mid_q, route,
9579+
dim, hidden, n_experts, k);
9580+
}
9581+
break;
9582+
}
94369583
case BN_GPU_CODE_RMSNORM: {
94379584
BnCudaBuffer *w = (BnCudaBuffer *)op->W_buf;
94389585
float *in = cuda_act(ctx, op->buf_in);

src/gpu_shader_ir_internal.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,17 @@ typedef enum {
8282
BN_GPU_CODE_RELU2_ACT,
8383
BN_GPU_CODE_WEIGHTED_ADD_SIGMOID,
8484
BN_GPU_CODE_MOE_ROUTE_TOPK,
85+
BN_GPU_CODE_MOE_ROUTED_FFN,
8586
} BnGPUOpCode;
8687

8788
// A single backend shader command in the lowered forward pass.
8889
typedef struct BnGPUOp {
8990
int op_kind; // BnGPUOpKind semantic op; 0 = infer from op_code
9091
int op_code; // BnGPUOpCode concrete shader operation
9192
int type; // BN_GGUF_TENSOR_* (matvec only, -1 otherwise)
92-
void *W_buf; // weight buffer handle (matvec only, NULL otherwise)
93+
void *W_buf; // primary weight buffer handle
94+
void *W_buf2; // optional secondary weight buffer handle
95+
void *W_buf3; // optional tertiary weight buffer handle
9396
int buf_in; // BN_GPU_VALUE_* primary input
9497
int buf_out; // BN_GPU_VALUE_* output
9598
int buf_aux; // secondary BN_GPU_VALUE_* (-1 if unused)
@@ -134,6 +137,7 @@ static inline BnGPUOpKind bn_gpu_op_kind_from_code(int code) {
134137
return BN_GPU_OP_COPY;
135138
case BN_GPU_CODE_FUSED_GATEUP_SILU:
136139
case BN_GPU_CODE_MOE_ROUTE_TOPK:
140+
case BN_GPU_CODE_MOE_ROUTED_FFN:
137141
return BN_GPU_OP_FFN;
138142
case BN_GPU_CODE_SSM_CONV_SILU:
139143
case BN_GPU_CODE_SSM_L2NORM:

src/model_gpu.c

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#include "backend_layout.h"
33
#include "backend_model.h"
44
#include "gpu_backend.h"
5+
#include "moe_internal.h"
56
#include <stdlib.h>
67
#include <stdint.h>
8+
#include <limits.h>
79

810
static int checked_mul_size(size_t a, size_t b, size_t *out) {
911
if (a != 0 && b > SIZE_MAX / a) return -1;
@@ -69,6 +71,59 @@ static void *upload_moe_router_diff2(BnGPUBackend *gpu,
6971
return handle;
7072
}
7173

74+
static void *upload_moe_all_proj(BnModel *model,
75+
BnGPUBackend *gpu,
76+
const BnMoEExpertMap *em,
77+
int proj,
78+
int n_experts) {
79+
if (!model || !gpu || !em || n_experts <= 0)
80+
return NULL;
81+
size_t offset = 0;
82+
size_t expert_bytes = 0;
83+
if (bn_moe_proj_info(em, 0, proj, &offset, &expert_bytes) != 0 ||
84+
expert_bytes == 0)
85+
return NULL;
86+
size_t stride = 0;
87+
int type = 0;
88+
int rows = 0;
89+
int cols = 0;
90+
switch (proj) {
91+
case 0:
92+
stride = em->gate_stride ? em->gate_stride : em->expert_gate_bytes;
93+
type = em->gate_type;
94+
rows = em->gate_rows;
95+
cols = em->gate_cols;
96+
break;
97+
case 1:
98+
stride = em->up_stride ? em->up_stride : em->expert_up_bytes;
99+
type = em->up_type;
100+
rows = em->up_rows;
101+
cols = em->up_cols;
102+
break;
103+
case 2:
104+
stride = em->down_stride ? em->down_stride : em->expert_down_bytes;
105+
type = em->down_type;
106+
rows = em->down_rows;
107+
cols = em->down_cols;
108+
break;
109+
default:
110+
return NULL;
111+
}
112+
if (stride != expert_bytes)
113+
return NULL;
114+
const uint8_t *base = bn_moe_mmap_base_for_proj(
115+
bn_model_moe_io(model), em, proj);
116+
if (!base)
117+
return NULL;
118+
size_t total_bytes = 0;
119+
if (checked_mul_size(expert_bytes, (size_t)n_experts, &total_bytes) != 0)
120+
return NULL;
121+
if ((size_t)n_experts > (size_t)INT_MAX / (size_t)rows)
122+
return NULL;
123+
return gpu->buffer_create(gpu->ctx, base + offset, total_bytes,
124+
type, rows * n_experts, cols);
125+
}
126+
72127
int bn_model_upload_weights(BnModel *model, BnGPUBackend *gpu) {
73128
if (!model || !gpu || !gpu->buffer_create) return -1;
74129
if (bn_model_ensure_backend(model) != 0) return -1;
@@ -144,6 +199,20 @@ int bn_model_upload_weights(BnModel *model, BnGPUBackend *gpu) {
144199
(size_t)c->n_experts * (size_t)c->dim * sizeof(float),
145200
BN_GGUF_TENSOR_F32, c->n_experts, c->dim)
146201
: NULL;
202+
int upload_moe_all = lw->moe.router_weight &&
203+
getenv("BN_CUDA_ENABLE_MOE_ROUTED_FFN");
204+
void *moe_gate_all_gpu = upload_moe_all
205+
? upload_moe_all_proj(model, gpu, &lw->moe.expert_map, 0,
206+
c->n_experts)
207+
: NULL;
208+
void *moe_up_all_gpu = upload_moe_all
209+
? upload_moe_all_proj(model, gpu, &lw->moe.expert_map, 1,
210+
c->n_experts)
211+
: NULL;
212+
void *moe_down_all_gpu = upload_moe_all
213+
? upload_moe_all_proj(model, gpu, &lw->moe.expert_map, 2,
214+
c->n_experts)
215+
: NULL;
147216
void *shared_expert_gate_gpu = lw->shared.shared_expert_gate
148217
? gpu->buffer_create(
149218
gpu->ctx, lw->shared.shared_expert_gate,
@@ -158,6 +227,12 @@ int bn_model_upload_weights(BnModel *model, BnGPUBackend *gpu) {
158227
moe_router_diff_gpu) != 0 ||
159228
register_gpu_handle(model, l, BN_BACKEND_HANDLE_MOE_ROUTER,
160229
moe_router_gpu) != 0 ||
230+
register_gpu_handle(model, l, BN_BACKEND_HANDLE_MOE_GATE_ALL,
231+
moe_gate_all_gpu) != 0 ||
232+
register_gpu_handle(model, l, BN_BACKEND_HANDLE_MOE_UP_ALL,
233+
moe_up_all_gpu) != 0 ||
234+
register_gpu_handle(model, l, BN_BACKEND_HANDLE_MOE_DOWN_ALL,
235+
moe_down_all_gpu) != 0 ||
161236
register_gpu_handle(model, l, BN_BACKEND_HANDLE_SHARED_EXPERT_GATE,
162237
shared_expert_gate_gpu) != 0) {
163238
if (attn_norm_gpu) gpu->buffer_destroy(gpu->ctx, attn_norm_gpu);
@@ -166,6 +241,12 @@ int bn_model_upload_weights(BnModel *model, BnGPUBackend *gpu) {
166241
gpu->buffer_destroy(gpu->ctx, moe_router_diff_gpu);
167242
if (moe_router_gpu)
168243
gpu->buffer_destroy(gpu->ctx, moe_router_gpu);
244+
if (moe_gate_all_gpu)
245+
gpu->buffer_destroy(gpu->ctx, moe_gate_all_gpu);
246+
if (moe_up_all_gpu)
247+
gpu->buffer_destroy(gpu->ctx, moe_up_all_gpu);
248+
if (moe_down_all_gpu)
249+
gpu->buffer_destroy(gpu->ctx, moe_down_all_gpu);
169250
if (shared_expert_gate_gpu)
170251
gpu->buffer_destroy(gpu->ctx, shared_expert_gate_gpu);
171252
bn_model_release_gpu(model);

src/transformer/gpu.c

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,50 @@ static float *bn_transformer_gpu_forward_impl(BnModel *m, BnSession *sess,
939939
backend, l, BN_BACKEND_HANDLE_MOE_ROUTER);
940940
int gpu_route_topk =
941941
moe_router && !getenv("BN_CUDA_DISABLE_MOE_ROUTER_TOPK");
942+
void *moe_gate_all = bn_backend_model_handle(
943+
backend, l, BN_BACKEND_HANDLE_MOE_GATE_ALL);
944+
void *moe_up_all = bn_backend_model_handle(
945+
backend, l, BN_BACKEND_HANDLE_MOE_UP_ALL);
946+
void *moe_down_all = bn_backend_model_handle(
947+
backend, l, BN_BACKEND_HANDLE_MOE_DOWN_ALL);
948+
int gpu_routed_ffn =
949+
gpu_route_topk && moe_gate_all && moe_up_all && moe_down_all &&
950+
getenv("BN_CUDA_ENABLE_MOE_ROUTED_FFN") &&
951+
!c->has_shared_expert &&
952+
lw->moe.expert_map.gate_type == BN_GGUF_TENSOR_Q4_K &&
953+
lw->moe.expert_map.up_type == BN_GGUF_TENSOR_Q4_K &&
954+
lw->moe.expert_map.down_type == BN_GGUF_TENSOR_Q6_K &&
955+
lw->moe.expert_map.gate_rows == c->moe_intermediate_size &&
956+
lw->moe.expert_map.up_rows == c->moe_intermediate_size &&
957+
lw->moe.expert_map.gate_cols == dim &&
958+
lw->moe.expert_map.up_cols == dim &&
959+
lw->moe.expert_map.down_rows == dim &&
960+
lw->moe.expert_map.down_cols == c->moe_intermediate_size &&
961+
!getenv("BN_CUDA_DISABLE_MOE_ROUTED_FFN");
962+
if (gpu_routed_ffn) {
963+
if (bn_transformer_gpu_emit_context_moe_route_topk(
964+
&emit, moe_router, BN_GPU_VALUE_XB,
965+
BN_GPU_VALUE_MOE_HB, BN_GPU_VALUE_MOE_HB2,
966+
dim, c->n_experts, c->n_experts_active) != 0)
967+
return bn_transformer_gpu_reject_forward(
968+
&emit, "gpu moe route emit failed");
969+
if (bn_transformer_gpu_emit_context_moe_routed_ffn(
970+
&emit, moe_gate_all, moe_up_all, moe_down_all,
971+
BN_GPU_VALUE_XB, BN_GPU_VALUE_MOE_HB2,
972+
BN_GPU_VALUE_MOE_HB, BN_GPU_VALUE_MOE_OUT,
973+
lw->moe.expert_map.gate_type,
974+
lw->moe.expert_map.down_type, dim,
975+
c->moe_intermediate_size, c->n_experts,
976+
c->n_experts_active) != 0)
977+
return bn_transformer_gpu_reject_forward(
978+
&emit, "gpu moe routed ffn emit failed");
979+
bn_transformer_gpu_emit_context_residual_add(
980+
&emit, BN_GPU_VALUE_X, BN_GPU_VALUE_MOE_OUT, dim);
981+
bn_transformer_gpu_emit_context_rmsnorm(
982+
&emit, next_norm, BN_GPU_VALUE_X, BN_GPU_VALUE_XB, dim,
983+
u_eps);
984+
continue;
985+
}
942986
int did_gpu_route_topk = 0;
943987
if (gpu_route_topk) {
944988
if (bn_transformer_gpu_emit_context_moe_route_topk(

0 commit comments

Comments
 (0)