Skip to content

Commit ae449e1

Browse files
committed
chore: introduce write queue for inference_store
# What does this PR do? Adds a write worker queue for writes to inference store. This avoids overwhelming request processing with slow inference writes. ## Test Plan Benchmark: ``` cd /docs/source/distributions/k8s-benchmark # start mock server python openai-mock-server.py --port 8000 # start stack server uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml # run benchmark script uv run python3 benchmark.py --duration 120 --concurrent 50 --base-url=http://localhost:8321/v1/openai/v1 --model=vllm-inference/meta-llama/Llama-3.2-3B-Instruct ``` Before: ============================================================ BENCHMARK RESULTS Response Time Statistics: Mean: 1.111s Median: 0.982s Min: 0.466s Max: 15.190s Std Dev: 1.091s Percentiles: P50: 0.982s P90: 1.281s P95: 1.439s P99: 5.476s Time to First Token (TTFT) Statistics: Mean: 0.474s Median: 0.347s Min: 0.175s Max: 15.129s Std Dev: 0.819s TTFT Percentiles: P50: 0.347s P90: 0.661s P95: 0.762s P99: 2.788s Streaming Statistics: Mean chunks per response: 67.2 Total chunks received: 122154 ============================================================ Total time: 120.00s Concurrent users: 50 Total requests: 1919 Successful requests: 1819 Failed requests: 100 Success rate: 94.8% Requests per second: 15.16 Errors (showing first 5): Request error: Request error: Request error: Request error: Request error: Benchmark completed. Stopping server (PID: 679)... Server stopped. After: ============================================================ BENCHMARK RESULTS Response Time Statistics: Mean: 1.085s Median: 1.089s Min: 0.451s Max: 2.002s Std Dev: 0.212s Percentiles: P50: 1.089s P90: 1.343s P95: 1.409s P99: 1.617s Time to First Token (TTFT) Statistics: Mean: 0.407s Median: 0.361s Min: 0.182s Max: 1.178s Std Dev: 0.175s TTFT Percentiles: P50: 0.361s P90: 0.644s P95: 0.744s P99: 0.932s Streaming Statistics: Mean chunks per response: 66.8 Total chunks received: 367240 ============================================================ Total time: 120.00s Concurrent users: 50 Total requests: 5495 Successful requests: 5495 Failed requests: 0 Success rate: 100.0% Requests per second: 45.79 Benchmark completed. Stopping server (PID: 97169)... Server stopped.
1 parent 28696c3 commit ae449e1

File tree

8 files changed

+139
-25
lines changed

8 files changed

+139
-25
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ AGENTS.md
3030
server.log
3131
CLAUDE.md
3232
.claude/
33+
*.log

docs/source/distributions/k8s-benchmark/benchmark.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,6 @@ def print_summary(self):
5858

5959
print(f"\n{'='*60}")
6060
print(f"BENCHMARK RESULTS")
61-
print(f"{'='*60}")
62-
print(f"Total time: {total_time:.2f}s")
63-
print(f"Concurrent users: {self.concurrent_users}")
64-
print(f"Total requests: {self.total_requests}")
65-
print(f"Successful requests: {self.success_count}")
66-
print(f"Failed requests: {len(self.errors)}")
67-
print(f"Success rate: {success_rate:.1f}%")
68-
print(f"Requests per second: {self.success_count / total_time:.2f}")
6961

7062
print(f"\nResponse Time Statistics:")
7163
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
@@ -106,6 +98,15 @@ def print_summary(self):
10698
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
10799
print(f" Total chunks received: {sum(self.chunks_received)}")
108100

101+
print(f"{'='*60}")
102+
print(f"Total time: {total_time:.2f}s")
103+
print(f"Concurrent users: {self.concurrent_users}")
104+
print(f"Total requests: {self.total_requests}")
105+
print(f"Successful requests: {self.success_count}")
106+
print(f"Failed requests: {len(self.errors)}")
107+
print(f"Success rate: {success_rate:.1f}%")
108+
print(f"Requests per second: {self.success_count / total_time:.2f}")
109+
109110
if self.errors:
110111
print(f"\nErrors (showing first 5):")
111112
for error in self.errors[:5]:
@@ -215,7 +216,7 @@ async def progress_reporter():
215216
await asyncio.sleep(1) # Report every second
216217
if time.time() >= last_report_time + 10: # Report every 10 seconds
217218
elapsed = time.time() - stats.start_time
218-
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
219+
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}")
219220
last_report_time = time.time()
220221
except asyncio.CancelledError:
221222
break

