Skip to content

Commit 757add1

Browse files
authored
Merge pull request #1456 from kvcache-ai/support-smt-glm4
Support SmallThinker and GLM4-MoE
2 parents 1677e90 + 1334ddc commit 757add1

35 files changed

+3934
-74
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ ktransformers/tests/chat_txt.txt
2626
mmlu_result*
2727
ktransformers/ktransformers_ext/cuda_musa/
2828
test_prompt.txt
29-
csrc/demo
29+
csrc/demo
30+
build*
31+
CMakeFiles/
32+
kvc2/
33+
sched/

README.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,14 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
2323

2424
<h2 id="Updates">🔥 Updates</h2>
2525

26+
* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md))
2627
* **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md))
27-
2828
* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse.
29-
3029
* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
31-
3230
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
3331

3432
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2
3533

36-
37-
38-
3934
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).
4035
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).
4136

@@ -65,7 +60,7 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
6560
</p>
6661

6762
- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)).
68-
63+
6964
- Prefill Speed (tokens/s):
7065
- KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)
7166
- Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**.
@@ -131,7 +126,6 @@ we have already supported vendors:
131126
- Kunpeng
132127
- AMD
133128

134-
135129
### 📥 Installation
136130

137131
To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
@@ -201,3 +195,4 @@ If you have any questions, feel free to open an issue. Alternatively, you can jo
201195
<h2 id="FAQ">🙋 FAQ</h2>
202196

203197
Some common questions are answered in the [FAQ](doc/en/FAQ.md).
198+

csrc/balance_serve/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}")
1010
project(balance_serve VERSION 0.1.0)
1111

1212
set(CMAKE_CXX_STANDARD 20)
13-
# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
14-
# set(CMAKE_BUILD_TYPE "Debug")
15-
set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
16-
set(CMAKE_BUILD_TYPE "Release")
13+
set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
14+
set(CMAKE_BUILD_TYPE "Debug")
15+
# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
16+
# set(CMAKE_BUILD_TYPE "Release")
1717

1818

1919
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)

csrc/balance_serve/sched/model_config.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,14 @@ using ModelName = std::string;
1515
class ModelConfig {
1616
public:
1717
DimSize hidden_size;
18-
DimSize intermediate_size;
1918
size_t max_position_embeddings;
20-
std::string model_type;
2119
size_t num_attention_heads;
2220
size_t num_hidden_layers;
2321
size_t num_key_value_heads;
2422
size_t vocab_size;
2523

26-
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
27-
max_position_embeddings, model_type,
24+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size,
25+
max_position_embeddings,
2826
num_attention_heads, num_hidden_layers,
2927
num_key_value_heads, vocab_size);
3028

csrc/ktransformers_ext/ext_bindings.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -683,12 +683,12 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
683683
py::class_<MOEConfig>(moe_module, "MOEConfig")
684684
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
685685
int intermediate_size, int stride, int group_min_len,
686-
int group_max_len, intptr_t gate_proj,
686+
int group_max_len, bool use_silu, intptr_t gate_proj,
687687
intptr_t up_proj, intptr_t down_proj, int gate_type,
688688
int up_type, int down_type, int hidden_type) {
689689
return MOEConfig(expert_num, routed_expert_num, hidden_size,
690690
intermediate_size, stride, group_min_len,
691-
group_max_len, (void *)gate_proj, (void *)up_proj,
691+
group_max_len, use_silu, (void *)gate_proj, (void *)up_proj,
692692
(void *)down_proj, (ggml_type)gate_type,
693693
(ggml_type)up_type, (ggml_type)down_type,
694694
(ggml_type)hidden_type);
@@ -703,11 +703,11 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
703703
py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig")
704704
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
705705
int intermediate_size,
706-
int max_len, intptr_t gate_proj,
706+
int max_len, bool use_silu, intptr_t gate_proj,
707707
intptr_t up_proj, intptr_t down_proj) {
708708
return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,
709709
intermediate_size,
710-
max_len, (void *)gate_proj,
710+
max_len, use_silu, (void *)gate_proj,
711711
(void *)up_proj, (void *)down_proj);
712712
}));
713713

csrc/ktransformers_ext/operators/amx/moe.hpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,29 @@ static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
6969
return _mm512_mul_ps(act_val, up_val);
7070
}
7171

