4
4
import httpx
5
5
import structlog
6
6
from fastapi import Header , HTTPException , Request
7
- from fastapi .responses import JSONResponse
8
7
9
- from codegate .config import Config
10
8
from codegate .pipeline .factory import PipelineFactory
11
9
from codegate .providers .base import BaseProvider , ModelFetchError
12
10
from codegate .providers .litellmshim import LiteLLmShim , sse_stream_generator
@@ -19,11 +17,6 @@ def __init__(
19
17
pipeline_factory : PipelineFactory ,
20
18
):
21
19
completion_handler = LiteLLmShim (stream_generator = sse_stream_generator )
22
- config = Config .get_config ()
23
- if config is not None :
24
- provided_urls = config .provider_urls
25
- self .lm_studio_url = provided_urls .get ("lm_studio" , "http://localhost:11434/" )
26
-
27
20
super ().__init__ (
28
21
OpenAIInputNormalizer (),
29
22
OpenAIOutputNormalizer (),
@@ -39,8 +32,6 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
39
32
headers = {}
40
33
if api_key :
41
34
headers ["Authorization" ] = f"Bearer { api_key } "
42
- if not endpoint :
43
- endpoint = "https://api.openai.com"
44
35
45
36
resp = httpx .get (f"{ endpoint } /v1/models" , headers = headers )
46
37
@@ -51,19 +42,32 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
51
42
52
43
return [model ["id" ] for model in jsonresp .get ("data" , [])]
53
44
45
+ async def process_request (self , data : dict , api_key : str , request : Request ):
46
+ """
47
+ Process the request and return the completion stream
48
+ """
49
+ is_fim_request = self ._is_fim_request (request , data )
50
+ try :
51
+ stream = await self .complete (data , api_key , is_fim_request = is_fim_request )
52
+ except Exception as e :
53
+ # check if we have an status code there
54
+ if hasattr (e , "status_code" ):
55
+ logger = structlog .get_logger ("codegate" )
56
+ logger .error ("Error in OpenAIProvider completion" , error = str (e ))
57
+
58
+ raise HTTPException (status_code = e .status_code , detail = str (e )) # type: ignore
59
+ else :
60
+ # just continue raising the exception
61
+ raise e
62
+ return self ._completion_handler .create_response (stream )
63
+
54
64
def _setup_routes (self ):
55
65
"""
56
66
Sets up the /chat/completions route for the provider as expected by the
57
67
OpenAI API. Extracts the API key from the "Authorization" header and
58
68
passes it to the completion handler.
59
69
"""
60
70
61
- @self .router .get (f"/{ self .provider_route_name } /models" )
62
- @self .router .get (f"/{ self .provider_route_name } /v1/models" )
63
- async def get_models ():
64
- # dummy method for lm studio
65
- return JSONResponse (status_code = 200 , content = [])
66
-
67
71
@self .router .post (f"/{ self .provider_route_name } /chat/completions" )
68
72
@self .router .post (f"/{ self .provider_route_name } /completions" )
69
73
@self .router .post (f"/{ self .provider_route_name } /v1/chat/completions" )
@@ -78,20 +82,4 @@ async def create_completion(
78
82
body = await request .body ()
79
83
data = json .loads (body )
80
84
81
- # if model starts with lm_studio, propagate it
82
- if data .get ("model" , "" ).startswith ("lm_studio" ):
83
- data ["base_url" ] = self .lm_studio_url + "/v1/"
84
- is_fim_request = self ._is_fim_request (request , data )
85
- try :
86
- stream = await self .complete (data , api_key , is_fim_request = is_fim_request )
87
- except Exception as e :
88
- # check if we have an status code there
89
- if hasattr (e , "status_code" ):
90
- logger = structlog .get_logger ("codegate" )
91
- logger .error ("Error in OpenAIProvider completion" , error = str (e ))
92
-
93
- raise HTTPException (status_code = e .status_code , detail = str (e )) # type: ignore
94
- else :
95
- # just continue raising the exception
96
- raise e
97
- return self ._completion_handler .create_response (stream )
85
+ return await self .process_request (data , api_key , request )
0 commit comments