@@ -110,15 +110,17 @@ def _construct_metrics(
110110 prompt_tokens : int ,
111111 completion_tokens : int ,
112112 total_tokens : int ,
113- model : Model ,
113+ fully_qualified_model_id : str ,
114+ provider_id : str ,
114115 ) -> list [MetricEvent ]:
115116 """Constructs a list of MetricEvent objects containing token usage metrics.
116117
117118 Args:
118119 prompt_tokens: Number of tokens in the prompt
119120 completion_tokens: Number of tokens in the completion
120121 total_tokens: Total number of tokens used
121- model: Model object containing model_id and provider_id
122+ fully_qualified_model_id:
123+ provider_id: The provider identifier
122124
123125 Returns:
124126 List of MetricEvent objects with token usage metrics
@@ -144,8 +146,8 @@ def _construct_metrics(
144146 timestamp = datetime .now (UTC ),
145147 unit = "tokens" ,
146148 attributes = {
147- "model_id" : model . model_id ,
148- "provider_id" : model . provider_id ,
149+ "model_id" : fully_qualified_model_id ,
150+ "provider_id" : provider_id ,
149151 },
150152 )
151153 )
@@ -158,7 +160,9 @@ async def _compute_and_log_token_usage(
158160 total_tokens : int ,
159161 model : Model ,
160162 ) -> list [MetricInResponse ]:
161- metrics = self ._construct_metrics (prompt_tokens , completion_tokens , total_tokens , model )
163+ metrics = self ._construct_metrics (
164+ prompt_tokens , completion_tokens , total_tokens , model .model_id , model .provider_id
165+ )
162166 if self .telemetry_enabled :
163167 for metric in metrics :
164168 enqueue_event (metric )
@@ -178,14 +182,25 @@ async def _count_tokens(
178182 encoded = self .formatter .encode_content (messages )
179183 return len (encoded .tokens ) if encoded and encoded .tokens else 0
180184
181- async def _get_model (self , model_id : str , expected_model_type : str ) -> Model :
182- """takes a model id and gets model after ensuring that it is accessible and of the correct type"""
183- model = await self .routing_table .get_model (model_id )
184- if model is None :
185+ async def _get_model_provider (self , model_id : str , expected_model_type : str ) -> tuple [Inference , str ]:
186+ model = await self .routing_table .get_object_by_identifier ("model" , model_id )
187+ if model :
188+ if model .model_type != expected_model_type :
189+ raise ModelTypeError (model_id , model .model_type , expected_model_type )
190+
191+ provider = await self .routing_table .get_provider_impl (model .identifier )
192+ return provider , model .provider_resource_id
193+
194+ splits = model_id .split ("/" , maxsplit = 1 )
195+ if len (splits ) != 2 :
196+ raise ModelNotFoundError (model_id )
197+
198+ provider_id , provider_resource_id = splits
199+ if provider_id not in self .routing_table .impls_by_provider_id :
200+ logger .warning (f"Provider { provider_id } not found for model { model_id } " )
185201 raise ModelNotFoundError (model_id )
186- if model .model_type != expected_model_type :
187- raise ModelTypeError (model_id , model .model_type , expected_model_type )
188- return model
202+
203+ return self .routing_table .impls_by_provider_id [provider_id ], provider_resource_id
189204
190205 async def rerank (
191206 self ,
@@ -195,14 +210,8 @@ async def rerank(
195210 max_num_results : int | None = None ,
196211 ) -> RerankResponse :
197212 logger .debug (f"InferenceRouter.rerank: { model } " )
198- model_obj = await self ._get_model (model , ModelType .rerank )
199- provider = await self .routing_table .get_provider_impl (model_obj .identifier )
200- return await provider .rerank (
201- model = model_obj .identifier ,
202- query = query ,
203- items = items ,
204- max_num_results = max_num_results ,
205- )
213+ provider , provider_resource_id = await self ._get_model_provider (model , ModelType .rerank )
214+ return await provider .rerank (provider_resource_id , query , items , max_num_results )
206215
207216 async def openai_completion (
208217 self ,
@@ -211,24 +220,24 @@ async def openai_completion(
211220 logger .debug (
212221 f"InferenceRouter.openai_completion: model={ params .model } , stream={ params .stream } , prompt={ params .prompt } " ,
213222 )
214- model_obj = await self ._get_model (params .model , ModelType .llm )
215-
216- # Update params with the resolved model identifier
217- params .model = model_obj .identifier
223+ request_model_id = params .model
224+ provider , provider_resource_id = await self ._get_model_provider (params .model , ModelType .llm )
225+ params .model = provider_resource_id
218226
219- provider = await self .routing_table .get_provider_impl (model_obj .identifier )
220227 if params .stream :
221228 return await provider .openai_completion (params )
222229 # TODO: Metrics do NOT work with openai_completion stream=True due to the fact
223230 # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
224231
225232 response = await provider .openai_completion (params )
233+ response .model = request_model_id
226234 if self .telemetry_enabled :
227235 metrics = self ._construct_metrics (
228236 prompt_tokens = response .usage .prompt_tokens ,
229237 completion_tokens = response .usage .completion_tokens ,
230238 total_tokens = response .usage .total_tokens ,
231- model = model_obj ,
239+ fully_qualified_model_id = request_model_id ,
240+ provider_id = provider .__provider_id__ ,
232241 )
233242 for metric in metrics :
234243 enqueue_event (metric )
@@ -246,7 +255,9 @@ async def openai_chat_completion(
246255 logger .debug (
247256 f"InferenceRouter.openai_chat_completion: model={ params .model } , stream={ params .stream } , messages={ params .messages } " ,
248257 )
249- model_obj = await self ._get_model (params .model , ModelType .llm )
258+ request_model_id = params .model
259+ provider , provider_resource_id = await self ._get_model_provider (params .model , ModelType .llm )
260+ params .model = provider_resource_id
250261
251262 # Use the OpenAI client for a bit of extra input validation without
252263 # exposing the OpenAI client itself as part of our API surface
@@ -264,22 +275,20 @@ async def openai_chat_completion(
264275 params .tool_choice = None
265276 params .tools = None
266277
267- # Update params with the resolved model identifier
268- params .model = model_obj .identifier
269-
270- provider = await self .routing_table .get_provider_impl (model_obj .identifier )
271278 if params .stream :
272279 response_stream = await provider .openai_chat_completion (params )
273280
274281 # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
275282 # We need to add metrics to each chunk and store the final completion
276283 return self .stream_tokens_and_compute_metrics_openai_chat (
277284 response = response_stream ,
278- model = model_obj ,
285+ fully_qualified_model_id = request_model_id ,
286+ provider_id = provider .__provider_id__ ,
279287 messages = params .messages ,
280288 )
281289
282290 response = await self ._nonstream_openai_chat_completion (provider , params )
291+ response .model = request_model_id
283292
284293 # Store the response with the ID that will be returned to the client
285294 if self .store :
@@ -290,7 +299,8 @@ async def openai_chat_completion(
290299 prompt_tokens = response .usage .prompt_tokens ,
291300 completion_tokens = response .usage .completion_tokens ,
292301 total_tokens = response .usage .total_tokens ,
293- model = model_obj ,
302+ fully_qualified_model_id = request_model_id ,
303+ provider_id = provider .__provider_id__ ,
294304 )
295305 for metric in metrics :
296306 enqueue_event (metric )
@@ -307,13 +317,13 @@ async def openai_embeddings(
307317 logger .debug (
308318 f"InferenceRouter.openai_embeddings: model={ params .model } , input_type={ type (params .input )} , encoding_format={ params .encoding_format } , dimensions={ params .dimensions } " ,
309319 )
310- model_obj = await self ._get_model (params .model , ModelType .embedding )
311-
312- # Update model to use resolved identifier
313- params .model = model_obj .identifier
320+ request_model_id = params .model
321+ provider , provider_resource_id = await self ._get_model_provider (params .model , ModelType .embedding )
322+ params .model = provider_resource_id
314323
315- provider = await self .routing_table .get_provider_impl (model_obj .identifier )
316- return await provider .openai_embeddings (params )
324+ response = await provider .openai_embeddings (params )
325+ response .model = request_model_id
326+ return response
317327
318328 async def list_chat_completions (
319329 self ,
@@ -369,7 +379,8 @@ async def stream_tokens_and_compute_metrics(
369379 self ,
370380 response ,
371381 prompt_tokens ,
372- model ,
382+ fully_qualified_model_id : str ,
383+ provider_id : str ,
373384 tool_prompt_format : ToolPromptFormat | None = None ,
374385 ) -> AsyncGenerator [ChatCompletionResponseStreamChunk , None ] | AsyncGenerator [CompletionResponseStreamChunk , None ]:
375386 completion_text = ""
@@ -407,7 +418,8 @@ async def stream_tokens_and_compute_metrics(
407418 prompt_tokens = prompt_tokens ,
408419 completion_tokens = completion_tokens ,
409420 total_tokens = total_tokens ,
410- model = model ,
421+ fully_qualified_model_id = fully_qualified_model_id ,
422+ provider_id = provider_id ,
411423 )
412424 for metric in completion_metrics :
413425 if metric .metric in [
@@ -427,7 +439,8 @@ async def stream_tokens_and_compute_metrics(
427439 prompt_tokens or 0 ,
428440 completion_tokens or 0 ,
429441 total_tokens ,
430- model ,
442+ fully_qualified_model_id = fully_qualified_model_id ,
443+ provider_id = provider_id ,
431444 )
432445 async_metrics = [
433446 MetricInResponse (metric = metric .metric , value = metric .value ) for metric in completion_metrics
@@ -439,7 +452,8 @@ async def count_tokens_and_compute_metrics(
439452 self ,
440453 response : ChatCompletionResponse | CompletionResponse ,
441454 prompt_tokens ,
442- model ,
455+ fully_qualified_model_id : str ,
456+ provider_id : str ,
443457 tool_prompt_format : ToolPromptFormat | None = None ,
444458 ):
445459 if isinstance (response , ChatCompletionResponse ):
@@ -456,7 +470,8 @@ async def count_tokens_and_compute_metrics(
456470 prompt_tokens = prompt_tokens ,
457471 completion_tokens = completion_tokens ,
458472 total_tokens = total_tokens ,
459- model = model ,
473+ fully_qualified_model_id = fully_qualified_model_id ,
474+ provider_id = provider_id ,
460475 )
461476 for metric in completion_metrics :
462477 if metric .metric in ["completion_tokens" , "total_tokens" ]: # Only log completion and total tokens
@@ -470,14 +485,16 @@ async def count_tokens_and_compute_metrics(
470485 prompt_tokens or 0 ,
471486 completion_tokens or 0 ,
472487 total_tokens ,
473- model ,
488+ fully_qualified_model_id = fully_qualified_model_id ,
489+ provider_id = provider_id ,
474490 )
475491 return [MetricInResponse (metric = metric .metric , value = metric .value ) for metric in metrics ]
476492
477493 async def stream_tokens_and_compute_metrics_openai_chat (
478494 self ,
479495 response : AsyncIterator [OpenAIChatCompletionChunk ],
480- model : Model ,
496+ fully_qualified_model_id : str ,
497+ provider_id : str ,
481498 messages : list [OpenAIMessageParam ] | None = None ,
482499 ) -> AsyncIterator [OpenAIChatCompletionChunk ]:
483500 """Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
@@ -497,6 +514,8 @@ async def stream_tokens_and_compute_metrics_openai_chat(
497514 if created is None and chunk .created :
498515 created = chunk .created
499516
517+ chunk .model = fully_qualified_model_id
518+
500519 # Accumulate choice data for final assembly
501520 if chunk .choices :
502521 for choice_delta in chunk .choices :
@@ -553,7 +572,8 @@ async def stream_tokens_and_compute_metrics_openai_chat(
553572 prompt_tokens = chunk .usage .prompt_tokens ,
554573 completion_tokens = chunk .usage .completion_tokens ,
555574 total_tokens = chunk .usage .total_tokens ,
556- model = model ,
575+ model_id = fully_qualified_model_id ,
576+ provider_id = provider_id ,
557577 )
558578 for metric in metrics :
559579 enqueue_event (metric )
@@ -601,7 +621,7 @@ async def stream_tokens_and_compute_metrics_openai_chat(
601621 id = id ,
602622 choices = assembled_choices ,
603623 created = created or int (time .time ()),
604- model = model . identifier ,
624+ model = fully_qualified_model_id ,
605625 object = "chat.completion" ,
606626 )
607627 logger .debug (f"InferenceRouter.completion_response: { final_response } " )
0 commit comments