|
16 | 16 | #include <stdlib.h> |
17 | 17 | #include <string.h> |
18 | 18 | #include <time.h> |
| 19 | +#include <limits.h> |
19 | 20 |
|
20 | 21 | typedef struct { |
21 | 22 | void *data; |
@@ -3444,6 +3445,87 @@ static __global__ void moe_route_topk_kernel(float *route, |
3444 | 3445 | route[k + i] = (float)selected[i]; |
3445 | 3446 | } |
3446 | 3447 |
|
| 3448 | +static __global__ void moe_q4k_gateup_routed_mid_kernel( |
| 3449 | + float *mid, |
| 3450 | + const BnBlockQ4K *gate, |
| 3451 | + const BnBlockQ4K *up, |
| 3452 | + const BnCudaBlockQ8_1 *xq, |
| 3453 | + const float *route, |
| 3454 | + int hidden, |
| 3455 | + int cols, |
| 3456 | + int n_experts, |
| 3457 | + int k) { |
| 3458 | + int lane = threadIdx.x & 31; |
| 3459 | + int warp = threadIdx.x >> 5; |
| 3460 | + int warps_per_block = blockDim.x >> 5; |
| 3461 | + int task = blockIdx.x * warps_per_block + warp; |
| 3462 | + if (task >= hidden * k) return; |
| 3463 | + |
| 3464 | + int slot = task / hidden; |
| 3465 | + int row = task - slot * hidden; |
| 3466 | + int expert = (int)(route[k + slot] + 0.5f); |
| 3467 | + if (expert < 0) expert = 0; |
| 3468 | + if (expert >= n_experts) expert = n_experts - 1; |
| 3469 | + float route_weight = route[slot]; |
| 3470 | + |
| 3471 | + int n_bpr = cols / BN_QK_K; |
| 3472 | + int kbx = lane / 16; |
| 3473 | + int iqs = 2 * (lane & 15); |
| 3474 | + size_t expert_row = ((size_t)expert * (size_t)hidden + (size_t)row); |
| 3475 | + const BnBlockQ4K *gate_blocks = gate + expert_row * (size_t)n_bpr; |
| 3476 | + const BnBlockQ4K *up_blocks = up + expert_row * (size_t)n_bpr; |
| 3477 | + float gate_sum = 0.0f; |
| 3478 | + float up_sum = 0.0f; |
| 3479 | + for (int b = kbx; b < n_bpr; b += 2) { |
| 3480 | + const BnCudaBlockQ8_1 *xqb = xq + (size_t)b * 8; |
| 3481 | + gate_sum += cuda_vec_dot_q4k_q8_1(&gate_blocks[b], xqb, iqs); |
| 3482 | + up_sum += cuda_vec_dot_q4k_q8_1(&up_blocks[b], xqb, iqs); |
| 3483 | + } |
| 3484 | + for (int offset = 16; offset > 0; offset >>= 1) { |
| 3485 | + gate_sum += __shfl_down_sync(0xffffffffu, gate_sum, offset); |
| 3486 | + up_sum += __shfl_down_sync(0xffffffffu, up_sum, offset); |
| 3487 | + } |
| 3488 | + if (lane == 0) { |
| 3489 | + float silu = gate_sum / (1.0f + __expf(-gate_sum)); |
| 3490 | + mid[(size_t)slot * (size_t)hidden + (size_t)row] = |
| 3491 | + route_weight * silu * up_sum; |
| 3492 | + } |
| 3493 | +} |
| 3494 | + |
| 3495 | +static __global__ void moe_q6k_down_routed_q8k_accum_kernel( |
| 3496 | + float *out, |
| 3497 | + const BnBlockQ6K *down, |
| 3498 | + const BnBlockQ8K *mid_q, |
| 3499 | + const float *route, |
| 3500 | + int dim, |
| 3501 | + int hidden, |
| 3502 | + int n_experts, |
| 3503 | + int k) { |
| 3504 | + int lane = threadIdx.x & 31; |
| 3505 | + int warp = threadIdx.x >> 5; |
| 3506 | + int warps_per_block = blockDim.x >> 5; |
| 3507 | + int row = blockIdx.x * warps_per_block + warp; |
| 3508 | + if (row >= dim) return; |
| 3509 | + |
| 3510 | + int n_bpr = hidden / BN_QK_K; |
| 3511 | + float sum = 0.0f; |
| 3512 | + for (int slot = 0; slot < k; slot++) { |
| 3513 | + int expert = (int)(route[k + slot] + 0.5f); |
| 3514 | + if (expert < 0) expert = 0; |
| 3515 | + if (expert >= n_experts) expert = n_experts - 1; |
| 3516 | + const BnBlockQ6K *row_blocks = |
| 3517 | + down + (((size_t)expert * (size_t)dim + (size_t)row) * |
| 3518 | + (size_t)n_bpr); |
| 3519 | + const BnBlockQ8K *slot_mid_q = mid_q + (size_t)slot * (size_t)n_bpr; |
| 3520 | + for (int b = lane; b < n_bpr; b += 32) |
| 3521 | + sum += cuda_vec_dot_q6k_q8k(&row_blocks[b], slot_mid_q + b); |
| 3522 | + } |
| 3523 | + for (int offset = 16; offset > 0; offset >>= 1) |
| 3524 | + sum += __shfl_down_sync(0xffffffffu, sum, offset); |
| 3525 | + if (lane == 0) |
| 3526 | + out[row] = sum; |
| 3527 | +} |
| 3528 | + |
3447 | 3529 | static __device__ __forceinline__ float cuda_fast_exp(float x) { |
3448 | 3530 | x = fminf(88.7f, fmaxf(-87.3f, x)); |
3449 | 3531 | float n_f = floorf(x * 1.4426950409f + 0.5f); |
@@ -5173,6 +5255,10 @@ static int cuda_init_activations(void *vctx, const void *config_ptr) { |
5173 | 5255 | moe_scratch = c->n_experts; |
5174 | 5256 | if (2 * c->n_experts_active > moe_scratch) |
5175 | 5257 | moe_scratch = 2 * c->n_experts_active; |
| 5258 | + if (c->n_experts_active > 0 && |
| 5259 | + c->moe_intermediate_size <= INT_MAX / c->n_experts_active && |
| 5260 | + c->moe_intermediate_size * c->n_experts_active > moe_scratch) |
| 5261 | + moe_scratch = c->moe_intermediate_size * c->n_experts_active; |
5176 | 5262 | sizes[BN_GPU_VALUE_MOE_HB] = |
5177 | 5263 | (size_t)moe_scratch * sizeof(float); |
5178 | 5264 | sizes[BN_GPU_VALUE_MOE_HB2] = |
@@ -8369,6 +8455,7 @@ static const char *cuda_op_name(int code) { |
8369 | 8455 | case BN_GPU_CODE_SILU_ACT: return "silu_act"; |
8370 | 8456 | case BN_GPU_CODE_RELU2_ACT: return "relu2_act"; |
8371 | 8457 | case BN_GPU_CODE_MOE_ROUTE_TOPK: return "moe_route_topk"; |
| 8458 | + case BN_GPU_CODE_MOE_ROUTED_FFN: return "moe_routed_ffn"; |
8372 | 8459 | case BN_GPU_CODE_ROPE: return "rope"; |
8373 | 8460 | case BN_GPU_CODE_ROPE_QK: return "rope_qk"; |
8374 | 8461 | case BN_GPU_CODE_GQA_SCORES: return "gqa_scores"; |
@@ -8434,6 +8521,7 @@ static int cuda_op_reads_buf(const BnGPUOp *op, int buf) { |
8434 | 8521 | case BN_GPU_CODE_Q5K_MATVEC_SPLIT: |
8435 | 8522 | case BN_GPU_CODE_FUSED_GATEUP_SILU: |
8436 | 8523 | case BN_GPU_CODE_MOE_ROUTE_TOPK: |
| 8524 | + case BN_GPU_CODE_MOE_ROUTED_FFN: |
8437 | 8525 | case BN_GPU_CODE_RMSNORM: |
8438 | 8526 | case BN_GPU_CODE_PER_HEAD_RMSNORM: |
8439 | 8527 | case BN_GPU_CODE_COPY: |
@@ -9433,6 +9521,65 @@ static int cuda_execute(void *vctx, const void *ops_raw, int n_ops, |
9433 | 9521 | route, logits, n_experts, k); |
9434 | 9522 | break; |
9435 | 9523 | } |
| 9524 | + case BN_GPU_CODE_MOE_ROUTED_FFN: { |
| 9525 | + BnCudaBuffer *gate = (BnCudaBuffer *)op->W_buf; |
| 9526 | + BnCudaBuffer *up = (BnCudaBuffer *)op->W_buf2; |
| 9527 | + BnCudaBuffer *down = (BnCudaBuffer *)op->W_buf3; |
| 9528 | + float *in = cuda_act(ctx, op->buf_in); |
| 9529 | + float *route = cuda_act(ctx, op->buf_aux); |
| 9530 | + int mid_buf = (int)op->p[4]; |
| 9531 | + float *mid = cuda_act(ctx, mid_buf); |
| 9532 | + float *out = cuda_act(ctx, op->buf_out); |
| 9533 | + int hidden = (int)op->p[0]; |
| 9534 | + int n_experts = (int)op->p[1]; |
| 9535 | + int k = (int)op->p[2]; |
| 9536 | + int down_type = (int)op->p[3]; |
| 9537 | + int dim = op->cols; |
| 9538 | + if (!gate || !gate->data || !up || !up->data || |
| 9539 | + !down || !down->data || !in || !route || !mid || !out || |
| 9540 | + op->type != BN_GGUF_TENSOR_Q4_K || |
| 9541 | + down_type != BN_GGUF_TENSOR_Q6_K || |
| 9542 | + dim <= 0 || hidden <= 0 || n_experts <= 0 || k <= 0 || |
| 9543 | + (dim % BN_QK_K) != 0 || (hidden % BN_QK_K) != 0 || |
| 9544 | + gate->type != BN_GGUF_TENSOR_Q4_K || |
| 9545 | + up->type != BN_GGUF_TENSOR_Q4_K || |
| 9546 | + down->type != BN_GGUF_TENSOR_Q6_K || |
| 9547 | + gate->rows < hidden * n_experts || gate->cols < dim || |
| 9548 | + up->rows < hidden * n_experts || up->cols < dim || |
| 9549 | + down->rows < dim * n_experts || down->cols < hidden || |
| 9550 | + ctx->act_sizes[op->buf_aux] < |
| 9551 | + (size_t)(2 * k) * sizeof(float) || |
| 9552 | + ctx->act_sizes[mid_buf] < |
| 9553 | + (size_t)k * (size_t)hidden * sizeof(float) || |
| 9554 | + ctx->act_sizes[op->buf_out] < (size_t)dim * sizeof(float)) |
| 9555 | + return -1; |
| 9556 | + if (cuda_ensure_q8_1(ctx, dim) != 0) return -1; |
| 9557 | + BnCudaBlockQ8_1 *xq = (BnCudaBlockQ8_1 *)ctx->d_q8_1; |
| 9558 | + BN_CUDA_LAUNCH(ctx, quantize_q8_1_kernel, |
| 9559 | + (dim + 31) / 32, 32, 0, xq, in, dim); |
| 9560 | + { |
| 9561 | + int route_threads = 256; |
| 9562 | + int warps = route_threads / 32; |
| 9563 | + int gateup_tasks = hidden * k; |
| 9564 | + int gateup_blocks = (gateup_tasks + warps - 1) / warps; |
| 9565 | + BN_CUDA_LAUNCH(ctx, moe_q4k_gateup_routed_mid_kernel, |
| 9566 | + gateup_blocks, route_threads, 0, |
| 9567 | + mid, (const BnBlockQ4K *)gate->data, |
| 9568 | + (const BnBlockQ4K *)up->data, xq, route, hidden, dim, |
| 9569 | + n_experts, k); |
| 9570 | + if (cuda_ensure_q8_k(ctx, hidden, k) != 0) return -1; |
| 9571 | + BnBlockQ8K *mid_q = (BnBlockQ8K *)ctx->d_q8_k; |
| 9572 | + BN_CUDA_LAUNCH(ctx, quantize_q8k_batch_kernel, |
| 9573 | + dim3(hidden / BN_QK_K, k, 1), BN_QK_K, 0, |
| 9574 | + mid_q, mid, hidden, k); |
| 9575 | + int down_blocks = (dim + warps - 1) / warps; |
| 9576 | + BN_CUDA_LAUNCH(ctx, moe_q6k_down_routed_q8k_accum_kernel, |
| 9577 | + down_blocks, route_threads, 0, |
| 9578 | + out, (const BnBlockQ6K *)down->data, mid_q, route, |
| 9579 | + dim, hidden, n_experts, k); |
| 9580 | + } |
| 9581 | + break; |
| 9582 | + } |
9436 | 9583 | case BN_GPU_CODE_RMSNORM: { |
9437 | 9584 | BnCudaBuffer *w = (BnCudaBuffer *)op->W_buf; |
9438 | 9585 | float *in = cuda_act(ctx, op->buf_in); |
|
0 commit comments