Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f9c069c
Modularize fused experts and integrate PPLX kernels (#15956)
bnellnm May 14, 2025
8568650
[CI] Disable Failing Tests (#18165)
robertgshaw2-redhat May 14, 2025
749f792
[Frontend] decrease import time of vllm.multimodal (#18031)
davidxia May 14, 2025
d93c976
[Kernel] Have rotary embeddings support tensors (#18046)
LucasWilkinson May 14, 2025
2fc9075
[V1] Structured Outputs + Thinking compatibility (#16577)
aarnphm May 14, 2025
7974736
Add support for loading torchao models with `AOPerModuleConfig` (#17826)
jerryzh168 May 14, 2025
78aa341
[CI] Fix race condition in test_kv_cache_events test (#18169)
russellb May 14, 2025
2142035
[V1] Support multiple kv connectors (#17564)
mgoin May 14, 2025
09f106a
Upload vllm index for the rc builds (#18173)
atalman May 14, 2025
f25e0d1
[Bugfix]: make most of `test_openai_schema.py` pass (#17664)
davidxia May 15, 2025
e60f550
[v1] Support multiple KV cache groups in GPU model runner (#17945)
heheda12345 May 15, 2025
65334ef
[V1][Metrics] Remove unused code (#18158)
markmc May 15, 2025
afe3236
[Chore] astral's ty (#18116)
aarnphm May 15, 2025
2dff093
[Misc] add lobe-chat support (#18177)
reidliu41 May 15, 2025
83f74c6
[Fix][ROCm] Enforce eager for all encoder-decoder models on ROCm (#18…
ProExpertProg May 15, 2025
26d0419
Update deprecated type hinting in `models` (#18132)
hmellor May 15, 2025
e6b8e65
[Bugfix] Fix fp8 tests for triton_unified_attention for Triton 3.3 (#…
tdoublep May 15, 2025
4f07a64
Support custom implementations of VideoLoader backends. (#18091)
huachenheli May 15, 2025
420caf7
[UT] Add ut for none hash (#17892)
andyxning May 15, 2025
dd2a945
[Model] Allow the use of sliding window in Qwen2 (#17772)
inkcherry May 15, 2025
70f8b96
[Bugfix] Fix FusedMoEPrepareAndFinalize for cuda-disalike backends (#…
MengqingCao May 15, 2025
de71fec
[CI] don't skip fixed `test_kv_cache_events()` (#18183)
davidxia May 15, 2025
a8f5aec
[V1] Update zmq socket creation in nixl connector (#18148)
russellb May 15, 2025
a9944aa
fix: typos (#18151)
omahs May 15, 2025
07ad271
Update deprecated type hinting in `model_loader` (#18130)
hmellor May 15, 2025
451da4b
add tools into TokenizeChatRequest (#18187)
hustxiayang May 15, 2025
01c2233
[Kernel] [V1] Fix performance regression for triton unified attention…
tdoublep May 15, 2025
566ec04
Adding "Basic Models Test" and "Multi-Modal Models Test (Extended) 3"…
Alexei-V-Ivanov-AMD May 15, 2025
51ff154
Improve examples rendering in docs and GitHub (#18203)
hmellor May 15, 2025
2aa5470
[Frontend] Fix chat template content format detection (#18190)
schoennenbeck May 15, 2025
fadb8d5
[Bugfix]Change the exception thrown by call_hf_processor from Runtime…
Abatom May 15, 2025
9254052
[Bugfix] [ROCm]: Remove assertion logic when using AITER fused moe in…
tjtanaa May 15, 2025
e3f3aee
[Misc] Avoid cuda graph log when sizes still match (#18202)
NickLucche May 15, 2025
61c0b12
triton kernel fusion for EAGLE
leo-cf-tian May 15, 2025
c89f9ca
include all state updates
leo-cf-tian May 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .buildkite/scripts/hardware_ci/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"*
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
fi

if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
fi

if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then
commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"}
fi

if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
fi
Expand Down
1 change: 1 addition & 0 deletions .buildkite/scripts/upload-wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ else
fi

aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html"
7 changes: 3 additions & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ steps:
- pytest -v -s v1/spec_decode
- pytest -v -s v1/kv_connector/unit
- pytest -v -s v1/test_serial_utils.py
- pytest -v -s v1/test_stats.py
- pytest -v -s v1/test_utils.py
- pytest -v -s v1/test_oracle.py
# TODO: accuracy does not match, whether setting
Expand Down Expand Up @@ -456,7 +455,7 @@ steps:
##### models test #####

- label: Basic Models Test # 24min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies:
- vllm/
Expand Down Expand Up @@ -528,7 +527,7 @@ steps:
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'

- label: Multi-Modal Models Test (Extended) 3
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
optional: true
source_file_dependencies:
- vllm/
Expand All @@ -538,7 +537,7 @@ steps:
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'

- label: Quantized Models Test
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/model_executor/layers/quantization
- tests/models/quantization
Expand Down
3 changes: 3 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
Expand Down
4 changes: 2 additions & 2 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ __device__ void paged_attention_kernel(

// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
Expand Down Expand Up @@ -259,7 +259,7 @@ __device__ void paged_attention_kernel(

// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
Expand Down
14 changes: 14 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)

#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
8 changes: 4 additions & 4 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}

if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
Expand All @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>());
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
Expand All @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
Expand All @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");

VLLM_DISPATCH_INTEGRAL_TYPES(
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
Expand Down
63 changes: 45 additions & 18 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}

template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
{

using cub_kvp = cub::KeyValuePair<int, float>;
Expand Down Expand Up @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/

template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
Expand Down Expand Up @@ -397,8 +405,8 @@ struct TopkConstants
};
} // namespace detail

template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
Expand All @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);

template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
IndType* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
Expand Down Expand Up @@ -493,14 +502,32 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);

if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else
{
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
}
40 changes: 28 additions & 12 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,17 @@ inline __device__ void apply_rotary_embedding(
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride) {
const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;

const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int64_t token_head =
token_idx * query_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
Expand All @@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int64_t token_head =
token_idx * key_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
Expand All @@ -84,15 +87,16 @@ __global__ void rotary_embedding_kernel(
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
token_idx, query_stride, key_stride, head_stride);
}

template <typename scalar_t, bool IS_NEOX>
Expand All @@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
Expand All @@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(

apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
token_idx, query_stride, key_stride, head_stride);
}

} // namespace vllm
Expand Down Expand Up @@ -179,6 +183,12 @@ void rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
Expand All @@ -190,14 +200,14 @@ void rotary_embedding(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
num_heads, num_kv_heads, head_size);
head_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
key_stride, head_stride, num_heads, num_kv_heads, head_size);
}
});
}
Expand Down Expand Up @@ -263,6 +273,12 @@ void batched_rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
Expand All @@ -276,15 +292,15 @@ void batched_rotary_embedding(
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
key_stride, head_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
key_stride, head_stride, num_heads, num_kv_heads, head_size);
}
});
}
1 change: 1 addition & 0 deletions docs/source/deployment/frameworks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ chatbox
dify
dstack
helm
lobe-chat
lws
modal
open-webui
Expand Down
13 changes: 13 additions & 0 deletions docs/source/deployment/frameworks/lobe-chat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
(deployment-lobe-chat)=

# Lobe Chat

[Lobe Chat](https://github.com/lobehub/lobe-chat) is an open-source, modern-design ChatGPT/LLMs UI/Framework.

Supports speech-synthesis, multi-modal, and extensible (function call) plugin system.

One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application.

It supports vLLM as a AI model provider to efficiently serve large language models.

For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm).
Loading