From 2537a88f70beee274178bf322ac7b10122de0722 Mon Sep 17 00:00:00 2001 From: "maye.chy" Date: Thu, 15 May 2025 18:06:38 +0800 Subject: [PATCH 1/3] feat: register support model --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index 21d4897c4..eb3831e32 100644 --- a/completion.go +++ b/completion.go @@ -166,6 +166,10 @@ func checkEndpointSupportsModel(endpoint, model string) bool { return !disabledModelsForEndpoints[endpoint][model] } +func RegisterSupportsModel(endpoint, model string) { + disabledModelsForEndpoints[endpoint][model] = true +} + func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) From 45bd3d8a02642b444999fb3f258b557fc14be74e Mon Sep 17 00:00:00 2001 From: "maye.chy" Date: Thu, 15 May 2025 18:15:05 +0800 Subject: [PATCH 2/3] add unit test --- completion_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/completion_test.go b/completion_test.go index 27e2d150e..8e3ba14c5 100644 --- a/completion_test.go +++ b/completion_test.go @@ -300,3 +300,27 @@ func TestCompletionWithGPT4oModels(t *testing.T) { }) } } + +func TestRegisterSupportsModel(t *testing.T) { + type args struct { + endpoint string + model string + } + tests := []struct { + name string + args args + }{ + { + name: "Register model ", + args: args{ + endpoint: "/chat/completions", + model: "local-model-3.5", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + openai.RegisterSupportsModel(tt.args.endpoint, tt.args.model) + }) + } +} From 959df8c294397a8e292d661ed6587c594ce7bc2c Mon Sep 17 00:00:00 2001 From: "maye.chy" Date: Thu, 15 May 2025 18:17:07 +0800 Subject: [PATCH 3/3] golint --- completion_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/completion_test.go b/completion_test.go index 8e3ba14c5..73fdd9519 100644 --- a/completion_test.go +++ b/completion_test.go @@ -319,7 +319,7 @@ func TestRegisterSupportsModel(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(_ *testing.T) { openai.RegisterSupportsModel(tt.args.endpoint, tt.args.model) }) }