Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions src/llama_stack/core/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,17 @@ def _construct_metrics(
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
fully_qualified_model_id: str,
provider_id: str,
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.

Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
model_id: The model identifier
provider_id: The provider identifier

Returns:
List of MetricEvent objects with token usage metrics
Expand All @@ -144,8 +146,8 @@ def _construct_metrics(
timestamp=datetime.now(UTC),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
"model_id": fully_qualified_model_id,
"provider_id": provider_id,
},
)
)
Expand All @@ -158,7 +160,9 @@ async def _compute_and_log_token_usage(
total_tokens: int,
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
metrics = self._construct_metrics(
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
)
if self.telemetry_enabled:
for metric in metrics:
enqueue_event(metric)
Expand All @@ -178,14 +182,25 @@ async def _count_tokens(
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0

async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
model = await self.routing_table.get_model(model_id)
if model is None:
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
model = await self.routing_table.get_object_by_identifier("model", model_id)
if model:
if model.model_type != expected_model_type:
raise ModelTypeError(model_id, model.model_type, expected_model_type)

provider = await self.routing_table.get_provider_impl(model.identifier)
return provider, model.provider_resource_id

splits = model_id.split("/", maxsplit=1)
if len(splits) != 2:
raise ModelNotFoundError(model_id)

provider_id, provider_resource_id = splits
if provider_id not in self.routing_table.impls_by_provider_id:
logger.warning(f"Provider {provider_id} not found for model {model_id}")
raise ModelNotFoundError(model_id)
if model.model_type != expected_model_type:
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model

return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id

async def rerank(
self,
Expand All @@ -195,14 +210,8 @@ async def rerank(
max_num_results: int | None = None,
) -> RerankResponse:
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
provider, provider_resource_id = await self._get_model_provider(model, ModelType.rerank)
return await provider.rerank(provider_resource_id, query, items, max_num_results)

async def openai_completion(
self,
Expand All @@ -211,24 +220,24 @@ async def openai_completion(
logger.debug(
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(params.model, ModelType.llm)

# Update params with the resolved model identifier
params.model = model_obj.identifier
request_model_id = params.model
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
params.model = provider_resource_id

provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.

response = await provider.openai_completion(params)
response.model = request_model_id
if self.telemetry_enabled:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
)
for metric in metrics:
enqueue_event(metric)
Expand All @@ -246,7 +255,9 @@ async def openai_chat_completion(
logger.debug(
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
)
model_obj = await self._get_model(params.model, ModelType.llm)
request_model_id = params.model
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
params.model = provider_resource_id

# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
Expand All @@ -264,22 +275,20 @@ async def openai_chat_completion(
params.tool_choice = None
params.tools = None

# Update params with the resolved model identifier
params.model = model_obj.identifier

provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if params.stream:
response_stream = await provider.openai_chat_completion(params)

# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream,
model=model_obj,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
messages=params.messages,
)

response = await self._nonstream_openai_chat_completion(provider, params)
response.model = request_model_id

# Store the response with the ID that will be returned to the client
if self.store:
Expand All @@ -290,7 +299,8 @@ async def openai_chat_completion(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
model=model_obj,
fully_qualified_model_id=request_model_id,
provider_id=provider.__provider_id__,
)
for metric in metrics:
enqueue_event(metric)
Expand All @@ -307,13 +317,13 @@ async def openai_embeddings(
logger.debug(
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
)
model_obj = await self._get_model(params.model, ModelType.embedding)

# Update model to use resolved identifier
params.model = model_obj.identifier
request_model_id = params.model
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.embedding)
params.model = provider_resource_id

provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(params)
response = await provider.openai_embeddings(params)
response.model = request_model_id
return response

async def list_chat_completions(
self,
Expand Down Expand Up @@ -369,7 +379,8 @@ async def stream_tokens_and_compute_metrics(
self,
response,
prompt_tokens,
model,
fully_qualified_model_id: str,
provider_id: str,
tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
completion_text = ""
Expand Down Expand Up @@ -407,7 +418,8 @@ async def stream_tokens_and_compute_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in completion_metrics:
if metric.metric in [
Expand All @@ -427,7 +439,8 @@ async def stream_tokens_and_compute_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
Expand All @@ -439,7 +452,8 @@ async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
prompt_tokens,
model,
fully_qualified_model_id: str,
provider_id: str,
tool_prompt_format: ToolPromptFormat | None = None,
):
if isinstance(response, ChatCompletionResponse):
Expand All @@ -456,7 +470,8 @@ async def count_tokens_and_compute_metrics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
Expand All @@ -470,14 +485,16 @@ async def count_tokens_and_compute_metrics(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]

async def stream_tokens_and_compute_metrics_openai_chat(
self,
response: AsyncIterator[OpenAIChatCompletionChunk],
model: Model,
fully_qualified_model_id: str,
provider_id: str,
messages: list[OpenAIMessageParam] | None = None,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
Expand All @@ -497,6 +514,8 @@ async def stream_tokens_and_compute_metrics_openai_chat(
if created is None and chunk.created:
created = chunk.created

chunk.model = fully_qualified_model_id

# Accumulate choice data for final assembly
if chunk.choices:
for choice_delta in chunk.choices:
Expand Down Expand Up @@ -553,7 +572,8 @@ async def stream_tokens_and_compute_metrics_openai_chat(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model=model,
model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:
enqueue_event(metric)
Expand Down Expand Up @@ -601,7 +621,7 @@ async def stream_tokens_and_compute_metrics_openai_chat(
id=id,
choices=assembled_choices,
created=created or int(time.time()),
model=model.identifier,
model=fully_qualified_model_id,
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ async def openai_embeddings(
raise ValueError("Empty list not supported")

# Get the model and generate embeddings
model_obj = await self.model_store.get_model(params.model)
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
embedding_model = await self._load_sentence_transformer_model(params.model)
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)

# Convert embeddings to the requested format
Expand Down
5 changes: 4 additions & 1 deletion src/llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,11 @@ async def _get_provider_model_id(self, model: str) -> str:
:param model: The registered model name/identifier
:return: The provider-specific model ID (e.g., "gpt-4")
"""
# Look up the registered model to get the provider-specific model ID
# self.model_store is injected by the distribution system at runtime
if not await self.model_store.has_model(model): # type: ignore[attr-defined]
return model

# Look up the registered model to get the provider-specific model ID
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
if model_obj.provider_resource_id is None:
Expand Down
9 changes: 3 additions & 6 deletions tests/integration/inference/test_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe

assert response.object == "list"

# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
assert response.model == embedding_model_id
assert len(response.data) == 1
assert response.data[0].object == "embedding"
assert response.data[0].index == 0
Expand All @@ -186,8 +185,7 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e

assert response.object == "list"

# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
assert response.model == embedding_model_id
assert len(response.data) == len(input_texts)

for i, embedding_data in enumerate(response.data):
Expand Down Expand Up @@ -365,8 +363,7 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
# Validate response structure
assert response.object == "list"

# Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5)
assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}")
assert response.model == embedding_model_id
assert len(response.data) == len(input_texts)

# Validate each embedding in the batch
Expand Down
Loading
Loading