Skip to content

Commit 2cc622d

Browse files
committed
Stack CUDA shared MoE gate-up
1 parent fb9c252 commit 2cc622d

5 files changed

Lines changed: 42 additions & 14 deletions

File tree

include/backend_model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ typedef enum {
3636
BN_BACKEND_HANDLE_FFN_DOWN_PREFILL = 23,
3737
BN_BACKEND_HANDLE_SHARED_EXPERT_GATE = 24,
3838
BN_BACKEND_HANDLE_MOE_ROUTER_DIFF = 25,
39+
BN_BACKEND_HANDLE_SHARED_GATEUP_STACKED = 26,
3940
} BnBackendHandleRole;
4041

4142
BnBackendModel *bn_backend_model_create(void);

src/model_gpu.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,18 @@ int bn_model_upload_weights(BnModel *model, BnGPUBackend *gpu) {
238238
return -1;
239239
}
240240

241+
void *shared_gateup_stacked_gpu =
242+
bn_backend_layout_upload_stacked2(
243+
gpu, &lw->shared.shared_gate, &lw->shared.shared_up);
244+
if (register_gpu_handle(model, l,
245+
BN_BACKEND_HANDLE_SHARED_GATEUP_STACKED,
246+
shared_gateup_stacked_gpu) != 0) {
247+
if (shared_gateup_stacked_gpu)
248+
gpu->buffer_destroy(gpu->ctx, shared_gateup_stacked_gpu);
249+
bn_model_release_gpu(model);
250+
return -1;
251+
}
252+
241253
void *ssm_qkvz_stacked_gpu =
242254
bn_backend_layout_upload_stacked2(gpu, &lw->ssm.wqkv, &lw->ssm.wz);
243255
if (register_gpu_handle(model, l, BN_BACKEND_HANDLE_SSM_QKVZ_STACKED,

src/transformer/gpu_emit.c

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,20 +1560,32 @@ void bn_transformer_gpu_emit_context_moe(BnTransformerGPUEmitContext *ctx,
15601560
}
15611561

15621562
if (lw->shared.shared_gate.data && shared && shared->shared_gate) {
1563-
uint32_t shared_gate_flags =
1564-
lw->shared.shared_gate.type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u;
1565-
uint32_t shared_up_flags =
1566-
lw->shared.shared_up.type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u;
1567-
emit_context_matvec_flags(
1568-
ctx, lw->shared.shared_gate.type,
1569-
shared->shared_gate,
1570-
BN_GPU_VALUE_XB, BN_GPU_VALUE_HB, lw->shared.shared_gate.rows,
1571-
lw->shared.shared_gate.cols, 0, shared_gate_flags);
1572-
emit_context_matvec_flags(
1573-
ctx, lw->shared.shared_up.type,
1574-
shared->shared_up,
1575-
BN_GPU_VALUE_XB, BN_GPU_VALUE_HB2, lw->shared.shared_up.rows,
1576-
lw->shared.shared_up.cols, 0, shared_up_flags);
1563+
if (shared->shared_gateup_stacked) {
1564+
emit_context_matvec_split(
1565+
ctx, lw->shared.shared_gate.type,
1566+
shared->shared_gateup_stacked,
1567+
BN_GPU_VALUE_XB, BN_GPU_VALUE_HB, BN_GPU_VALUE_HB2, -1,
1568+
lw->shared.shared_gate.rows + lw->shared.shared_up.rows,
1569+
lw->shared.shared_gate.cols, lw->shared.shared_gate.rows,
1570+
0, 0, 0, 0);
1571+
} else {
1572+
uint32_t shared_gate_flags =
1573+
lw->shared.shared_gate.type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u;
1574+
uint32_t shared_up_flags =
1575+
lw->shared.shared_up.type == BN_GGUF_TENSOR_Q4_K ? 1u : 0u;
1576+
emit_context_matvec_flags(
1577+
ctx, lw->shared.shared_gate.type,
1578+
shared->shared_gate,
1579+
BN_GPU_VALUE_XB, BN_GPU_VALUE_HB,
1580+
lw->shared.shared_gate.rows, lw->shared.shared_gate.cols, 0,
1581+
shared_gate_flags);
1582+
emit_context_matvec_flags(
1583+
ctx, lw->shared.shared_up.type,
1584+
shared->shared_up,
1585+
BN_GPU_VALUE_XB, BN_GPU_VALUE_HB2,
1586+
lw->shared.shared_up.rows, lw->shared.shared_up.cols, 0,
1587+
shared_up_flags);
1588+
}
15771589
bn_transformer_gpu_emit_context_activation(
15781590
ctx, BN_GPU_VALUE_HB, BN_GPU_VALUE_HB2,
15791591
lw->shared.shared_gate.rows, 0, BN_GPU_IR_ACTIVATION_SILU);

src/transformer/gpu_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ typedef struct {
7777
void *shared_up;
7878
void *shared_down;
7979
void *shared_expert_gate;
80+
void *shared_gateup_stacked;
8081
} BnTransformerGPUMoESharedResources;
8182

8283
typedef struct {

src/transformer/gpu_resources.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,5 +177,7 @@ bn_transformer_gpu_resolve_moe_shared_resources(
177177
.shared_down = qweight_backend_buf(backend, &lw->shared.shared_down),
178178
.shared_expert_gate = backend_handle_or(
179179
backend, layer, BN_BACKEND_HANDLE_SHARED_EXPERT_GATE),
180+
.shared_gateup_stacked = backend_handle_or(
181+
backend, layer, BN_BACKEND_HANDLE_SHARED_GATEUP_STACKED),
180182
};
181183
}

0 commit comments

Comments
 (0)