3
3
#
4
4
# This source code is licensed under the terms described in the LICENSE file in
5
5
# the root directory of this source tree.
6
+ import asyncio
7
+ from typing import Any
8
+
6
9
from llama_stack .apis .inference import (
7
10
ListOpenAIChatCompletionResponse ,
8
11
OpenAIChatCompletion ,
9
12
OpenAICompletionWithInputMessages ,
10
13
OpenAIMessageParam ,
11
14
Order ,
12
15
)
13
- from llama_stack .core .datatypes import AccessRule
14
- from llama_stack .core . utils . config_dirs import RUNTIME_BASE_DIR
16
+ from llama_stack .core .datatypes import AccessRule , InferenceStoreConfig
17
+ from llama_stack .log import get_logger
15
18
16
19
from ..sqlstore .api import ColumnDefinition , ColumnType
17
20
from ..sqlstore .authorized_sqlstore import AuthorizedSqlStore
18
- from ..sqlstore .sqlstore import SqliteSqlStoreConfig , SqlStoreConfig , sqlstore_impl
21
+ from ..sqlstore .sqlstore import SqlStoreConfig , SqlStoreType , sqlstore_impl
22
+
23
+ logger = get_logger (name = __name__ , category = "inference_store" )
19
24
20
25
21
26
class InferenceStore :
22
- def __init__ (self , sql_store_config : SqlStoreConfig , policy : list [AccessRule ]):
23
- if not sql_store_config :
24
- sql_store_config = SqliteSqlStoreConfig (
25
- db_path = (RUNTIME_BASE_DIR / "sqlstore.db" ).as_posix (),
27
+ def __init__ (
28
+ self ,
29
+ config : InferenceStoreConfig | SqlStoreConfig ,
30
+ policy : list [AccessRule ],
31
+ ):
32
+ # Handle backward compatibility
33
+ if not isinstance (config , InferenceStoreConfig ):
34
+ # Legacy: SqlStoreConfig passed directly as config
35
+ config = InferenceStoreConfig (
36
+ sql_store_config = config ,
26
37
)
27
- self .sql_store_config = sql_store_config
38
+
39
+ self .config = config
40
+ self .sql_store_config = config .sql_store_config
28
41
self .sql_store = None
29
42
self .policy = policy
30
43
44
+ # Disable write queue for SQLite to avoid concurrency issues
45
+ self .enable_write_queue = self .sql_store_config .type != SqlStoreType .sqlite
46
+
47
+ # Async write queue and worker control
48
+ self ._queue : asyncio .Queue [tuple [OpenAIChatCompletion , list [OpenAIMessageParam ]]] | None = None
49
+ self ._worker_tasks : list [asyncio .Task [Any ]] = []
50
+ self ._max_write_queue_size : int = config .max_write_queue_size
51
+ self ._num_writers : int = max (1 , config .num_writers )
52
+
31
53
async def initialize (self ):
32
54
"""Create the necessary tables if they don't exist."""
33
55
self .sql_store = AuthorizedSqlStore (sqlstore_impl (self .sql_store_config ))
@@ -42,14 +64,68 @@ async def initialize(self):
42
64
},
43
65
)
44
66
67
+ if self .enable_write_queue :
68
+ self ._queue = asyncio .Queue (maxsize = self ._max_write_queue_size )
69
+ for _ in range (self ._num_writers ):
70
+ self ._worker_tasks .append (asyncio .create_task (self ._worker_loop ()))
71
+ else :
72
+ logger .info ("Write queue disabled for SQLite to avoid concurrency issues" )
73
+
74
+ async def shutdown (self ) -> None :
75
+ if not self ._worker_tasks :
76
+ return
77
+ if self ._queue is not None :
78
+ await self ._queue .join ()
79
+ for t in self ._worker_tasks :
80
+ if not t .done ():
81
+ t .cancel ()
82
+ for t in self ._worker_tasks :
83
+ try :
84
+ await t
85
+ except asyncio .CancelledError :
86
+ pass
87
+ self ._worker_tasks .clear ()
88
+
89
+ async def flush (self ) -> None :
90
+ """Wait for all queued writes to complete. Useful for testing."""
91
+ if self .enable_write_queue and self ._queue is not None :
92
+ await self ._queue .join ()
93
+
45
94
async def store_chat_completion (
46
95
self , chat_completion : OpenAIChatCompletion , input_messages : list [OpenAIMessageParam ]
47
96
) -> None :
48
- if not self .sql_store :
49
- raise ValueError ("Inference store is not initialized" )
50
-
97
+ if self .enable_write_queue :
98
+ if self ._queue is None :
99
+ raise ValueError ("Inference store is not initialized" )
100
+ try :
101
+ self ._queue .put_nowait ((chat_completion , input_messages ))
102
+ except asyncio .QueueFull :
103
+ logger .warning (
104
+ f"Write queue full; adding chat completion id={ getattr (chat_completion , 'id' , '<unknown>' )} "
105
+ )
106
+ await self ._queue .put ((chat_completion , input_messages ))
107
+ else :
108
+ await self ._write_chat_completion (chat_completion , input_messages )
109
+
110
+ async def _worker_loop (self ) -> None :
111
+ assert self ._queue is not None
112
+ while True :
113
+ try :
114
+ item = await self ._queue .get ()
115
+ except asyncio .CancelledError :
116
+ break
117
+ chat_completion , input_messages = item
118
+ try :
119
+ await self ._write_chat_completion (chat_completion , input_messages )
120
+ except Exception as e : # noqa: BLE001
121
+ logger .error (f"Error writing chat completion: { e } " )
122
+ finally :
123
+ self ._queue .task_done ()
124
+
125
+ async def _write_chat_completion (
126
+ self , chat_completion : OpenAIChatCompletion , input_messages : list [OpenAIMessageParam ]
127
+ ) -> None :
51
128
data = chat_completion .model_dump ()
52
-
53
129
await self .sql_store .insert (
54
130
table = "chat_completions" ,
55
131
data = {
0 commit comments