docs/source/distributions/k8s-benchmark/stack_run_config.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ version: '2'
22
image_name: kubernetes-benchmark-demo
33
apis:
44
- agents
5+
- files
56
- inference
67
- safety
78
- telemetry
@@ -19,6 +20,14 @@ providers:
1920
- provider_id: sentence-transformers
2021
provider_type: inline::sentence-transformers
2122
config: {}
23+
files:
24+
- provider_id: meta-reference-files
25+
provider_type: inline::localfs
26+
config:
27+
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
28+
metadata_store:
29+
type: sqlite
30+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
2231
vector_io:
2332
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
2433
provider_type: remote::chromadb

llama_stack/core/datatypes.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
431431
)
432432

433433

434+
class InferenceStoreConfig(BaseModel):
435+
sql_store_config: SqlStoreConfig
436+
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
437+
num_writers: int = Field(default=4, description="Number of concurrent background writers")
438+
439+
434440
class StackRunConfig(BaseModel):
435441
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
436442

@@ -464,11 +470,12 @@ class StackRunConfig(BaseModel):
464470
a default SQLite store will be used.""",
465471
)
466472

467-
inference_store: SqlStoreConfig | None = Field(
473+
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
468474
default=None,
469475
description="""
470-
Configuration for the persistence store used by the inference API. If not specified,
471-
a default SQLite store will be used.""",
476+
Configuration for the persistence store used by the inference API. Can be either a
477+
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
478+
If not specified, a default SQLite store will be used.""",
472479
)
473480

474481
# registry of "resources" in the distribution