72+
static inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) {
73+
__m512 zero_vec = _mm512_setzero_ps();
74+
__m512 act_val = _mm512_max_ps(zero_vec, gate_val);
75+
return _mm512_mul_ps(act_val, up_val);
76+
}
77+
7278
struct AMX_MOEConfig {
7379
int expert_num;
7480
int routed_expert_num;
7581
int hidden_size;
7682
int intermediate_size;
7783
int max_len;
84+
bool use_silu;
7885
void *gate_proj;
7986
void *up_proj;
8087
void *down_proj;
8188

8289
AMX_MOEConfig() {}
8390

84-
AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,
91+
AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, bool use_silu,
8592
void *gate_proj, void *up_proj, void *down_proj)
8693
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),
87-
intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),
94+
intermediate_size(intermediate_size), max_len(max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj),
8895
down_proj(down_proj) {}
8996
};
9097

@@ -336,18 +343,35 @@ template <class T> class AMX_MOE {
336343
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
337344
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
338345
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
339-
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
340-
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
341-
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
342-
for (int j = n_start; j < n_end; j += 32) {
343-
__m512 gate_val0, gate_val1, up_val0, up_val1;
344-
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
345-
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
346-
__m512 result0 = act_fn(gate_val0, up_val0);
347-
__m512 result1 = act_fn(gate_val1, up_val1);
348-
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
349-
}
346+
if (config_.use_silu) {
347+
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
348+
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
349+
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
350+
for (int j = n_start; j < n_end; j += 32) {
351+
__m512 gate_val0, gate_val1, up_val0, up_val1;
352+
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
353+
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
354+
__m512 result0 = act_fn(gate_val0, up_val0);
355+
__m512 result1 = act_fn(gate_val1, up_val1);
356+
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
357+
}
358+
}
359+
}
360+
else {
361+
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
362+
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
363+
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
364+
for (int j = n_start; j < n_end; j += 32) {
365+
__m512 gate_val0, gate_val1, up_val0, up_val1;
366+
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
367+
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
368+
__m512 result0 = relu_act_fn(gate_val0, up_val0);
369+
__m512 result1 = relu_act_fn(gate_val1, up_val1);
370+
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
371+
}
372+
}
350373
}
374+
351375
},
352376
nullptr);
353377
backend->do_work_stealing_job(

csrc/ktransformers_ext/operators/llamafile/moe.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "moe.h"
1111
#include <iostream>
1212
#include <cstdint>
13+
#include <math.h>
1314

1415
#ifdef USE_NUMA
1516
#include <numa.h>
@@ -134,6 +135,14 @@ static float act_fn(float x) {
134135
return x / (1.0f + expf(-x));
135136
}
136137

138+
static float act_fn_relu(float x) {
139+
if(x > 0.0){
140+
return x;
141+
} else {
142+
return 0.0;
143+
}
144+
}
145+
137146
void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
138147
const void* gate_input_ptr;
139148
const void* up_input_ptr;
@@ -182,8 +191,16 @@ void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, c
182191

183192
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
184193
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
185-
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
186-
s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
194+
if(config_.use_silu){
195+
// use silu as act fn
196+
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
197+
s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
198+
}
199+
} else {
200+
// use relu as act fn
201+
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
202+
s_intermediate_fp32_[expert_idx][i] = act_fn_relu(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
203+
}
187204
}
188205
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
189206
float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride;
@@ -304,8 +321,14 @@ void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float*
304321
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
305322
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
306323
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
307-
for (int j = ith * stride; j < (ith + 1) * stride; j++) {
308-
m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
324+
if(config_.use_silu){
325+
for (int j = ith * stride; j < (ith + 1) * stride; j++) {
326+
m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
327+
}
328+
} else {
329+
for (int j = ith * stride; j < (ith + 1) * stride; j++) {
330+
m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn_relu(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
331+
}
309332
}
310333
float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;
311334
void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);

csrc/ktransformers_ext/operators/llamafile/moe.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct MOEConfig {
3232
int stride;
3333
int group_min_len;
3434
int group_max_len;
35+
bool use_silu;
3536
void* gate_proj;
3637
void* up_proj;
3738
void* down_proj;
@@ -42,8 +43,8 @@ struct MOEConfig {
4243

4344
MOEConfig() {}
4445

45-
MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
46-
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
46+
MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, bool use_silu, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
47+
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
4748
};
4849

4950
class MOE {

doc/en/SmallThinker_and_Glm4moe.md

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SmallThinker & GLM-4-MoE Support for KTransformers
2+
3+
## Introduction
4+
5+
### Overview
6+
We are excited to announce that **KTransformers now supports both SmallThinker and GLM-4-MoE**.
7+
8+
- **SmallThinker-21B (bf16)**: ~26 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~84 GB DRAM.
9+
- **GLM-4-MoE 110B (bf16)**: ~11 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~440 GB DRAM.
10+
- **GLM-4-MoE 110B (AMX INT8)**: prefill ~309 TPS / decode ~16 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~220 GB DRAM.
11+
12+
### Model & Resource Links
13+
- **SmallThinker-21B**
14+
- *(to be announced)*
15+
- **GLM-4-MoE 110B**
16+
- *(to be announced)*
17+
18+
---
19+
20+
## Installation Guide
21+
22+
### 1. Resource Requirements
23+
24+
| Model | Precision | Experts | DRAM Needed | GPU Memory Needed\* | TPS (approx.) |
25+
| ------------------------- | ---------- | ------- | ----------- | ------------------- | --------------------------------------- |
26+
| SmallThinker-21B | bf16 | 32 | \~42 GB | 14 GB | \~26 TPS |
27+
| GLM-4-MoE 110B | bf16 | 128 | \~220 GB | 14 GB | \~11 TPS |
28+
| GLM-4-MoE 110B (AMX INT8) | int8 | 128 | \~220 GB | 14 GB | \~16 TPS
29+
30+
31+
\* Exact GPU memory depends on sequence length, batch size, and kernels used.
32+
33+
### 2. Prepare Models
34+
35+
```bash
36+
# Example: download original safetensors (adjust to your paths/repos)
37+
# (Fill in actual repos/filenames yourself)
38+
39+
# SmallThinker-21B
40+
huggingface-cli download --resume-download placeholder-org/Model-TBA \
41+
--local-dir ./Model-TBA
42+
43+
# GLM-4-MoE 110B
44+
huggingface-cli download --resume-download placeholder-org/Model-TBA \
45+
--local-dir ./Model-TBA
46+
```
47+
48+
49+
### 3. Install KTransformers
50+
51+
Follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
52+
53+
```bash
54+
pip install ktransformers # or from source if you need bleeding-edge features
55+
```
56+
57+
### 4. Run SmallThinker-21B Inference Server
58+
59+
```bash
60+
python ktransformers/server/main.py \
61+
--port 10021 \
62+
--model_path /abs/path/to/SmallThinker-21B-bf16 \
63+
--model_name SmallThinkerForCausalLM \
64+
--optimize_config_path ktransformers/optimize/optimize_rules/SmallThinker-serve.yaml \
65+
--max_new_tokens 1024 \
66+
--cache_lens 32768 \
67+
--chunk_size 256 \
68+
--max_batch_size 4 \
69+
--backend_type balance_serve
70+
```
71+
72+
### 5. Run GLM-4-MoE 110B Inference Server
73+
74+
```bash
75+
python ktransformers/server/main.py \
76+
--port 10110 \
77+
--model_name Glm4MoeForCausalLM \
78+
--model_path /abs/path/to/GLM-4-MoE-110B-bf16 \
79+
--optimize_config_path ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml \
80+
--max_new_tokens 1024 \
81+
--cache_lens 32768 \
82+
--chunk_size 256 \
83+
--max_batch_size 4 \
84+
--backend_type balance_serve
85+
```
86+
87+
### 6. Access Server
88+
89+
```bash
90+
curl -X POST http://localhost:10021/v1/chat/completions \
91+
-H "accept: application/json" \
92+
-H "Content-Type: application/json" \
93+
-d '{
94+
"messages": [
95+
{"role": "user", "content": "hello"}
96+
],
97+
"model": "SmallThinker-21B",
98+
"temperature": 0.3,
99+
"top_p": 1.0,
100+
"stream": true
101+
}'
102+
```
103+
104+
```bash
105+
curl -X POST http://localhost:10110/v1/chat/completions \
106+
-H "accept: application/json" \
107+
-H "Content-Type: application/json" \
108+
-d '{
109+
"messages": [
110+
{"role": "user", "content": "hello"}
111+
],
112+
"model": "GLM-4-MoE-110B",
113+
"temperature": 0.3,
114+
"top_p": 1.0,
115+
"stream": true
116+
}'
117+
```

0 commit comments

Comments
 (0)