diff --git a/src/llama_stack/core/routers/inference.py b/src/llama_stack/core/routers/inference.py index d532bc622d..c548dd14c4 100644 --- a/src/llama_stack/core/routers/inference.py +++ b/src/llama_stack/core/routers/inference.py @@ -110,7 +110,8 @@ def _construct_metrics( prompt_tokens: int, completion_tokens: int, total_tokens: int, - model: Model, + fully_qualified_model_id: str, + provider_id: str, ) -> list[MetricEvent]: """Constructs a list of MetricEvent objects containing token usage metrics. @@ -118,7 +119,8 @@ def _construct_metrics( prompt_tokens: Number of tokens in the prompt completion_tokens: Number of tokens in the completion total_tokens: Total number of tokens used - model: Model object containing model_id and provider_id + fully_qualified_model_id: + provider_id: The provider identifier Returns: List of MetricEvent objects with token usage metrics @@ -144,8 +146,8 @@ def _construct_metrics( timestamp=datetime.now(UTC), unit="tokens", attributes={ - "model_id": model.model_id, - "provider_id": model.provider_id, + "model_id": fully_qualified_model_id, + "provider_id": provider_id, }, ) ) @@ -158,7 +160,9 @@ async def _compute_and_log_token_usage( total_tokens: int, model: Model, ) -> list[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id + ) if self.telemetry_enabled: for metric in metrics: enqueue_event(metric) @@ -178,14 +182,25 @@ async def _count_tokens( encoded = self.formatter.encode_content(messages) return len(encoded.tokens) if encoded and encoded.tokens else 0 - async def _get_model(self, model_id: str, expected_model_type: str) -> Model: - """takes a model id and gets model after ensuring that it is accessible and of the correct type""" - model = await self.routing_table.get_model(model_id) - if model is None: + async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]: + model = await self.routing_table.get_object_by_identifier("model", model_id) + if model: + if model.model_type != expected_model_type: + raise ModelTypeError(model_id, model.model_type, expected_model_type) + + provider = await self.routing_table.get_provider_impl(model.identifier) + return provider, model.provider_resource_id + + splits = model_id.split("/", maxsplit=1) + if len(splits) != 2: + raise ModelNotFoundError(model_id) + + provider_id, provider_resource_id = splits + if provider_id not in self.routing_table.impls_by_provider_id: + logger.warning(f"Provider {provider_id} not found for model {model_id}") raise ModelNotFoundError(model_id) - if model.model_type != expected_model_type: - raise ModelTypeError(model_id, model.model_type, expected_model_type) - return model + + return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id async def rerank( self, @@ -195,14 +210,8 @@ async def rerank( max_num_results: int | None = None, ) -> RerankResponse: logger.debug(f"InferenceRouter.rerank: {model}") - model_obj = await self._get_model(model, ModelType.rerank) - provider = await self.routing_table.get_provider_impl(model_obj.identifier) - return await provider.rerank( - model=model_obj.identifier, - query=query, - items=items, - max_num_results=max_num_results, - ) + provider, provider_resource_id = await self._get_model_provider(model, ModelType.rerank) + return await provider.rerank(provider_resource_id, query, items, max_num_results) async def openai_completion( self, @@ -211,24 +220,24 @@ async def openai_completion( logger.debug( f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}", ) - model_obj = await self._get_model(params.model, ModelType.llm) - - # Update params with the resolved model identifier - params.model = model_obj.identifier + request_model_id = params.model + provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm) + params.model = provider_resource_id - provider = await self.routing_table.get_provider_impl(model_obj.identifier) if params.stream: return await provider.openai_completion(params) # TODO: Metrics do NOT work with openai_completion stream=True due to the fact # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. response = await provider.openai_completion(params) + response.model = request_model_id if self.telemetry_enabled: metrics = self._construct_metrics( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, - model=model_obj, + fully_qualified_model_id=request_model_id, + provider_id=provider.__provider_id__, ) for metric in metrics: enqueue_event(metric) @@ -246,7 +255,9 @@ async def openai_chat_completion( logger.debug( f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}", ) - model_obj = await self._get_model(params.model, ModelType.llm) + request_model_id = params.model + provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm) + params.model = provider_resource_id # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface @@ -264,10 +275,6 @@ async def openai_chat_completion( params.tool_choice = None params.tools = None - # Update params with the resolved model identifier - params.model = model_obj.identifier - - provider = await self.routing_table.get_provider_impl(model_obj.identifier) if params.stream: response_stream = await provider.openai_chat_completion(params) @@ -275,11 +282,13 @@ async def openai_chat_completion( # We need to add metrics to each chunk and store the final completion return self.stream_tokens_and_compute_metrics_openai_chat( response=response_stream, - model=model_obj, + fully_qualified_model_id=request_model_id, + provider_id=provider.__provider_id__, messages=params.messages, ) response = await self._nonstream_openai_chat_completion(provider, params) + response.model = request_model_id # Store the response with the ID that will be returned to the client if self.store: @@ -290,7 +299,8 @@ async def openai_chat_completion( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, - model=model_obj, + fully_qualified_model_id=request_model_id, + provider_id=provider.__provider_id__, ) for metric in metrics: enqueue_event(metric) @@ -307,13 +317,13 @@ async def openai_embeddings( logger.debug( f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}", ) - model_obj = await self._get_model(params.model, ModelType.embedding) - - # Update model to use resolved identifier - params.model = model_obj.identifier + request_model_id = params.model + provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.embedding) + params.model = provider_resource_id - provider = await self.routing_table.get_provider_impl(model_obj.identifier) - return await provider.openai_embeddings(params) + response = await provider.openai_embeddings(params) + response.model = request_model_id + return response async def list_chat_completions( self, @@ -369,7 +379,8 @@ async def stream_tokens_and_compute_metrics( self, response, prompt_tokens, - model, + fully_qualified_model_id: str, + provider_id: str, tool_prompt_format: ToolPromptFormat | None = None, ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: completion_text = "" @@ -407,7 +418,8 @@ async def stream_tokens_and_compute_metrics( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, - model=model, + fully_qualified_model_id=fully_qualified_model_id, + provider_id=provider_id, ) for metric in completion_metrics: if metric.metric in [ @@ -427,7 +439,8 @@ async def stream_tokens_and_compute_metrics( prompt_tokens or 0, completion_tokens or 0, total_tokens, - model, + fully_qualified_model_id=fully_qualified_model_id, + provider_id=provider_id, ) async_metrics = [ MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics @@ -439,7 +452,8 @@ async def count_tokens_and_compute_metrics( self, response: ChatCompletionResponse | CompletionResponse, prompt_tokens, - model, + fully_qualified_model_id: str, + provider_id: str, tool_prompt_format: ToolPromptFormat | None = None, ): if isinstance(response, ChatCompletionResponse): @@ -456,7 +470,8 @@ async def count_tokens_and_compute_metrics( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, - model=model, + fully_qualified_model_id=fully_qualified_model_id, + provider_id=provider_id, ) for metric in completion_metrics: if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens @@ -470,14 +485,16 @@ async def count_tokens_and_compute_metrics( prompt_tokens or 0, completion_tokens or 0, total_tokens, - model, + fully_qualified_model_id=fully_qualified_model_id, + provider_id=provider_id, ) return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def stream_tokens_and_compute_metrics_openai_chat( self, response: AsyncIterator[OpenAIChatCompletionChunk], - model: Model, + fully_qualified_model_id: str, + provider_id: str, messages: list[OpenAIMessageParam] | None = None, ) -> AsyncIterator[OpenAIChatCompletionChunk]: """Stream OpenAI chat completion chunks, compute metrics, and store the final completion.""" @@ -497,6 +514,8 @@ async def stream_tokens_and_compute_metrics_openai_chat( if created is None and chunk.created: created = chunk.created + chunk.model = fully_qualified_model_id + # Accumulate choice data for final assembly if chunk.choices: for choice_delta in chunk.choices: @@ -553,7 +572,8 @@ async def stream_tokens_and_compute_metrics_openai_chat( prompt_tokens=chunk.usage.prompt_tokens, completion_tokens=chunk.usage.completion_tokens, total_tokens=chunk.usage.total_tokens, - model=model, + model_id=fully_qualified_model_id, + provider_id=provider_id, ) for metric in metrics: enqueue_event(metric) @@ -601,7 +621,7 @@ async def stream_tokens_and_compute_metrics_openai_chat( id=id, choices=assembled_choices, created=created or int(time.time()), - model=model.identifier, + model=fully_qualified_model_id, object="chat.completion", ) logger.debug(f"InferenceRouter.completion_response: {final_response}") diff --git a/src/llama_stack/providers/utils/inference/embedding_mixin.py b/src/llama_stack/providers/utils/inference/embedding_mixin.py index c959b9c19d..bab495eef1 100644 --- a/src/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/src/llama_stack/providers/utils/inference/embedding_mixin.py @@ -46,8 +46,7 @@ async def openai_embeddings( raise ValueError("Empty list not supported") # Get the model and generate embeddings - model_obj = await self.model_store.get_model(params.model) - embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id) + embedding_model = await self._load_sentence_transformer_model(params.model) embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False) # Convert embeddings to the requested format diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index bbd3d2e109..7c2200b139 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -226,8 +226,11 @@ async def _get_provider_model_id(self, model: str) -> str: :param model: The registered model name/identifier :return: The provider-specific model ID (e.g., "gpt-4") """ - # Look up the registered model to get the provider-specific model ID # self.model_store is injected by the distribution system at runtime + if not await self.model_store.has_model(model): # type: ignore[attr-defined] + return model + + # Look up the registered model to get the provider-specific model ID model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined] # provider_resource_id is str | None, but we expect it to be str for OpenAI calls if model_obj.provider_resource_id is None: diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 00de56f3a0..ee21030fa5 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -161,8 +161,7 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe assert response.object == "list" - # Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5) - assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}") + assert response.model == embedding_model_id assert len(response.data) == 1 assert response.data[0].object == "embedding" assert response.data[0].index == 0 @@ -186,8 +185,7 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e assert response.object == "list" - # Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5) - assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}") + assert response.model == embedding_model_id assert len(response.data) == len(input_texts) for i, embedding_data in enumerate(response.data): @@ -365,8 +363,7 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo # Validate response structure assert response.object == "list" - # Handle provider-scoped model identifiers (e.g., sentence-transformers/nomic-ai/nomic-embed-text-v1.5) - assert response.model == embedding_model_id or response.model.endswith(f"/{embedding_model_id}") + assert response.model == embedding_model_id assert len(response.data) == len(input_texts) # Validate each embedding in the batch diff --git a/tests/integration/inference/test_provider_data_routing.py b/tests/integration/inference/test_provider_data_routing.py new file mode 100644 index 0000000000..9b9806345a --- /dev/null +++ b/tests/integration/inference/test_provider_data_routing.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test that models can be routed using provider_id/model_id format +when the provider is configured but the specific model is not registered. + +This test validates the fix in src/llama_stack/core/routers/inference.py +that enables routing based on provider_data alone. +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from llama_stack import LlamaStackAsLibraryClient +from llama_stack.apis.datatypes import Api +from llama_stack.apis.inference.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionUsage, + OpenAIChoice, +) +from llama_stack.core.telemetry.telemetry import MetricEvent + + +class OpenAIChatCompletionWithMetrics(OpenAIChatCompletion): + metrics: list[MetricEvent] | None = None + + +def test_unregistered_model_routing_with_provider_data(client_with_models): + """ + Test that a model can be routed using provider_id/model_id format + even when the model is not explicitly registered, as long as the provider + is available. + + This validates the fix where the router: + 1. Tries to lookup model in routing table + 2. If not found, splits model_id by "/" to extract provider_id and provider_resource_id + 3. Routes directly to the provider with the provider_resource_id + + Without the fix, this would raise ModelNotFoundError immediately. + With the fix, the routing succeeds and the request reaches the provider. + """ + if not isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("Test requires library client for provider-level patching") + + client = client_with_models + + # Use a model format that follows provider_id/model_id convention + # We'll use anthropic as an example since it's a remote provider that + # benefits from this pattern + test_model_id = "anthropic/claude-3-5-sonnet-20241022" + + # First, verify the model is NOT registered + registered_models = {m.identifier for m in client.models.list()} + assert test_model_id not in registered_models, f"Model {test_model_id} should not be pre-registered for this test" + + # Check if anthropic provider is available in ci-tests + providers = {p.provider_id: p for p in client.providers.list()} + if "anthropic" not in providers: + pytest.skip("Anthropic provider not configured in ci-tests - cannot test unregistered model routing") + + # Get the actual provider implementation from the library client's stack + inference_router = client.async_client.impls.get(Api.inference) + if not inference_router: + raise RuntimeError("No inference router found") + + # The inference router's routing_table.impls_by_provider_id should have anthropic + # Let's patch the anthropic provider's openai_chat_completion method + # to avoid making real API calls + mock_response = OpenAIChatCompletionWithMetrics( + id="chatcmpl-test-123", + created=1234567890, + model="claude-3-5-sonnet-20241022", + choices=[ + OpenAIChoice( + index=0, + finish_reason="stop", + message=OpenAIAssistantMessageParam( + content="Mocked response to test routing", + ), + ) + ], + usage=OpenAIChatCompletionUsage( + prompt_tokens=5, + completion_tokens=10, + total_tokens=15, + ), + ) + + # Get the routing table from the inference router + routing_table = inference_router.routing_table + + # Patch the anthropic provider's openai_chat_completion method + anthropic_provider = routing_table.impls_by_provider_id.get("anthropic") + if not anthropic_provider: + raise RuntimeError("Anthropic provider not found in routing table even though it's in providers list") + + with patch.object( + anthropic_provider, + "openai_chat_completion", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_method: + # Make the request with the unregistered model + response = client.chat.completions.create( + model=test_model_id, + messages=[ + { + "role": "user", + "content": "Test message for unregistered model routing", + } + ], + stream=False, + ) + + # Verify the provider's method was called + assert mock_method.called, "Provider's openai_chat_completion should have been called" + + # Verify the response came through + assert response.choices[0].message.content == "Mocked response to test routing" + + # Verify that the router passed the correct model to the provider + # (without the "anthropic/" prefix) + call_args = mock_method.call_args + params = call_args[0][0] # First positional argument is the params object + assert params.model == "claude-3-5-sonnet-20241022", ( + f"Provider should receive model without provider prefix, got {params.model}" + ) diff --git a/tests/integration/telemetry/test_completions.py b/tests/integration/telemetry/test_completions.py index 77ca4d51ce..a542730bea 100644 --- a/tests/integration/telemetry/test_completions.py +++ b/tests/integration/telemetry/test_completions.py @@ -64,10 +64,11 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, # Verify spans spans = mock_otlp_collector.get_spans() - assert len(spans) == 5 + # Expected spans: 1 root span + 3 autotraced method calls from routing/inference + assert len(spans) == 4, f"Expected 4 spans, got {len(spans)}" - # we only need this captured one time - logged_model_id = None + # Collect all model_ids found in spans + logged_model_ids = [] for span in spans: attrs = span.attributes @@ -87,10 +88,10 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, args = json.loads(attrs["__args__"]) if "model_id" in args: - logged_model_id = args["model_id"] + logged_model_ids.append(args["model_id"]) - assert logged_model_id is not None - assert logged_model_id == text_model_id + # At least one span should capture the fully qualified model ID + assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}" # TODO: re-enable this once metrics get fixed """