llama_stack/core/routers/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ async def get_auto_router_impl(
7878

7979
# TODO: move pass configs to routers instead
8080
if api == Api.inference and run_config.inference_store:
81-
inference_store = InferenceStore(run_config.inference_store, policy)
81+
inference_store = InferenceStore(
82+
config=run_config.inference_store,
83+
policy=policy,
84+
)
8285
await inference_store.initialize()
8386
api_to_dep_impl["store"] = inference_store
8487

llama_stack/core/routers/inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ async def initialize(self) -> None:
9090

9191
async def shutdown(self) -> None:
9292
logger.debug("InferenceRouter.shutdown")
93+
if self.store:
94+
try:
95+
await self.store.shutdown()
96+
except Exception as e:
97+
logger.warning(f"Error during InferenceStore shutdown: {e}")
9398

9499
async def register_model(
95100
self,

llama_stack/providers/utils/inference/inference_store.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,53 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
import asyncio
7+
from typing import Any
8+
69
from llama_stack.apis.inference import (
710
ListOpenAIChatCompletionResponse,
811
OpenAIChatCompletion,
912
OpenAICompletionWithInputMessages,
1013
OpenAIMessageParam,
1114
Order,
1215
)
13-
from llama_stack.core.datatypes import AccessRule
14-
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
16+
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
17+
from llama_stack.log import get_logger
1518

1619
from ..sqlstore.api import ColumnDefinition, ColumnType
1720
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
18-
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
21+
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
22+
23+
logger = get_logger(name=__name__, category="inference_store")
1924

2025

2126
class InferenceStore:
22-
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
23-
if not sql_store_config:
24-
sql_store_config = SqliteSqlStoreConfig(
25-
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
27+
def __init__(
28+
self,
29+
config: InferenceStoreConfig | SqlStoreConfig,
30+
policy: list[AccessRule],
31+
):
32+
# Handle backward compatibility
33+
if not isinstance(config, InferenceStoreConfig):
34+
# Legacy: SqlStoreConfig passed directly as config
35+
config = InferenceStoreConfig(
36+
sql_store_config=config,
2637
)
27-
self.sql_store_config = sql_store_config
38+
39+
self.config = config
40+
self.sql_store_config = config.sql_store_config
2841
self.sql_store = None
2942
self.policy = policy
3043

44+
# Disable write queue for SQLite to avoid concurrency issues
45+
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
46+
47+
# Async write queue and worker control
48+
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
49+
self._worker_tasks: list[asyncio.Task[Any]] = []
50+
self._max_write_queue_size: int = config.max_write_queue_size
51+
self._num_writers: int = max(1, config.num_writers)
52+
3153
async def initialize(self):
3254
"""Create the necessary tables if they don't exist."""
3355
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
@@ -42,14 +64,68 @@ async def initialize(self):
4264
},
4365
)
4466

67+
if self.enable_write_queue:
68+
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
69+
for _ in range(self._num_writers):
70+
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
71+
else:
72+
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
73+
74+
async def shutdown(self) -> None:
75+
if not self._worker_tasks:
76+
return
77+
if self._queue is not None:
78+
await self._queue.join()
79+
for t in self._worker_tasks:
80+
if not t.done():
81+
t.cancel()
82+
for t in self._worker_tasks:
83+
try:
84+
await t
85+
except asyncio.CancelledError:
86+
pass
87+
self._worker_tasks.clear()
88+
89+
async def flush(self) -> None:
90+
"""Wait for all queued writes to complete. Useful for testing."""
91+
if self.enable_write_queue and self._queue is not None:
92+
await self._queue.join()
93+
4594
async def store_chat_completion(
4695
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
4796
) -> None:
48-
if not self.sql_store:
49-
raise ValueError("Inference store is not initialized")
50-
97+
if self.enable_write_queue:
98+
if self._queue is None:
99+
raise ValueError("Inference store is not initialized")
100+
try:
101+
self._queue.put_nowait((chat_completion, input_messages))
102+
except asyncio.QueueFull:
103+
logger.warning(
104+
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
105+
)
106+
await self._queue.put((chat_completion, input_messages))
107+
else:
108+
await self._write_chat_completion(chat_completion, input_messages)
109+
110+
async def _worker_loop(self) -> None:
111+
assert self._queue is not None
112+
while True:
113+
try:
114+
item = await self._queue.get()
115+
except asyncio.CancelledError:
116+
break
117+
chat_completion, input_messages = item
118+
try:
119+
await self._write_chat_completion(chat_completion, input_messages)
120+
except Exception as e: # noqa: BLE001
121+
logger.error(f"Error writing chat completion: {e}")
122+
finally:
123+
self._queue.task_done()
124+
125+
async def _write_chat_completion(
126+
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
127+
) -> None:
51128
data = chat_completion.model_dump()
52-
53129
await self.sql_store.insert(
54130
table="chat_completions",
55131
data={

tests/unit/utils/inference/test_inference_store.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic():
6565
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
6666
await store.store_chat_completion(completion, input_messages)
6767

68+
# Wait for all queued writes to complete
69+
await store.flush()
70+
6871
# Test 1: First page with limit=2, descending order (default)
6972
result = await store.list_chat_completions(limit=2, order=Order.desc)
7073
assert len(result.data) == 2
@@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending():
108111
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
109112
await store.store_chat_completion(completion, input_messages)
110113

114+
# Wait for all queued writes to complete
115+
await store.flush()
116+
111117
# Test ascending order pagination
112118
result = await store.list_chat_completions(limit=1, order=Order.asc)
113119
assert len(result.data) == 1
@@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter():
143149
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
144150
await store.store_chat_completion(completion, input_messages)
145151

152+
# Wait for all queued writes to complete
153+
await store.flush()
154+
146155
# Test pagination with model filter
147156
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
148157
assert len(result.data) == 1
@@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit():
190199
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
191200
await store.store_chat_completion(completion, input_messages)
192201

202+
# Wait for all queued writes to complete
203+
await store.flush()
204+
193205
# Test without limit
194206
result = await store.list_chat_completions(order=Order.desc)
195207
assert len(result.data) == 2

0 commit comments

Comments
 (0)