Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit 1b1229c

Browse files
authored
feat/API: implement /models and /models/{modelId} using TypeSpec (#64421)
Fixes CODY-3085 Fixes CODY-3086 Previously, there was no way for OpenAI clients to list the available models on Sourcegraph or query metadata about a given model ID ("model ref" using our internal terminology). This PR fixes that problem AND additionally adds infrastructure to auto-generate Go models from a TypeSpec specification. [TypeSpec](https://typespec.io/) is an IDL to document REST APIs, created by Microsoft. Historically, the Go code in this repository has been the single source of truth about what exact JSON structures are expected in HTTP request/response pairs in our REST endpoints. This new TypeSpec infrastructure allows us to document these shapes at a higher abstraction level, which has several benefits including automatic OpenAPI generation, which we can use to generate docs on sourcegraph.com/docs or automatically generate client bindings in TypeScript (among many other use-cases). I am planning to write an RFC to propose we start using TypeSpec for new REST endpoints going forward. If the RFC is not approved then we can just delete the new `tools/typespec_codegen` directory and keep the generated code in the repo. It won't be a big difference in the end compared our current manual approach of writing Go structs for HTTP APIs. <!-- PR description tips: https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e --> ## Test plan See test cases. I additionally wrote a basic python script with the official OpenAI client to test that it works with this endpoint. First, I ran `sg start minimal`. Then I wrote this script ```py import os from openai import OpenAI from dotenv import load_dotenv import httpx load_dotenv() openai = OpenAI( # base_url="https://api.openai.com/v1", # api_key=os.getenv("OPENAI_API_KEY"), base_url="https://sourcegraph.test:3443/api/v1", api_key=os.getenv("SRC_ACCESS_TOKEN"), http_client=httpx.Client(verify=False) ) def main(): response = openai.models.list() for model in response.data: print(model.id) if __name__ == "__main__": main() ``` Finally, I ran ``` ❯ python3 models.py anthropic::unknown::claude-3-haiku-20240307 anthropic::unknown::claude-3-sonnet-20240229 fireworks::unknown::starcoder ``` <!-- REQUIRED; info at https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles --> ## Changelog * New `GET /.api/llm/models` and `GET /.api/llm/models/{modelId}` REST API endpoints to list available LLM models on the instance and to get information about a given model. This endpoints is compatible with the `/models` and `/models/{modelId}` endpoints from OpenAI. <!-- OPTIONAL; info at https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c -->
1 parent c222523 commit 1b1229c

30 files changed

+1088
-77
lines changed

.bazelignore

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ client/web/node_modules
2626
client/web-sveltekit/node_modules
2727
client/wildcard/node_modules
2828
internal/appliance/frontend/maintenance/node_modules
29+
internal/openapi/node_modules
2930

3031
cmd/symbols/internal/squirrel/test_repos/starlark
3132

cmd/frontend/internal/llmapi/BUILD.bazel

+9-4
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@ load("//dev:go_defs.bzl", "go_test")
44
go_library(
55
name = "llmapi",
66
srcs = [
7-
"chat_completions_handler.go",
8-
"chat_completions_models.go",
7+
"handler_chat_completions.go",
8+
"handler_models.go",
9+
"handler_models_modelid.go",
910
"httpapi.go",
1011
],
1112
importpath = "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/llmapi",
1213
visibility = ["//cmd/frontend:__subpackages__"],
1314
deps = [
1415
"//internal/completions/types",
1516
"//internal/modelconfig/types",
17+
"//internal/openapi/goapi",
1618
"//lib/errors",
1719
"@com_github_google_uuid//:uuid",
1820
"@com_github_gorilla_mux//:mux",
@@ -23,8 +25,9 @@ go_library(
2325
go_test(
2426
name = "llmapi_test",
2527
srcs = [
26-
"chat_completions_handler_test.go",
27-
"chat_completions_models_test.go",
28+
"handler_chat_completions_test.go",
29+
"handler_models_modelid_test.go",
30+
"handler_models_test.go",
2831
"utils_test.go",
2932
],
3033
data = glob(["golly-recordings/**"]),
@@ -33,8 +36,10 @@ go_test(
3336
"//internal/golly",
3437
"//internal/httpcli",
3538
"//internal/modelconfig/types",
39+
"//internal/openapi/goapi",
3640
"@com_github_gorilla_mux//:mux",
3741
"@com_github_hexops_autogold_v2//:autogold",
3842
"@com_github_stretchr_testify//assert",
43+
"@com_github_stretchr_testify//require",
3944
],
4045
)

cmd/frontend/internal/llmapi/chat_completions_handler.go renamed to cmd/frontend/internal/llmapi/handler_chat_completions.go

+12-19
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ import (
1111
"time"
1212

1313
"github.com/google/uuid"
14-
"github.com/sourcegraph/log"
1514
sglog "github.com/sourcegraph/log"
1615

1716
"github.com/sourcegraph/sourcegraph/lib/errors"
1817

1918
completions "github.com/sourcegraph/sourcegraph/internal/completions/types"
2019
types "github.com/sourcegraph/sourcegraph/internal/modelconfig/types"
20+
"github.com/sourcegraph/sourcegraph/internal/openapi/goapi"
2121
)
2222

2323
// chatCompletionsHandler implements the REST endpoint /chat/completions
@@ -34,12 +34,10 @@ type chatCompletionsHandler struct {
3434
GetModelConfig GetModelConfigurationFunc
3535
}
3636

37-
type GetModelConfigurationFunc func() (*types.ModelConfiguration, error)
38-
3937
var _ http.Handler = (*chatCompletionsHandler)(nil)
4038

4139
func (h *chatCompletionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
42-
var chatCompletionRequest CreateChatCompletionRequest
40+
var chatCompletionRequest goapi.CreateChatCompletionRequest
4341
body, err := io.ReadAll(r.Body)
4442
if err != nil {
4543
http.Error(w, fmt.Sprintf("io.ReadAll: %v", err), http.StatusInternalServerError)
@@ -78,19 +76,14 @@ func (h *chatCompletionsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
7876

7977
chatCompletionResponse := transformToOpenAIResponse(sgResp, chatCompletionRequest)
8078

81-
w.Header().Set("Content-Type", "application/json")
82-
w.WriteHeader(http.StatusOK)
83-
if err = json.NewEncoder(w).Encode(chatCompletionResponse); err != nil {
84-
h.logger.Error("writing /chat/completions response body", log.Error(err))
85-
}
86-
79+
serveJSON(w, r, h.logger, chatCompletionResponse)
8780
}
8881

8982
// validateRequestedModel checks that are only use the modelref syntax
9083
// (${ProviderID}::${APIVersionID}::${ModelID}). If the user passes the old
9184
// syntax `${ProviderID}/${ModelID}`, then we try to return a helpful error
9285
// message suggesting to use the new modelref syntax.
93-
func validateRequestedModel(chatCompletionRequest CreateChatCompletionRequest, modelConfig *types.ModelConfiguration) string {
86+
func validateRequestedModel(chatCompletionRequest goapi.CreateChatCompletionRequest, modelConfig *types.ModelConfiguration) string {
9487
closestModelRef := ""
9588
for _, model := range modelConfig.Models {
9689
if string(model.ModelRef) == chatCompletionRequest.Model {
@@ -109,7 +102,7 @@ func validateRequestedModel(chatCompletionRequest CreateChatCompletionRequest, m
109102
return fmt.Sprintf("model %s is not supported%s", chatCompletionRequest.Model, didYouMean)
110103
}
111104

112-
func validateChatCompletionRequest(chatCompletionRequest CreateChatCompletionRequest) string {
105+
func validateChatCompletionRequest(chatCompletionRequest goapi.CreateChatCompletionRequest) string {
113106

114107
if chatCompletionRequest.N != nil && *chatCompletionRequest.N != 1 {
115108
return "n must be nil or 1"
@@ -148,7 +141,7 @@ func validateChatCompletionRequest(chatCompletionRequest CreateChatCompletionReq
148141
return ""
149142
}
150143

151-
func transformToSGRequest(openAIReq CreateChatCompletionRequest) completions.CodyCompletionRequestParameters {
144+
func transformToSGRequest(openAIReq goapi.CreateChatCompletionRequest) completions.CodyCompletionRequestParameters {
152145
maxTokens := 16 // Default in OpenAI openapi.yaml spec
153146
if openAIReq.MaxTokens != nil {
154147
maxTokens = *openAIReq.MaxTokens
@@ -178,7 +171,7 @@ func transformToSGRequest(openAIReq CreateChatCompletionRequest) completions.Cod
178171
}
179172
}
180173

181-
func transformMessages(messages []ChatCompletionRequestMessage) []completions.Message {
174+
func transformMessages(messages []goapi.ChatCompletionRequestMessage) []completions.Message {
182175
// Transform OpenAI messages to Sourcegraph format
183176
transformed := make([]completions.Message, len(messages))
184177
for i, msg := range messages {
@@ -248,23 +241,23 @@ func (h *chatCompletionsHandler) forwardToAPIHandler(sgReq completions.CodyCompl
248241
return &sgResp, nil
249242
}
250243

251-
func transformToOpenAIResponse(sgResp *completions.CompletionResponse, openAIReq CreateChatCompletionRequest) CreateChatCompletionResponse {
252-
return CreateChatCompletionResponse{
244+
func transformToOpenAIResponse(sgResp *completions.CompletionResponse, openAIReq goapi.CreateChatCompletionRequest) goapi.CreateChatCompletionResponse {
245+
return goapi.CreateChatCompletionResponse{
253246
ID: "chat-" + generateUUID(),
254247
Object: "chat.completion",
255248
Created: time.Now().Unix(),
256249
Model: openAIReq.Model,
257-
Choices: []ChatCompletionChoice{
250+
Choices: []goapi.ChatCompletionChoice{
258251
{
259252
Index: 0,
260-
Message: ChatCompletionResponseMessage{
253+
Message: goapi.ChatCompletionResponseMessage{
261254
Role: "assistant",
262255
Content: sgResp.Completion,
263256
},
264257
FinishReason: sgResp.StopReason,
265258
},
266259
},
267-
Usage: CompletionUsage{},
260+
Usage: goapi.CompletionUsage{},
268261
}
269262
}
270263

cmd/frontend/internal/llmapi/chat_completions_handler_test.go renamed to cmd/frontend/internal/llmapi/handler_chat_completions_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/stretchr/testify/assert"
1111

1212
types "github.com/sourcegraph/sourcegraph/internal/modelconfig/types"
13+
"github.com/sourcegraph/sourcegraph/internal/openapi/goapi"
1314
)
1415

1516
func TestChatCompletionsHandler(t *testing.T) {
@@ -89,7 +90,7 @@ func TestChatCompletionsHandler(t *testing.T) {
8990
t.Fatalf("Expected status code %d, got %d. Body: %s%s", http.StatusOK, rr.Code, rr.Body.String(), extraMessage)
9091
}
9192

92-
var resp CreateChatCompletionResponse
93+
var resp goapi.CreateChatCompletionResponse
9394
responseBytes := rr.Body.Bytes()
9495
err := json.Unmarshal(responseBytes, &resp)
9596
if err != nil {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package llmapi
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
sglog "github.com/sourcegraph/log"
8+
9+
"github.com/sourcegraph/sourcegraph/internal/openapi/goapi"
10+
)
11+
12+
type modelsHandler struct {
13+
logger sglog.Logger
14+
GetModelConfig GetModelConfigurationFunc
15+
}
16+
17+
var _ http.Handler = (*modelsHandler)(nil)
18+
19+
func (m *modelsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
20+
currentModelConfig, err := m.GetModelConfig()
21+
if err != nil {
22+
http.Error(w, fmt.Sprintf("modelConfigSvc.Get: %v", err), http.StatusInternalServerError)
23+
return
24+
}
25+
26+
var data []goapi.Model
27+
for _, model := range currentModelConfig.Models {
28+
data = append(data, goapi.Model{
29+
Object: "model",
30+
Id: string(model.ModelRef),
31+
OwnedBy: string(model.ModelRef.ProviderID()),
32+
})
33+
}
34+
35+
response := goapi.ListModelsResponse{
36+
Object: "list",
37+
Data: data,
38+
}
39+
40+
serveJSON(w, r, m.logger, response)
41+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package llmapi
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
"github.com/gorilla/mux"
8+
sglog "github.com/sourcegraph/log"
9+
10+
"github.com/sourcegraph/sourcegraph/internal/openapi/goapi"
11+
)
12+
13+
type modelsModelIDHandler struct {
14+
logger sglog.Logger
15+
GetModelConfig GetModelConfigurationFunc
16+
}
17+
18+
var _ http.Handler = (*modelsModelIDHandler)(nil)
19+
20+
func (m *modelsModelIDHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
21+
vars := mux.Vars(r)
22+
modelId := vars["modelId"]
23+
24+
if modelId == "" {
25+
http.Error(w, "modelId is required", http.StatusBadRequest)
26+
return
27+
}
28+
29+
currentModelConfig, err := m.GetModelConfig()
30+
if err != nil {
31+
http.Error(w, fmt.Sprintf("modelConfigSvc.Get: %v", err), http.StatusInternalServerError)
32+
return
33+
}
34+
35+
for _, model := range currentModelConfig.Models {
36+
if string(model.ModelRef) == modelId {
37+
response := goapi.Model{
38+
Object: "model",
39+
Id: string(model.ModelRef),
40+
OwnedBy: string(model.ModelRef.ProviderID()),
41+
}
42+
serveJSON(w, r, m.logger, response)
43+
return
44+
}
45+
}
46+
47+
http.Error(w, "Model not found", http.StatusNotFound)
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package llmapi
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
types "github.com/sourcegraph/sourcegraph/internal/modelconfig/types"
13+
"github.com/sourcegraph/sourcegraph/internal/openapi/goapi"
14+
)
15+
16+
func TestModelsModelIDHandler(t *testing.T) {
17+
mockModels := []types.Model{
18+
{
19+
ModelRef: "anthropic::unknown::claude-3-sonnet-20240229",
20+
DisplayName: "Claude 3 Sonnet",
21+
},
22+
{
23+
ModelRef: "openai::unknown::gpt-4",
24+
DisplayName: "GPT-4",
25+
},
26+
}
27+
28+
c := newTest(t, func() (*types.ModelConfiguration, error) {
29+
return &types.ModelConfiguration{Models: mockModels}, nil
30+
})
31+
32+
testCases := []struct {
33+
name string
34+
modelID string
35+
expectedStatus int
36+
expectedModel *goapi.Model
37+
}{
38+
{
39+
name: "Existing model",
40+
modelID: "anthropic::unknown::claude-3-sonnet-20240229",
41+
expectedStatus: http.StatusOK,
42+
expectedModel: &goapi.Model{
43+
Id: "anthropic::unknown::claude-3-sonnet-20240229",
44+
Object: "model",
45+
OwnedBy: "anthropic",
46+
},
47+
},
48+
{
49+
name: "Non-existent model",
50+
modelID: "non-existent-model",
51+
expectedStatus: http.StatusNotFound,
52+
expectedModel: nil,
53+
},
54+
}
55+
56+
for _, tc := range testCases {
57+
t.Run(tc.name, func(t *testing.T) {
58+
req, err := http.NewRequest("GET", "/.api/llm/models/"+tc.modelID, nil)
59+
require.NoError(t, err)
60+
61+
rr := httptest.NewRecorder()
62+
63+
c.Handler.ServeHTTP(rr, req)
64+
65+
assert.Equal(t, tc.expectedStatus, rr.Code)
66+
if rr.Code == http.StatusOK {
67+
assert.Equal(t, "application/json", rr.Header().Get("Content-Type"))
68+
}
69+
70+
if tc.expectedModel != nil {
71+
var response goapi.Model
72+
err = json.Unmarshal(rr.Body.Bytes(), &response)
73+
require.NoError(t, err)
74+
assert.Equal(t, response, *tc.expectedModel)
75+
}
76+
})
77+
}
78+
}

0 commit comments

Comments
 (0)