Skip to content

Commit 6c9b508

Browse files
Add LM Studio provider (#837)
* Add LM Studio provider We were using OpenAI provider to interface with LM Studio since both of them were very similar. For muxing we need to clearly distinguish to which providers we need to route the request. Hence it will be easier to disambiguate the providers. * Delete conditional to add lm studio URL
1 parent d24c989 commit 6c9b508

File tree

4 files changed

+85
-34
lines changed

4 files changed

+85
-34
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import json
2+
3+
from fastapi import Header, HTTPException, Request
4+
from fastapi.responses import JSONResponse
5+
6+
from codegate.config import Config
7+
from codegate.pipeline.factory import PipelineFactory
8+
from codegate.providers.openai.provider import OpenAIProvider
9+
10+
11+
class LmStudioProvider(OpenAIProvider):
12+
def __init__(
13+
self,
14+
pipeline_factory: PipelineFactory,
15+
):
16+
config = Config.get_config()
17+
if config is not None:
18+
provided_urls = config.provider_urls
19+
self.lm_studio_url = provided_urls.get("lm_studio", "http://localhost:11434/")
20+
21+
super().__init__(pipeline_factory)
22+
23+
@property
24+
def provider_route_name(self) -> str:
25+
return "lm_studio"
26+
27+
def _setup_routes(self):
28+
"""
29+
Sets up the /chat/completions route for the provider as expected by the
30+
LM Studio API. Extracts the API key from the "Authorization" header and
31+
passes it to the completion handler.
32+
"""
33+
34+
@self.router.get(f"/{self.provider_route_name}/models")
35+
@self.router.get(f"/{self.provider_route_name}/v1/models")
36+
async def get_models():
37+
# dummy method for lm studio
38+
return JSONResponse(status_code=200, content=[])
39+
40+
@self.router.post(f"/{self.provider_route_name}/chat/completions")
41+
@self.router.post(f"/{self.provider_route_name}/completions")
42+
@self.router.post(f"/{self.provider_route_name}/v1/chat/completions")
43+
async def create_completion(
44+
request: Request,
45+
authorization: str = Header(..., description="Bearer token"),
46+
):
47+
if not authorization.startswith("Bearer "):
48+
raise HTTPException(status_code=401, detail="Invalid authorization header")
49+
50+
api_key = authorization.split(" ")[1]
51+
body = await request.body()
52+
data = json.loads(body)
53+
54+
data["base_url"] = self.lm_studio_url + "/v1/"
55+
56+
return await self.process_request(data, api_key, request)

src/codegate/providers/openai/provider.py

+20-32
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import httpx
55
import structlog
66
from fastapi import Header, HTTPException, Request
7-
from fastapi.responses import JSONResponse
87

9-
from codegate.config import Config
108
from codegate.pipeline.factory import PipelineFactory
119
from codegate.providers.base import BaseProvider, ModelFetchError
1210
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
@@ -19,11 +17,6 @@ def __init__(
1917
pipeline_factory: PipelineFactory,
2018
):
2119
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
22-
config = Config.get_config()
23-
if config is not None:
24-
provided_urls = config.provider_urls
25-
self.lm_studio_url = provided_urls.get("lm_studio", "http://localhost:11434/")
26-
2720
super().__init__(
2821
OpenAIInputNormalizer(),
2922
OpenAIOutputNormalizer(),
@@ -39,8 +32,6 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
3932
headers = {}
4033
if api_key:
4134
headers["Authorization"] = f"Bearer {api_key}"
42-
if not endpoint:
43-
endpoint = "https://api.openai.com"
4435

4536
resp = httpx.get(f"{endpoint}/v1/models", headers=headers)
4637

@@ -51,19 +42,32 @@ def models(self, endpoint: str = None, api_key: str = None) -> List[str]:
5142

5243
return [model["id"] for model in jsonresp.get("data", [])]
5344

45+
async def process_request(self, data: dict, api_key: str, request: Request):
46+
"""
47+
Process the request and return the completion stream
48+
"""
49+
is_fim_request = self._is_fim_request(request, data)
50+
try:
51+
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
52+
except Exception as e:
53+
#  check if we have an status code there
54+
if hasattr(e, "status_code"):
55+
logger = structlog.get_logger("codegate")
56+
logger.error("Error in OpenAIProvider completion", error=str(e))
57+
58+
raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore
59+
else:
60+
# just continue raising the exception
61+
raise e
62+
return self._completion_handler.create_response(stream)
63+
5464
def _setup_routes(self):
5565
"""
5666
Sets up the /chat/completions route for the provider as expected by the
5767
OpenAI API. Extracts the API key from the "Authorization" header and
5868
passes it to the completion handler.
5969
"""
6070

61-
@self.router.get(f"/{self.provider_route_name}/models")
62-
@self.router.get(f"/{self.provider_route_name}/v1/models")
63-
async def get_models():
64-
# dummy method for lm studio
65-
return JSONResponse(status_code=200, content=[])
66-
6771
@self.router.post(f"/{self.provider_route_name}/chat/completions")
6872
@self.router.post(f"/{self.provider_route_name}/completions")
6973
@self.router.post(f"/{self.provider_route_name}/v1/chat/completions")
@@ -78,20 +82,4 @@ async def create_completion(
7882
body = await request.body()
7983
data = json.loads(body)
8084

81-
# if model starts with lm_studio, propagate it
82-
if data.get("model", "").startswith("lm_studio"):
83-
data["base_url"] = self.lm_studio_url + "/v1/"
84-
is_fim_request = self._is_fim_request(request, data)
85-
try:
86-
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
87-
except Exception as e:
88-
#  check if we have an status code there
89-
if hasattr(e, "status_code"):
90-
logger = structlog.get_logger("codegate")
91-
logger.error("Error in OpenAIProvider completion", error=str(e))
92-
93-
raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore
94-
else:
95-
# just continue raising the exception
96-
raise e
97-
return self._completion_handler.create_response(stream)
85+
return await self.process_request(data, api_key, request)

src/codegate/server.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from codegate.pipeline.factory import PipelineFactory
1414
from codegate.providers.anthropic.provider import AnthropicProvider
1515
from codegate.providers.llamacpp.provider import LlamaCppProvider
16+
from codegate.providers.lm_studio.provider import LmStudioProvider
1617
from codegate.providers.ollama.provider import OllamaProvider
1718
from codegate.providers.openai.provider import OpenAIProvider
1819
from codegate.providers.registry import ProviderRegistry, get_provider_registry
@@ -96,6 +97,12 @@ async def log_user_agent(request: Request, call_next):
9697
pipeline_factory,
9798
),
9899
)
100+
registry.add_provider(
101+
"lm_studio",
102+
LmStudioProvider(
103+
pipeline_factory,
104+
),
105+
)
99106

100107
# Create and add system routes
101108
system_router = APIRouter(tags=["System"])

tests/test_server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_fa
108108
# Verify all providers were registered
109109
registry_instance = mock_registry.return_value
110110
assert (
111-
registry_instance.add_provider.call_count == 5
112-
) # openai, anthropic, llamacpp, vllm, ollama
111+
registry_instance.add_provider.call_count == 6
112+
) # openai, anthropic, llamacpp, vllm, ollama, lm_studio
113113

114114
# Verify specific providers were registered
115115
provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list]

0 commit comments

Comments
 (0)