Skip to content

Commit 74c9774

Browse files
ashwinbehhuang
andcommitted
fix(inference): enable routing of models with provider_data alone (#3928)
This PR enables routing of fully qualified model IDs of the form `provider_id/model_id` even when the models are not registered with the Stack. Here's the situation: assume a remote inference provider which works only when users provide their own API keys via `X-LlamaStack-Provider-Data` header. By definition, we cannot list models and hence update our routing registry. But because we _require_ a provider ID in the models now, we can identify which provider to route to and let that provider decide. Note that we still try to look up our registry since it may have a pre-registered alias. Just that we don't outright fail when we are not able to look it up. Also, updated inference router so that the responses have the _exact_ model that the request had. ## Test Plan Added an integration test Closes #3929 --------- Co-authored-by: ehhuang <[email protected]>
1 parent 6822dec commit 74c9774

File tree

6 files changed

+216
-63
lines changed

6 files changed

+216
-63
lines changed

src/llama_stack/core/routers/inference.py

Lines changed: 68 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,17 @@ def _construct_metrics(
110110
prompt_tokens: int,
111111
completion_tokens: int,
112112
total_tokens: int,
113-
model: Model,
113+
fully_qualified_model_id: str,
114+
provider_id: str,
114115
) -> list[MetricEvent]:
115116
"""Constructs a list of MetricEvent objects containing token usage metrics.
116117
117118
Args:
118119
prompt_tokens: Number of tokens in the prompt
119120
completion_tokens: Number of tokens in the completion
120121
total_tokens: Total number of tokens used
121-
model: Model object containing model_id and provider_id
122+
fully_qualified_model_id:
123+
provider_id: The provider identifier
122124
123125
Returns:
124126
List of MetricEvent objects with token usage metrics
@@ -144,8 +146,8 @@ def _construct_metrics(
144146
timestamp=datetime.now(UTC),
145147
unit="tokens",
146148
attributes={
147-
"model_id": model.model_id,
148-
"provider_id": model.provider_id,
149+
"model_id": fully_qualified_model_id,
150+
"provider_id": provider_id,
149151
},
150152
)
151153
)
@@ -158,7 +160,9 @@ async def _compute_and_log_token_usage(
158160
total_tokens: int,
159161
model: Model,
160162
) -> list[MetricInResponse]:
161-
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
163+
metrics = self._construct_metrics(
164+
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
165+
)
162166
if self.telemetry_enabled:
163167
for metric in metrics:
164168
enqueue_event(metric)
@@ -178,14 +182,25 @@ async def _count_tokens(
178182
encoded = self.formatter.encode_content(messages)
179183
return len(encoded.tokens) if encoded and encoded.tokens else 0
180184

181-
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
182-
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
183-
model = await self.routing_table.get_model(model_id)
184-
if model is None:
185+
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
186+
model = await self.routing_table.get_object_by_identifier("model", model_id)
187+
if model:
188+
if model.model_type != expected_model_type:
189+
raise ModelTypeError(model_id, model.model_type, expected_model_type)
190+
191+
provider = await self.routing_table.get_provider_impl(model.identifier)
192+
return provider, model.provider_resource_id
193+
194+
splits = model_id.split("/", maxsplit=1)
195+
if len(splits) != 2:
196+
raise ModelNotFoundError(model_id)
197+
198+
provider_id, provider_resource_id = splits
199+
if provider_id not in self.routing_table.impls_by_provider_id:
200+
logger.warning(f"Provider {provider_id} not found for model {model_id}")
185201
raise ModelNotFoundError(model_id)
186-
if model.model_type != expected_model_type:
187-
raise ModelTypeError(model_id, model.model_type, expected_model_type)
188-
return model
202+
203+
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
189204

190205
async def rerank(
191206
self,
@@ -195,14 +210,8 @@ async def rerank(
195210
max_num_results: int | None = None,
196211
) -> RerankResponse:
197212
logger.debug(f"InferenceRouter.rerank: {model}")
198-
model_obj = await self._get_model(model, ModelType.rerank)
199-
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
200-
return await provider.rerank(
201-
model=model_obj.identifier,
202-
query=query,
203-
items=items,
204-
max_num_results=max_num_results,
205-
)
213+
provider, provider_resource_id = await self._get_model_provider(model, ModelType.rerank)
214+
return await provider.rerank(provider_resource_id, query, items, max_num_results)
206215

207216
async def openai_completion(
208217
self,
@@ -211,24 +220,24 @@ async def openai_completion(
211220
logger.debug(
212221
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
213222
)
214-
model_obj = await self._get_model(params.model, ModelType.llm)
215-
216-
# Update params with the resolved model identifier
217-
params.model = model_obj.identifier
223+
request_model_id = params.model
224+
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
225+
params.model = provider_resource_id
218226

219-
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
220227
if params.stream:
221228
return await provider.openai_completion(params)
222229
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
223230
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
224231

225232
response = await provider.openai_completion(params)
233+
response.model = request_model_id
226234
if self.telemetry_enabled:
227235
metrics = self._construct_metrics(
228236
prompt_tokens=response.usage.prompt_tokens,
229237
completion_tokens=response.usage.completion_tokens,
230238
total_tokens=response.usage.total_tokens,
231-
model=model_obj,
239+
fully_qualified_model_id=request_model_id,
240+
provider_id=provider.__provider_id__,
232241
)
233242
for metric in metrics:
234243
enqueue_event(metric)
@@ -246,7 +255,9 @@ async def openai_chat_completion(
246255
logger.debug(
247256
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
248257
)
249-
model_obj = await self._get_model(params.model, ModelType.llm)
258+
request_model_id = params.model
259+
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
260+
params.model = provider_resource_id
250261

251262
# Use the OpenAI client for a bit of extra input validation without
252263
# exposing the OpenAI client itself as part of our API surface
@@ -264,22 +275,20 @@ async def openai_chat_completion(
264275
params.tool_choice = None
265276
params.tools = None
266277

267-
# Update params with the resolved model identifier
268-
params.model = model_obj.identifier
269-
270-
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
271278
if params.stream:
272279
response_stream = await provider.openai_chat_completion(params)
273280

274281
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
275282
# We need to add metrics to each chunk and store the final completion
276283
return self.stream_tokens_and_compute_metrics_openai_chat(
277284
response=response_stream,
278-
model=model_obj,
285+
fully_qualified_model_id=request_model_id,
286+
provider_id=provider.__provider_id__,
279287
messages=params.messages,
280288
)
281289

282290
response = await self._nonstream_openai_chat_completion(provider, params)
291+
response.model = request_model_id
283292

284293
# Store the response with the ID that will be returned to the client
285294
if self.store:
@@ -290,7 +299,8 @@ async def openai_chat_completion(
290299
prompt_tokens=response.usage.prompt_tokens,
291300
completion_tokens=response.usage.completion_tokens,
292301
total_tokens=response.usage.total_tokens,
293-
model=model_obj,
302+
fully_qualified_model_id=request_model_id,
303+
provider_id=provider.__provider_id__,
294304
)
295305
for metric in metrics:
296306
enqueue_event(metric)
@@ -307,13 +317,13 @@ async def openai_embeddings(
307317
logger.debug(
308318
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
309319
)
310-
model_obj = await self._get_model(params.model, ModelType.embedding)
311-
312-
# Update model to use resolved identifier
313-
params.model = model_obj.identifier
320+
request_model_id = params.model
321+
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.embedding)
322+
params.model = provider_resource_id
314323

315-
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
316-
return await provider.openai_embeddings(params)
324+
response = await provider.openai_embeddings(params)
325+
response.model = request_model_id
326+
return response
317327

318328
async def list_chat_completions(
319329
self,
@@ -369,7 +379,8 @@ async def stream_tokens_and_compute_metrics(
369379
self,
370380
response,
371381
prompt_tokens,
372-
model,
382+
fully_qualified_model_id: str,
383+
provider_id: str,
373384
tool_prompt_format: ToolPromptFormat | None = None,
374385
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
375386
completion_text = ""
@@ -407,7 +418,8 @@ async def stream_tokens_and_compute_metrics(
407418
prompt_tokens=prompt_tokens,
408419
completion_tokens=completion_tokens,
409420
total_tokens=total_tokens,
410-
model=model,
421+
fully_qualified_model_id=fully_qualified_model_id,
422+
provider_id=provider_id,
411423
)
412424
for metric in completion_metrics:
413425
if metric.metric in [
@@ -427,7 +439,8 @@ async def stream_tokens_and_compute_metrics(
427439
prompt_tokens or 0,
428440
completion_tokens or 0,
429441
total_tokens,
430-
model,
442+
fully_qualified_model_id=fully_qualified_model_id,
443+
provider_id=provider_id,
431444
)
432445
async_metrics = [
433446
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
@@ -439,7 +452,8 @@ async def count_tokens_and_compute_metrics(
439452
self,
440453
response: ChatCompletionResponse | CompletionResponse,
441454
prompt_tokens,
442-
model,
455+
fully_qualified_model_id: str,
456+
provider_id: str,
443457
tool_prompt_format: ToolPromptFormat | None = None,
444458
):
445459
if isinstance(response, ChatCompletionResponse):
@@ -456,7 +470,8 @@ async def count_tokens_and_compute_metrics(
456470
prompt_tokens=prompt_tokens,
457471
completion_tokens=completion_tokens,
458472
total_tokens=total_tokens,
459-
model=model,
473+
fully_qualified_model_id=fully_qualified_model_id,
474+
provider_id=provider_id,
460475
)
461476
for metric in completion_metrics:
462477
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
@@ -470,14 +485,16 @@ async def count_tokens_and_compute_metrics(
470485
prompt_tokens or 0,
471486
completion_tokens or 0,
472487
total_tokens,
473-
model,
488+
fully_qualified_model_id=fully_qualified_model_id,
489+
provider_id=provider_id,
474490
)
475491
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
476492

477493
async def stream_tokens_and_compute_metrics_openai_chat(
478494
self,
479495
response: AsyncIterator[OpenAIChatCompletionChunk],
480-
model: Model,
496+
fully_qualified_model_id: str,
497+
provider_id: str,
481498
messages: list[OpenAIMessageParam] | None = None,
482499
) -> AsyncIterator[OpenAIChatCompletionChunk]:
483500
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
@@ -497,6 +514,8 @@ async def stream_tokens_and_compute_metrics_openai_chat(
497514
if created is None and chunk.created:
498515
created = chunk.created
499516

517+
chunk.model = fully_qualified_model_id
518+
500519
# Accumulate choice data for final assembly
501520
if chunk.choices:
502521
for choice_delta in chunk.choices:
@@ -553,7 +572,8 @@ async def stream_tokens_and_compute_metrics_openai_chat(
553572
prompt_tokens=chunk.usage.prompt_tokens,
554573
completion_tokens=chunk.usage.completion_tokens,
555574
total_tokens=chunk.usage.total_tokens,
556-
model=model,
575+
model_id=fully_qualified_model_id,
576+
provider_id=provider_id,
557577
)
558578
for metric in metrics:
559579
enqueue_event(metric)
@@ -601,7 +621,7 @@ async def stream_tokens_and_compute_metrics_openai_chat(
601621
id=id,
602622
choices=assembled_choices,
603623
created=created or int(time.time()),
604-
model=model.identifier,
624+
model=fully_qualified_model_id,
605625
object="chat.completion",
606626
)
607627
logger.debug(f"InferenceRouter.completion_response: {final_response}")

src/llama_stack/providers/utils/inference/embedding_mixin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ async def openai_embeddings(
4646
raise ValueError("Empty list not supported")
4747

4848
# Get the model and generate embeddings
49-
model_obj = await self.model_store.get_model(params.model)
50-
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
49+
embedding_model = await self._load_sentence_transformer_model(params.model)
5150
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
5251

5352
# Convert embeddings to the requested format

src/llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,11 @@ async def _get_provider_model_id(self, model: str) -> str:
226226
:param model: The registered model name/identifier
227227
:return: The provider-specific model ID (e.g., "gpt-4")
228228
"""
229-
# Look up the registered model to get the provider-specific model ID
230229
# self.model_store is injected by the distribution system at runtime
230+
if not await self.model_store.has_model(model): # type: ignore[attr-defined]
231+
return model
232+
233+
# Look up the registered model to get the provider-specific model ID
231234
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
232235
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
233236
if model_obj.provider_resource_id is None:

tests/integration/inference/test_openai_embeddings.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe
161161

162162
assert response.object == "list"
163163

164-
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
165-
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
164+
assert response.model == embedding_model_id
166165
assert len(response.data) == 1
167166
assert response.data[0].object == "embedding"
168167
assert response.data[0].index == 0
@@ -186,8 +185,7 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e
186185

187186
assert response.object == "list"
188187

189-
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
190-
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
188+
assert response.model == embedding_model_id
191189
assert len(response.data) == len(input_texts)
192190

193191
for i, embedding_data in enumerate(response.data):
@@ -365,8 +363,7 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
365363
# Validate response structure
366364
assert response.object == "list"
367365

368-
# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
369-
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
366+
assert response.model == embedding_model_id
370367
assert len(response.data) == len(input_texts)
371368

372369
# Validate each embedding in the batch

0 commit comments

Comments
 (0)