diff --git a/internal/api/server.go b/internal/api/server.go index 0325ca30ce..6737a9aea0 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -330,6 +330,8 @@ func (s *Server) setupRoutes() { { v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers)) v1.POST("/chat/completions", openaiHandlers.ChatCompletions) + // Audio transcription uses a direct multipart passthrough path instead of the JSON translator stack. + v1.POST("/audio/transcriptions", openaiHandlers.AudioTranscriptions) v1.POST("/completions", openaiHandlers.Completions) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) @@ -353,6 +355,7 @@ func (s *Server) setupRoutes() { "message": "CLI Proxy API Server", "endpoints": []string{ "POST /v1/chat/completions", + "POST /v1/audio/transcriptions", "POST /v1/completions", "GET /v1/models", }, diff --git a/internal/api/server_test.go b/internal/api/server_test.go index f5c18aa167..86d0b2fdea 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -208,3 +208,18 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) { } } } + +func TestRootEndpointIncludesAudioTranscriptions(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + resp := httptest.NewRecorder() + server.engine.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if body := resp.Body.String(); !strings.Contains(body, "POST /v1/audio/transcriptions") { + t.Fatalf("response body missing audio transcription endpoint: %s", body) + } +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 14e2852ea7..b59b857f69 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -13,6 +13,7 @@ type staticModelsJSON struct { Vertex []*ModelInfo `json:"vertex"` GeminiCLI []*ModelInfo `json:"gemini-cli"` AIStudio []*ModelInfo `json:"aistudio"` + CodexShared []*ModelInfo `json:"codex-shared,omitempty"` CodexFree []*ModelInfo `json:"codex-free"` CodexTeam []*ModelInfo `json:"codex-team"` CodexPlus []*ModelInfo `json:"codex-plus"` @@ -50,22 +51,22 @@ func GetAIStudioModels() []*ModelInfo { // GetCodexFreeModels returns model definitions for the Codex free plan tier. func GetCodexFreeModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexFree) + return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexFree)) } // GetCodexTeamModels returns model definitions for the Codex team plan tier. func GetCodexTeamModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexTeam) + return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexTeam)) } // GetCodexPlusModels returns model definitions for the Codex plus plan tier. func GetCodexPlusModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexPlus) + return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexPlus)) } // GetCodexProModels returns model definitions for the Codex pro plan tier. func GetCodexProModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexPro) + return cloneModelInfos(mergeStaticModelSections(getModels().CodexShared, getModels().CodexPro)) } // GetQwenModels returns the standard Qwen model definitions. @@ -100,6 +101,31 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo { return out } +func mergeStaticModelSections(sections ...[]*ModelInfo) []*ModelInfo { + var out []*ModelInfo + seen := make(map[string]struct{}) + for _, section := range sections { + for _, model := range section { + if model == nil { + continue + } + modelID := strings.TrimSpace(model.ID) + if modelID == "" { + continue + } + if _, exists := seen[modelID]; exists { + continue + } + seen[modelID] = struct{}{} + out = append(out, model) + } + } + if len(out) == 0 { + return nil + } + return out +} + // GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider. // It returns nil when the channel is unknown. // @@ -156,7 +182,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { data.Vertex, data.GeminiCLI, data.AIStudio, - data.CodexPro, + mergeStaticModelSections(data.CodexShared, data.CodexPro), data.Qwen, data.IFlow, data.Kimi, diff --git a/internal/registry/model_definitions_test.go b/internal/registry/model_definitions_test.go new file mode 100644 index 0000000000..4a8ed45482 --- /dev/null +++ b/internal/registry/model_definitions_test.go @@ -0,0 +1,131 @@ +package registry + +import ( + "strings" + "testing" +) + +func TestCodexTierModelsIncludeSharedTranscriptionModels(t *testing.T) { + testCases := []struct { + name string + models []*ModelInfo + }{ + {name: "free", models: GetCodexFreeModels()}, + {name: "team", models: GetCodexTeamModels()}, + {name: "plus", models: GetCodexPlusModels()}, + {name: "pro", models: GetCodexProModels()}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if !modelListContainsID(tc.models, "gpt-4o-mini-transcribe") { + t.Fatalf("%s tier missing gpt-4o-mini-transcribe", tc.name) + } + if !modelListContainsID(tc.models, "whisper-1") { + t.Fatalf("%s tier missing whisper-1", tc.name) + } + }) + } +} + +func TestLookupStaticModelInfoFindsSharedCodexModels(t *testing.T) { + for _, modelID := range []string{"gpt-4o-mini-transcribe", "whisper-1"} { + info := LookupStaticModelInfo(modelID) + if info == nil { + t.Fatalf("LookupStaticModelInfo(%q) = nil", modelID) + } + if info.ID != modelID { + t.Fatalf("LookupStaticModelInfo(%q).ID = %q, want %q", modelID, info.ID, modelID) + } + } +} + +func TestValidateModelsCatalogRejectsCodexSharedOverlap(t *testing.T) { + data := newStaticModelsCatalogFixture() + data.CodexShared = []*ModelInfo{testStaticModel("shared-transcribe")} + data.CodexFree = []*ModelInfo{testStaticModel("shared-transcribe")} + + err := validateModelsCatalog(data) + if err == nil || !strings.Contains(err.Error(), `codex-shared overlaps codex-free on model id "shared-transcribe"`) { + t.Fatalf("validateModelsCatalog() error = %v, want overlap error", err) + } +} + +func TestDetectChangedProvidersTreatsCodexSharedAndDuplicatedSchemasAsEquivalent(t *testing.T) { + oldData := newStaticModelsCatalogFixture() + oldData.CodexFree = []*ModelInfo{ + testStaticModel("free-only"), + testStaticModel("gpt-4o-mini-transcribe"), + testStaticModel("whisper-1"), + } + oldData.CodexTeam = []*ModelInfo{ + testStaticModel("team-only"), + testStaticModel("gpt-4o-mini-transcribe"), + testStaticModel("whisper-1"), + } + oldData.CodexPlus = []*ModelInfo{ + testStaticModel("plus-only"), + testStaticModel("gpt-4o-mini-transcribe"), + testStaticModel("whisper-1"), + } + oldData.CodexPro = []*ModelInfo{ + testStaticModel("pro-only"), + testStaticModel("gpt-4o-mini-transcribe"), + testStaticModel("whisper-1"), + } + + newData := newStaticModelsCatalogFixture() + newData.CodexShared = []*ModelInfo{ + testStaticModel("gpt-4o-mini-transcribe"), + testStaticModel("whisper-1"), + } + newData.CodexFree = []*ModelInfo{testStaticModel("free-only")} + newData.CodexTeam = []*ModelInfo{testStaticModel("team-only")} + newData.CodexPlus = []*ModelInfo{testStaticModel("plus-only")} + newData.CodexPro = []*ModelInfo{testStaticModel("pro-only")} + + changed := detectChangedProviders(oldData, newData) + if stringSliceContains(changed, "codex") { + t.Fatalf("detectChangedProviders() = %v, want codex unchanged", changed) + } +} + +func newStaticModelsCatalogFixture() *staticModelsJSON { + return &staticModelsJSON{ + Claude: []*ModelInfo{testStaticModel("claude-base")}, + Gemini: []*ModelInfo{testStaticModel("gemini-base")}, + Vertex: []*ModelInfo{testStaticModel("vertex-base")}, + GeminiCLI: []*ModelInfo{testStaticModel("gemini-cli-base")}, + AIStudio: []*ModelInfo{testStaticModel("aistudio-base")}, + CodexFree: []*ModelInfo{testStaticModel("free-base")}, + CodexTeam: []*ModelInfo{testStaticModel("team-base")}, + CodexPlus: []*ModelInfo{testStaticModel("plus-base")}, + CodexPro: []*ModelInfo{testStaticModel("pro-base")}, + Qwen: []*ModelInfo{testStaticModel("qwen-base")}, + IFlow: []*ModelInfo{testStaticModel("iflow-base")}, + Kimi: []*ModelInfo{testStaticModel("kimi-base")}, + Antigravity: []*ModelInfo{testStaticModel("antigravity-base")}, + } +} + +func testStaticModel(id string) *ModelInfo { + return &ModelInfo{ID: id, Object: "model"} +} + +func modelListContainsID(models []*ModelInfo, modelID string) bool { + for _, model := range models { + if model != nil && model.ID == modelID { + return true + } + } + return false +} + +func stringSliceContains(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 197f604492..5e2079a055 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "sort" "strings" "sync" "time" @@ -209,10 +210,10 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string { {"vertex", oldData.Vertex, newData.Vertex}, {"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI}, {"aistudio", oldData.AIStudio, newData.AIStudio}, - {"codex", oldData.CodexFree, newData.CodexFree}, - {"codex", oldData.CodexTeam, newData.CodexTeam}, - {"codex", oldData.CodexPlus, newData.CodexPlus}, - {"codex", oldData.CodexPro, newData.CodexPro}, + {"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexFree), mergeStaticModelSections(newData.CodexShared, newData.CodexFree)}, + {"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexTeam), mergeStaticModelSections(newData.CodexShared, newData.CodexTeam)}, + {"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexPlus), mergeStaticModelSections(newData.CodexShared, newData.CodexPlus)}, + {"codex", mergeStaticModelSections(oldData.CodexShared, oldData.CodexPro), mergeStaticModelSections(newData.CodexShared, newData.CodexPro)}, {"qwen", oldData.Qwen, newData.Qwen}, {"iflow", oldData.IFlow, newData.IFlow}, {"kimi", oldData.Kimi, newData.Kimi}, @@ -225,7 +226,11 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string { if seen[s.provider] { continue } - if modelSectionChanged(s.oldList, s.newList) { + sectionChanged := modelSectionChanged(s.oldList, s.newList) + if s.provider == "codex" { + sectionChanged = modelSectionChangedIgnoringOrder(s.oldList, s.newList) + } + if sectionChanged { changed = append(changed, s.provider) seen[s.provider] = true } @@ -249,6 +254,21 @@ func modelSectionChanged(a, b []*ModelInfo) bool { return string(aj) != string(bj) } +func modelSectionChangedIgnoringOrder(a, b []*ModelInfo) bool { + return modelSectionChanged(sortedModelSectionByID(a), sortedModelSectionByID(b)) +} + +func sortedModelSectionByID(models []*ModelInfo) []*ModelInfo { + if len(models) == 0 { + return nil + } + cloned := cloneModelInfos(models) + sort.SliceStable(cloned, func(i, j int) bool { + return strings.TrimSpace(cloned[i].ID) < strings.TrimSpace(cloned[j].ID) + }) + return cloned +} + func notifyModelRefresh(changedProviders []string) { if len(changedProviders) == 0 { return @@ -346,6 +366,24 @@ func validateModelsCatalog(data *staticModelsJSON) error { return err } } + if len(data.CodexShared) > 0 { + if err := validateModelSection("codex-shared", data.CodexShared); err != nil { + return err + } + } + for _, section := range []struct { + name string + models []*ModelInfo + }{ + {name: "codex-free", models: data.CodexFree}, + {name: "codex-team", models: data.CodexTeam}, + {name: "codex-plus", models: data.CodexPlus}, + {name: "codex-pro", models: data.CodexPro}, + } { + if err := validateDisjointModelSections("codex-shared", data.CodexShared, section.name, section.models); err != nil { + return err + } + } return nil } @@ -370,3 +408,33 @@ func validateModelSection(section string, models []*ModelInfo) error { } return nil } + +func validateDisjointModelSections(leftName string, leftModels []*ModelInfo, rightName string, rightModels []*ModelInfo) error { + if len(leftModels) == 0 || len(rightModels) == 0 { + return nil + } + seen := make(map[string]struct{}, len(leftModels)) + for _, model := range leftModels { + if model == nil { + continue + } + modelID := strings.TrimSpace(model.ID) + if modelID == "" { + continue + } + seen[modelID] = struct{}{} + } + for _, model := range rightModels { + if model == nil { + continue + } + modelID := strings.TrimSpace(model.ID) + if modelID == "" { + continue + } + if _, exists := seen[modelID]; exists { + return fmt.Errorf("%s overlaps %s on model id %q", leftName, rightName, modelID) + } + } + return nil +} diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index 9a30478801..755fc73064 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -1167,6 +1167,36 @@ ] } ], + "codex-shared": [ + { + "id": "gpt-4o-mini-transcribe", + "object": "model", + "created": 1729728000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT-4o Mini Transcribe", + "version": "gpt-4o-mini-transcribe", + "description": "Fast speech-to-text model for audio transcription.", + "supported_parameters": [ + "language", + "prompt" + ] + }, + { + "id": "whisper-1", + "object": "model", + "created": 1679356800, + "owned_by": "openai", + "type": "openai", + "display_name": "Whisper", + "version": "whisper-1", + "description": "Speech-to-text model for audio transcription.", + "supported_parameters": [ + "language", + "prompt" + ] + } + ], "codex-free": [ { "id": "gpt-5", @@ -2680,4 +2710,4 @@ "max_completion_tokens": 32768 } ] -} \ No newline at end of file +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 4fb2291900..d2b03e2ef5 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -50,14 +50,7 @@ func (e *CodexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Aut return nil } apiKey, _ := codexCreds(auth) - if strings.TrimSpace(apiKey) != "" { - req.Header.Set("Authorization", "Bearer "+apiKey) - } - var attrs map[string]string - if auth != nil { - attrs = auth.Attributes - } - util.ApplyCustomHeadersFromAttrs(req, attrs) + applyCodexPreparedHeaders(req, auth, apiKey, e.cfg) return nil } @@ -637,8 +630,30 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form } func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) { + // JSON request paths use this helper so the content type is normalized to application/json. + // The PrepareRequest path intentionally avoids overriding caller-provided content types such as multipart/form-data. + applyCodexPreparedHeaders(r, auth, token, cfg) + if r == nil { + return + } r.Header.Set("Content-Type", "application/json") - r.Header.Set("Authorization", "Bearer "+token) + if stream { + r.Header.Set("Accept", "text/event-stream") + } else { + r.Header.Set("Accept", "application/json") + } +} + +func applyCodexPreparedHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, cfg *config.Config) { + if r == nil { + return + } + if r.Header == nil { + r.Header = make(http.Header) + } + if strings.TrimSpace(token) != "" { + r.Header.Set("Authorization", "Bearer "+token) + } var ginHeaders http.Header if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { @@ -649,13 +664,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString()) cfgUserAgent, _ := codexHeaderDefaults(cfg, auth) ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent) - - if stream { - r.Header.Set("Accept", "text/event-stream") - } else { + if strings.TrimSpace(r.Header.Get("Accept")) == "" { r.Header.Set("Accept", "application/json") } - r.Header.Set("Connection", "Keep-Alive") + if strings.TrimSpace(r.Header.Get("Connection")) == "" { + r.Header.Set("Connection", "Keep-Alive") + } isAPIKey := false if auth != nil && auth.Attributes != nil { diff --git a/internal/runtime/executor/codex_executor_test.go b/internal/runtime/executor/codex_executor_test.go new file mode 100644 index 0000000000..d138eb67fd --- /dev/null +++ b/internal/runtime/executor/codex_executor_test.go @@ -0,0 +1,42 @@ +package executor + +import ( + "net/http" + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestCodexPrepareRequestPreservesMultipartContentType(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://example.com/backend-api/transcribe", nil) + if err != nil { + t.Fatalf("NewRequest() error = %v", err) + } + req.Header.Set("Content-Type", "multipart/form-data; boundary=test-boundary") + + executor := NewCodexExecutor(nil) + auth := &cliproxyauth.Auth{ + Provider: "codex", + Metadata: map[string]any{ + "account_id": "account-123", + "email": "user@example.com", + }, + } + + if err := executor.PrepareRequest(req, auth); err != nil { + t.Fatalf("PrepareRequest() error = %v", err) + } + + if got := req.Header.Get("Content-Type"); got != "multipart/form-data; boundary=test-boundary" { + t.Fatalf("Content-Type = %q, want %q", got, "multipart/form-data; boundary=test-boundary") + } + if got := req.Header.Get("Chatgpt-Account-Id"); got != "account-123" { + t.Fatalf("Chatgpt-Account-Id = %q, want %q", got, "account-123") + } + if got := req.Header.Get("Originator"); got != "codex_cli_rs" { + t.Fatalf("Originator = %q, want %q", got, "codex_cli_rs") + } + if got := req.Header.Get("Accept"); got != "application/json" { + t.Fatalf("Accept = %q, want %q", got, "application/json") + } +} diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 0e490e3202..a252fbdf4c 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -497,12 +497,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType status = code } } - var addon http.Header - if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { - if hdr := he.Headers(); hdr != nil { - addon = hdr.Clone() - } - } + addon := filteredErrorHeaders(err) return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } if !PassthroughHeadersEnabled(h.Cfg) { @@ -511,6 +506,30 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil } +// ExecuteHTTPRequestWithAuthManager executes a raw HTTP request builder via the core auth manager. +// This path is intended for provider-native endpoints that do not go through the JSON translator stack. +func (h *BaseAPIHandler) ExecuteHTTPRequestWithAuthManager(ctx context.Context, modelName string, build coreauth.HTTPRequestBuilder) (*http.Response, *coreauth.Auth, *interfaces.ErrorMessage) { + providers, normalizedModel, errMsg := h.getRequestDetails(modelName) + if errMsg != nil { + return nil, nil, errMsg + } + reqMeta := requestExecutionMetadata(ctx) + reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel + opts := coreexecutor.Options{Metadata: reqMeta} + resp, auth, err := h.AuthManager.ExecuteHTTPRequest(ctx, providers, normalizedModel, normalizedModel, opts, build) + if err != nil { + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + addon := filteredErrorHeaders(err) + return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + } + return resp, auth, nil +} + // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) { @@ -543,12 +562,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle status = code } } - var addon http.Header - if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { - if hdr := he.Headers(); hdr != nil { - addon = hdr.Clone() - } - } + addon := filteredErrorHeaders(err) return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } if !PassthroughHeadersEnabled(h.Cfg) { @@ -594,12 +608,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl status = code } } - var addon http.Header - if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { - if hdr := he.Headers(); hdr != nil { - addon = hdr.Clone() - } - } + addon := filteredErrorHeaders(err) errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} close(errChan) return nil, nil, errChan @@ -839,6 +848,17 @@ func replaceHeader(dst http.Header, src http.Header) { } } +func filteredErrorHeaders(err error) http.Header { + if err == nil { + return nil + } + he, ok := err.(interface{ Headers() http.Header }) + if !ok || he == nil { + return nil + } + return FilterUpstreamHeaders(he.Headers()) +} + // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { status := http.StatusInternalServerError diff --git a/sdk/api/handlers/handlers_error_response_test.go b/sdk/api/handlers/handlers_error_response_test.go index cde4547fff..6cd9353edf 100644 --- a/sdk/api/handlers/handlers_error_response_test.go +++ b/sdk/api/handlers/handlers_error_response_test.go @@ -12,6 +12,21 @@ import ( sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) +type headerCarrierError struct { + headers http.Header +} + +func (e *headerCarrierError) Error() string { + return "upstream failure" +} + +func (e *headerCarrierError) Headers() http.Header { + if e == nil { + return nil + } + return e.headers.Clone() +} + func TestWriteErrorResponse_AddonHeadersDisabledByDefault(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() @@ -66,3 +81,23 @@ func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) { t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"}) } } + +func TestFilteredErrorHeaders_RemovesBlockedHeaders(t *testing.T) { + filtered := filteredErrorHeaders(&headerCarrierError{ + headers: http.Header{ + "Retry-After": {"30"}, + "Set-Cookie": {"session=secret"}, + "Connection": {"keep-alive"}, + }, + }) + + if got := filtered.Get("Retry-After"); got != "30" { + t.Fatalf("Retry-After = %q, want %q", got, "30") + } + if got := filtered.Get("Set-Cookie"); got != "" { + t.Fatalf("Set-Cookie should be filtered, got %q", got) + } + if got := filtered.Get("Connection"); got != "" { + t.Fatalf("Connection should be filtered, got %q", got) + } +} diff --git a/sdk/api/handlers/openai/openai_audio_handlers.go b/sdk/api/handlers/openai/openai_audio_handlers.go new file mode 100644 index 0000000000..5ac215b414 --- /dev/null +++ b/sdk/api/handlers/openai/openai_audio_handlers.go @@ -0,0 +1,1329 @@ +package openai + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "time" + "unicode" + "unicode/utf8" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +const ( + // Keep the upload cap fixed for this endpoint to preserve current behavior + // without introducing another configuration knob in the same PR. + audioTranscriptionUploadLimitBytes int64 = 32 << 20 + audioTranscriptionNonFileFieldsLimitBytes int64 = 1 << 20 + audioTranscriptionContentSniffBytes = 512 + audioTranscriptionFileFieldName = "file" + audioTranscriptionModelFieldName = "model" + audioTranscriptionResponseFormatFieldName = "response_format" + audioTranscriptionStreamFieldName = "stream" + audioTranscriptionDefaultFilename = "audio.webm" + audioTranscriptionTempFilePattern = "cliproxy-audio-transcription-*" + audioTranscriptionResponseTempFilePattern = "cliproxy-audio-transcription-response-*" + openAIAudioTranscriptionsPath = "/audio/transcriptions" + defaultCodexAudioTranscriptionURL = "https://chatgpt.com/backend-api/transcribe" +) + +var supportedAudioFileExtensions = map[string]struct{}{ + ".flac": {}, + ".m4a": {}, + ".mp3": {}, + ".mp4": {}, + ".mpeg": {}, + ".mpga": {}, + ".ogg": {}, + ".wav": {}, + ".webm": {}, +} + +var supportedAudioMediaTypes = map[string]struct{}{ + "audio/flac": {}, + "audio/m4a": {}, + "audio/mp3": {}, + "audio/mp4": {}, + "audio/mpeg": {}, + "audio/mpga": {}, + "audio/ogg": {}, + "audio/wav": {}, + "audio/webm": {}, + "video/mp4": {}, + "video/webm": {}, +} + +var supportedAudioResponseFormats = []string{ + "diarized_json", + "json", + "srt", + "text", + "verbose_json", + "vtt", +} + +var supportedAudioResponseFormatSet = newAudioResponseFormatSet(supportedAudioResponseFormats...) +var audioJSONResponseFormats = newAudioResponseFormatSet("", "diarized_json", "json", "verbose_json") +var audioRawResponseFormats = newAudioResponseFormatSet("srt", "text", "vtt") + +type audioFormField struct { + Name string + Value string +} + +type audioTranscriptionRequest struct { + Model string + ResponseFormat string + Stream bool + Fields []audioFormField + FileName string + FileContentType string + StagedFilePath string +} + +type audioRequestError struct { + status int + msg string +} + +func (e *audioRequestError) Error() string { + if e == nil { + return "" + } + return e.msg +} + +func (e *audioRequestError) StatusCode() int { + if e == nil || e.status <= 0 { + return http.StatusInternalServerError + } + return e.status +} + +// AudioTranscriptions handles the /v1/audio/transcriptions endpoint. +func (h *OpenAIAPIHandler) AudioTranscriptions(c *gin.Context) { + audioReq, err := parseAudioTranscriptionRequest(c) + if err != nil { + c.JSON(statusCodeOrDefault(err, http.StatusBadRequest), handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: err.Error(), + Type: "invalid_request_error", + }, + }) + return + } + defer func() { + _ = audioReq.Cleanup() + }() + var pinnedAuthID string + audioReq.Model, pinnedAuthID, err = resolveAudioAutoSelection(h.AuthManager, audioReq.Model) + if err != nil { + c.JSON(statusCodeOrDefault(err, http.StatusBadRequest), handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: err.Error(), + Type: "invalid_request_error", + }, + }) + return + } + + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + if pinnedAuthID != "" { + cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID) + } + upstreamResp, _, errMsg := h.ExecuteHTTPRequestWithAuthManager(cliCtx, audioReq.Model, func(ctx context.Context, auth *coreauth.Auth, upstreamModel string) (*http.Request, error) { + return audioReq.BuildHTTPRequest(ctx, auth, upstreamModel) + }) + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + cliCancel(errMsg.Error) + return + } + defer func() { + _ = upstreamResp.Body.Close() + }() + + filteredHeaders := handlers.FilterUpstreamHeaders(upstreamResp.Header) + if audioReq.ShouldStreamResponse(filteredHeaders, upstreamResp.Header) { + if err := h.writeAudioStreamingResponse(c, upstreamResp.Body, filteredHeaders, upstreamResp.Header); err != nil { + h.WriteErrorResponse(c, &interfaces.ErrorMessage{ + StatusCode: statusCodeOrDefault(err, http.StatusBadGateway), + Error: err, + }) + cliCancel(err) + return + } + cliCancel(nil) + return + } + + if audioReq.PreserveRawResponse() { + if err := h.writeAudioRawResponse(c, upstreamResp.Body, filteredHeaders, upstreamResp.Header); err != nil { + h.WriteErrorResponse(c, &interfaces.ErrorMessage{ + StatusCode: statusCodeOrDefault(err, http.StatusBadGateway), + Error: err, + }) + cliCancel(err) + return + } + cliCancel(nil) + return + } + + c.Header("Content-Type", "application/json") + if handlers.PassthroughHeadersEnabled(h.Cfg) { + handlers.WriteUpstreamHeaders(c.Writer.Header(), filteredHeaders) + } + stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) + normalizedPath, err := normalizeAudioTranscriptionResponseFromReader(upstreamResp.Body) + stopKeepAlive() + if err != nil { + h.WriteErrorResponse(c, &interfaces.ErrorMessage{ + StatusCode: statusCodeOrDefault(err, http.StatusBadGateway), + Error: err, + }) + cliCancel(err) + return + } + defer func() { + if normalizedPath != "" { + _ = os.Remove(normalizedPath) + } + }() + + c.Status(http.StatusOK) + if err := writeStagedAudioTranscriptionResponse(c.Writer, normalizedPath); err != nil { + cliCancel(err) + return + } + cliCancel() +} + +func (h *OpenAIAPIHandler) writeAudioStreamingResponse(c *gin.Context, body io.Reader, filteredHeaders, upstreamHeaders http.Header) error { + if c == nil { + return fmt.Errorf("missing response writer") + } + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return &audioRequestError{status: http.StatusInternalServerError, msg: "streaming not supported"} + } + if contentType := audioTranscriptionContentType(filteredHeaders, upstreamHeaders); contentType != "" { + c.Header("Content-Type", contentType) + } else { + c.Header("Content-Type", "text/event-stream") + } + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + if handlers.PassthroughHeadersEnabled(h.Cfg) { + handlers.WriteUpstreamHeaders(c.Writer.Header(), filteredHeaders) + } + c.Status(http.StatusOK) + return copyAudioResponse(c, flusher, body) +} + +func (h *OpenAIAPIHandler) writeAudioRawResponse(c *gin.Context, body io.Reader, filteredHeaders, upstreamHeaders http.Header) error { + if c == nil { + return fmt.Errorf("missing response writer") + } + if contentType := audioTranscriptionContentType(filteredHeaders, upstreamHeaders); contentType != "" { + c.Header("Content-Type", contentType) + } + if handlers.PassthroughHeadersEnabled(h.Cfg) { + handlers.WriteUpstreamHeaders(c.Writer.Header(), filteredHeaders) + } + c.Status(http.StatusOK) + _, err := io.Copy(c.Writer, body) + return err +} + +func copyAudioResponse(c *gin.Context, flusher http.Flusher, body io.Reader) error { + if c == nil || flusher == nil { + return fmt.Errorf("streaming not supported") + } + if body == nil { + return nil + } + buf := make([]byte, 32<<10) + for { + n, err := body.Read(buf) + if n > 0 { + if _, errWrite := c.Writer.Write(buf[:n]); errWrite != nil { + return errWrite + } + flusher.Flush() + } + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return err + } + select { + case <-c.Request.Context().Done(): + return c.Request.Context().Err() + default: + } + } +} + +func parseAudioTranscriptionRequest(c *gin.Context) (*audioTranscriptionRequest, error) { + if c == nil || c.Request == nil { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: "missing request"} + } + + reader, err := c.Request.MultipartReader() + if err != nil { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: fmt.Sprintf("invalid multipart form: %v", err)} + } + + audioReq := &audioTranscriptionRequest{} + cleanupOnError := true + var nonFileBytesRead int64 + defer func() { + if cleanupOnError { + _ = audioReq.Cleanup() + } + }() + + hasFile := false + for { + part, errNext := reader.NextPart() + if errors.Is(errNext, io.EOF) { + break + } + if errNext != nil { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: fmt.Sprintf("invalid multipart form: %v", errNext)} + } + if part == nil { + continue + } + + partName := part.FormName() + fileName := part.FileName() + + if fileName == "" { + fieldValue, errRead := readAudioTranscriptionField(part, audioTranscriptionNonFileFieldsLimitBytes-nonFileBytesRead) + _ = part.Close() + if errRead != nil { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: fmt.Sprintf("failed to read form field %q: %v", partName, errRead)} + } + nonFileBytesRead += int64(len(fieldValue)) + if nonFileBytesRead > audioTranscriptionNonFileFieldsLimitBytes { + return nil, &audioRequestError{ + status: http.StatusBadRequest, + msg: fmt.Sprintf("non-file multipart fields exceed %d byte limit", audioTranscriptionNonFileFieldsLimitBytes), + } + } + field := audioFormField{Name: partName, Value: string(fieldValue)} + audioReq.Fields = append(audioReq.Fields, field) + if partName == audioTranscriptionModelFieldName && audioReq.Model == "" { + audioReq.Model = strings.TrimSpace(field.Value) + } + if partName == audioTranscriptionResponseFormatFieldName && audioReq.ResponseFormat == "" { + audioReq.ResponseFormat = strings.TrimSpace(field.Value) + } + if partName == audioTranscriptionStreamFieldName && !audioReq.Stream { + if parsed, errParse := strconv.ParseBool(strings.TrimSpace(field.Value)); errParse == nil { + audioReq.Stream = parsed + } + } + continue + } + + if partName != audioTranscriptionFileFieldName { + _ = part.Close() + continue + } + if hasFile { + _ = part.Close() + return nil, &audioRequestError{status: http.StatusBadRequest, msg: "only one file upload is supported"} + } + + audioReq.FileName = fileName + audioReq.FileContentType = strings.TrimSpace(part.Header.Get("Content-Type")) + if errStage := audioReq.stageFilePart(part); errStage != nil { + _ = part.Close() + return nil, errStage + } + _ = part.Close() + hasFile = true + } + + if audioReq.Model == "" { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: "missing required field: model"} + } + if err := validateAudioResponseFormat(audioReq.ResponseFormat); err != nil { + return nil, err + } + if !hasFile { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: "missing required field: file"} + } + + sort.SliceStable(audioReq.Fields, func(i, j int) bool { + return audioReq.Fields[i].Name < audioReq.Fields[j].Name + }) + + cleanupOnError = false + return audioReq, nil +} + +func readAudioTranscriptionField(part *multipart.Part, remainingBytes int64) ([]byte, error) { + if remainingBytes < 0 { + return nil, fmt.Errorf("non-file multipart fields exceed %d byte limit", audioTranscriptionNonFileFieldsLimitBytes) + } + limitedReader := &io.LimitedReader{R: part, N: remainingBytes + 1} + payload, err := io.ReadAll(limitedReader) + if err != nil { + return nil, err + } + if int64(len(payload)) > remainingBytes { + return nil, fmt.Errorf("non-file multipart fields exceed %d byte limit", audioTranscriptionNonFileFieldsLimitBytes) + } + return payload, nil +} + +func (r *audioTranscriptionRequest) stageFilePart(part *multipart.Part) error { + if r == nil { + return &audioRequestError{status: http.StatusBadRequest, msg: "audio transcription request is empty"} + } + + tempFile, err := os.CreateTemp("", audioTranscriptionTempFilePattern) + if err != nil { + return fmt.Errorf("failed to create temp file for audio upload: %w", err) + } + + tempPath := tempFile.Name() + keepTempFile := false + defer func() { + _ = tempFile.Close() + if !keepTempFile { + _ = os.Remove(tempPath) + } + }() + + limitedReader := &io.LimitedReader{R: part, N: audioTranscriptionUploadLimitBytes + 1} + sniffBuffer := make([]byte, audioTranscriptionContentSniffBytes) + sniffedBytes, readErr := io.ReadFull(limitedReader, sniffBuffer) + switch { + case errors.Is(readErr, io.EOF): + return &audioRequestError{status: http.StatusBadRequest, msg: "uploaded file is empty"} + case errors.Is(readErr, io.ErrUnexpectedEOF): + case readErr != nil: + return &audioRequestError{status: http.StatusBadRequest, msg: fmt.Sprintf("failed to read uploaded file: %v", readErr)} + } + + sniffedContent := sniffBuffer[:sniffedBytes] + detectedContentType := http.DetectContentType(sniffedContent) + resolvedContentType := strings.TrimSpace(r.FileContentType) + if resolvedContentType == "" { + resolvedContentType = detectedContentType + } + + if errValidate := validateAudioFile(r.FileName, resolvedContentType, detectedContentType); errValidate != nil { + return errValidate + } + + if _, errWrite := tempFile.Write(sniffedContent); errWrite != nil { + return fmt.Errorf("failed to stage uploaded file: %w", errWrite) + } + + copiedBytes, errCopy := io.Copy(tempFile, limitedReader) + if errCopy != nil { + return &audioRequestError{status: http.StatusBadRequest, msg: fmt.Sprintf("failed to read uploaded file: %v", errCopy)} + } + + totalBytes := int64(sniffedBytes) + copiedBytes + if totalBytes > audioTranscriptionUploadLimitBytes || limitedReader.N == 0 { + return &audioRequestError{ + status: http.StatusBadRequest, + msg: fmt.Sprintf("uploaded file exceeds %d byte limit", audioTranscriptionUploadLimitBytes), + } + } + + if errClose := tempFile.Close(); errClose != nil { + return fmt.Errorf("failed to finalize staged audio upload: %w", errClose) + } + + r.StagedFilePath = tempPath + r.FileContentType = resolvedContentType + keepTempFile = true + return nil +} + +func validateAudioFile(fileName string, contentTypes ...string) error { + ext := strings.ToLower(filepath.Ext(strings.TrimSpace(fileName))) + if _, ok := supportedAudioFileExtensions[ext]; ok { + return nil + } + for _, contentType := range contentTypes { + mediaType := strings.ToLower(strings.TrimSpace(contentType)) + if mediaType == "" { + continue + } + if parsedMediaType, _, err := mime.ParseMediaType(mediaType); err == nil { + mediaType = strings.ToLower(strings.TrimSpace(parsedMediaType)) + } + if _, ok := supportedAudioMediaTypes[mediaType]; ok { + return nil + } + } + return &audioRequestError{ + status: http.StatusBadRequest, + msg: "unsupported audio format; supported formats are flac, m4a, mp3, mp4, mpeg, mpga, ogg, wav, and webm", + } +} + +func validateAudioResponseFormat(responseFormat string) error { + responseFormat = normalizeAudioResponseFormat(responseFormat) + if isAudioJSONResponseFormat(responseFormat) { + return nil + } + if _, ok := supportedAudioResponseFormatSet[responseFormat]; ok { + return nil + } + return &audioRequestError{ + status: http.StatusBadRequest, + msg: fmt.Sprintf("unsupported response_format %q; supported values are %s", responseFormat, describeAudioResponseFormats()), + } +} + +func (r *audioTranscriptionRequest) Cleanup() error { + if r == nil || r.StagedFilePath == "" { + return nil + } + err := os.Remove(r.StagedFilePath) + if err == nil || errors.Is(err, os.ErrNotExist) { + r.StagedFilePath = "" + return nil + } + return err +} + +func (r *audioTranscriptionRequest) BuildHTTPRequest(ctx context.Context, auth *coreauth.Auth, upstreamModel string) (*http.Request, error) { + if r == nil { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: "audio transcription request is empty"} + } + if strings.TrimSpace(r.StagedFilePath) == "" { + return nil, &audioRequestError{status: http.StatusBadRequest, msg: "audio transcription file is not staged"} + } + + targetURL, err := resolveAudioTranscriptionURL(auth) + if err != nil { + return nil, err + } + + bodyReader, bodyWriter := io.Pipe() + multipartWriter := multipart.NewWriter(bodyWriter) + contentLength, err := r.multipartContentLength(multipartWriter.Boundary(), upstreamModel) + if err != nil { + _ = bodyWriter.Close() + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bodyReader) + if err != nil { + _ = bodyWriter.Close() + return nil, err + } + req.Header.Set("Content-Type", multipartWriter.FormDataContentType()) + if accept := audioTranscriptionAcceptHeader(r.Stream, r.ResponseFormat); accept != "" { + req.Header.Set("Accept", accept) + } + req.ContentLength = contentLength + + go r.writeMultipartBody(bodyWriter, multipartWriter, upstreamModel) + return req, nil +} + +func (r *audioTranscriptionRequest) writeMultipartBody(bodyWriter *io.PipeWriter, multipartWriter *multipart.Writer, upstreamModel string) { + defer func() { + if recovered := recover(); recovered != nil { + closeAudioMultipartWriterWithError(bodyWriter, fmt.Errorf("panic while writing audio transcription multipart body: %v", recovered)) + } + }() + if err := r.writeMultipartFields(multipartWriter, upstreamModel); err != nil { + closeAudioMultipartWriterWithError(bodyWriter, err) + return + } + if err := multipartWriter.Close(); err != nil { + closeAudioMultipartWriterWithError(bodyWriter, err) + return + } + _ = bodyWriter.Close() +} + +func (r *audioTranscriptionRequest) writeMultipartFields(multipartWriter *multipart.Writer, upstreamModel string) error { + modelValue := normalizeAudioUpstreamModel(upstreamModel, r.Model) + for _, field := range r.Fields { + fieldValue := field.Value + if field.Name == audioTranscriptionModelFieldName { + fieldValue = modelValue + } + if err := multipartWriter.WriteField(field.Name, fieldValue); err != nil { + return err + } + } + + filename := strings.TrimSpace(r.FileName) + if filename == "" { + filename = audioTranscriptionDefaultFilename + } + + filePartHeader := make(textproto.MIMEHeader) + filePartHeader.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{ + "name": audioTranscriptionFileFieldName, + "filename": filename, + })) + if contentType := strings.TrimSpace(r.FileContentType); contentType != "" { + filePartHeader.Set("Content-Type", contentType) + } + + filePart, err := multipartWriter.CreatePart(filePartHeader) + if err != nil { + return err + } + + stagedFile, err := os.Open(r.StagedFilePath) + if err != nil { + return fmt.Errorf("failed to open staged audio file: %w", err) + } + defer func() { + _ = stagedFile.Close() + }() + + _, err = io.Copy(filePart, stagedFile) + return err +} + +func (r *audioTranscriptionRequest) multipartContentLength(boundary, upstreamModel string) (int64, error) { + if r == nil { + return 0, &audioRequestError{status: http.StatusBadRequest, msg: "audio transcription request is empty"} + } + fileInfo, err := os.Stat(r.StagedFilePath) + if err != nil { + return 0, fmt.Errorf("failed to stat staged audio file: %w", err) + } + counter := &countingWriter{} + multipartWriter := multipart.NewWriter(counter) + if err := multipartWriter.SetBoundary(boundary); err != nil { + return 0, err + } + modelValue := normalizeAudioUpstreamModel(upstreamModel, r.Model) + for _, field := range r.Fields { + fieldValue := field.Value + if field.Name == audioTranscriptionModelFieldName { + fieldValue = modelValue + } + if err := multipartWriter.WriteField(field.Name, fieldValue); err != nil { + return 0, err + } + } + + filename := strings.TrimSpace(r.FileName) + if filename == "" { + filename = audioTranscriptionDefaultFilename + } + filePartHeader := make(textproto.MIMEHeader) + filePartHeader.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{ + "name": audioTranscriptionFileFieldName, + "filename": filename, + })) + if contentType := strings.TrimSpace(r.FileContentType); contentType != "" { + filePartHeader.Set("Content-Type", contentType) + } + if _, err := multipartWriter.CreatePart(filePartHeader); err != nil { + return 0, err + } + counter.n += fileInfo.Size() + if err := multipartWriter.Close(); err != nil { + return 0, err + } + return counter.n, nil +} + +func normalizeAudioUpstreamModel(upstreamModel, fallbackModel string) string { + modelValue := strings.TrimSpace(upstreamModel) + if modelValue == "" { + modelValue = strings.TrimSpace(fallbackModel) + } + baseModel := strings.TrimSpace(thinking.ParseSuffix(modelValue).ModelName) + if baseModel != "" { + return baseModel + } + return modelValue +} + +func audioTranscriptionAcceptHeader(stream bool, responseFormat string) string { + if stream { + return "text/event-stream" + } + if !isAudioJSONResponseFormat(normalizeAudioResponseFormat(responseFormat)) { + return "" + } + return "application/json" +} + +func closeAudioMultipartWriterWithError(bodyWriter *io.PipeWriter, err error) { + if bodyWriter == nil || err == nil { + return + } + _ = bodyWriter.CloseWithError(err) +} + +func newAudioResponseFormatSet(values ...string) map[string]struct{} { + set := make(map[string]struct{}, len(values)) + for _, value := range values { + set[value] = struct{}{} + } + return set +} + +func normalizeAudioResponseFormat(responseFormat string) string { + return strings.ToLower(strings.TrimSpace(responseFormat)) +} + +func isAudioJSONResponseFormat(responseFormat string) bool { + _, ok := audioJSONResponseFormats[normalizeAudioResponseFormat(responseFormat)] + return ok +} + +func isAudioRawResponseFormat(responseFormat string) bool { + _, ok := audioRawResponseFormats[normalizeAudioResponseFormat(responseFormat)] + return ok +} + +func describeAudioResponseFormats() string { + if len(supportedAudioResponseFormats) == 0 { + return "" + } + if len(supportedAudioResponseFormats) == 1 { + return supportedAudioResponseFormats[0] + } + if len(supportedAudioResponseFormats) == 2 { + return supportedAudioResponseFormats[0] + " and " + supportedAudioResponseFormats[1] + } + head := strings.Join(supportedAudioResponseFormats[:len(supportedAudioResponseFormats)-1], ", ") + return head + ", and " + supportedAudioResponseFormats[len(supportedAudioResponseFormats)-1] +} + +func resolveAudioTranscriptionURL(auth *coreauth.Auth) (string, error) { + if auth == nil { + return "", &audioRequestError{status: http.StatusBadGateway, msg: "no auth selected for audio transcription"} + } + if isOpenAICompatibleAuth(auth) { + baseURL := "" + if auth.Attributes != nil { + baseURL = strings.TrimSpace(auth.Attributes["base_url"]) + } + if baseURL == "" { + return "", &audioRequestError{status: http.StatusBadGateway, msg: "selected OpenAI-compatible auth is missing base_url"} + } + return strings.TrimSuffix(baseURL, "/") + openAIAudioTranscriptionsPath, nil + } + if strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") { + if auth.Attributes != nil { + if baseURL := strings.TrimSpace(auth.Attributes["base_url"]); baseURL != "" { + return resolveCodexAudioTranscriptionURL(baseURL), nil + } + } + return defaultCodexAudioTranscriptionURL, nil + } + return "", &audioRequestError{ + status: http.StatusNotImplemented, + msg: fmt.Sprintf("audio transcription is not supported for provider %q", strings.TrimSpace(auth.Provider)), + } +} + +func isOpenAICompatibleAuth(auth *coreauth.Auth) bool { + if auth == nil { + return false + } + if auth.Attributes != nil { + if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" { + return true + } + } + return strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") +} + +func resolveAudioAutoModel(manager *coreauth.Manager, modelName string) (string, error) { + model, _, err := resolveAudioAutoSelection(manager, modelName) + return model, err +} + +func resolveAudioAutoSelection(manager *coreauth.Manager, modelName string) (string, string, error) { + parsed := thinking.ParseSuffix(strings.TrimSpace(modelName)) + if parsed.ModelName != "auto" { + return modelName, "", nil + } + selection, err := resolveAutoAudioModelBase(manager) + if err != nil { + return "", "", &audioRequestError{status: http.StatusBadRequest, msg: err.Error()} + } + if parsed.HasSuffix { + return fmt.Sprintf("%s(%s)", selection.RouteModel, parsed.RawSuffix), selection.AuthID, nil + } + return selection.RouteModel, selection.AuthID, nil +} + +type audioAutoModelCandidate struct { + RouteModel string + CreatedAt int64 + AuthID string +} + +func resolveAutoAudioModelBase(manager *coreauth.Manager) (audioAutoModelCandidate, error) { + if manager == nil { + return audioAutoModelCandidate{}, fmt.Errorf("model auto is not supported for audio transcription because the auth manager is unavailable") + } + + candidates := make([]audioAutoModelCandidate, 0) + now := time.Now() + + for _, preview := range manager.PreviewSelectableRouteModels(now, supportsAudioTranscriptionAuth) { + if !audioPreviewSupportsTranscription(preview) { + continue + } + authID := "" + if preview.Auth != nil { + authID = strings.TrimSpace(preview.Auth.ID) + } + candidates = append(candidates, audioAutoModelCandidate{ + RouteModel: preview.RouteModel, + CreatedAt: preview.CreatedAt, + AuthID: authID, + }) + } + + if len(candidates) == 0 { + return audioAutoModelCandidate{}, fmt.Errorf("model auto is not supported for audio transcription because no transcription-capable model is available") + } + + sort.SliceStable(candidates, func(i, j int) bool { + if candidates[i].CreatedAt == candidates[j].CreatedAt { + if candidates[i].RouteModel == candidates[j].RouteModel { + return candidates[i].AuthID < candidates[j].AuthID + } + return candidates[i].RouteModel < candidates[j].RouteModel + } + return candidates[i].CreatedAt > candidates[j].CreatedAt + }) + return candidates[0], nil +} + +func supportsAudioTranscriptionAuth(auth *coreauth.Auth) bool { + _, err := resolveAudioTranscriptionURL(auth) + return err == nil +} + +func audioPreviewSupportsTranscription(preview coreauth.RouteModelPreview) bool { + if len(preview.UpstreamModels) == 0 { + return false + } + for _, upstreamModel := range preview.UpstreamModels { + if !isAudioTranscriptionModel(preview.Auth, upstreamModel) { + return false + } + } + return true +} + +func isAudioTranscriptionModel(auth *coreauth.Auth, modelName string) bool { + baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName) + if baseModel == "" { + return false + } + provider := "" + if auth != nil { + provider = strings.TrimSpace(auth.Provider) + } + if info := registry.LookupModelInfo(baseModel, provider); info != nil { + return isAudioTranscriptionModelInfo(info) + } + if info := registry.LookupModelInfo(baseModel); info != nil { + return isAudioTranscriptionModelInfo(info) + } + return isAudioTranscriptionModelName(baseModel) +} + +func isAudioTranscriptionModelInfo(info *registry.ModelInfo) bool { + if info == nil { + return false + } + if len(info.SupportedInputModalities) > 0 || len(info.SupportedOutputModalities) > 0 { + supportsAudioInput := len(info.SupportedInputModalities) == 0 || containsFolded(info.SupportedInputModalities, "audio") + supportsTextOutput := len(info.SupportedOutputModalities) == 0 || containsFolded(info.SupportedOutputModalities, "text") + return supportsAudioInput && supportsTextOutput + } + return isAudioTranscriptionModelName(info.ID) || + isAudioTranscriptionModelName(info.Version) || + isAudioTranscriptionModelName(info.DisplayName) || + isAudioTranscriptionModelName(info.Description) +} + +func isAudioTranscriptionModelName(value string) bool { + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" { + return false + } + return strings.Contains(value, "transcribe") || + strings.Contains(value, "transcription") || + strings.Contains(value, "speech-to-text") || + strings.Contains(value, "whisper") +} + +func containsFolded(values []string, needle string) bool { + needle = strings.TrimSpace(needle) + if needle == "" { + return false + } + for _, value := range values { + if strings.EqualFold(strings.TrimSpace(value), needle) { + return true + } + } + return false +} + +func resolveCodexAudioTranscriptionURL(baseURL string) string { + trimmedBaseURL := strings.TrimRight(strings.TrimSpace(baseURL), "/") + // Configured Codex base URLs are expected to follow the same backend-api root + // convention as the existing Codex HTTP endpoints, where /codex and /transcribe + // are sibling paths under the same root. + if strings.HasSuffix(trimmedBaseURL, "/codex") { + return strings.TrimSuffix(trimmedBaseURL, "/codex") + "/transcribe" + } + return trimmedBaseURL + "/transcribe" +} + +func (r *audioTranscriptionRequest) PreserveRawResponse() bool { + if r == nil { + return false + } + return isAudioRawResponseFormat(r.ResponseFormat) +} + +func (r *audioTranscriptionRequest) ShouldStreamResponse(headers ...http.Header) bool { + if r != nil && r.Stream { + return true + } + return audioTranscriptionIsEventStream(headers...) +} + +func audioTranscriptionContentType(headers ...http.Header) string { + for _, hdr := range headers { + if hdr == nil { + continue + } + if contentType := strings.TrimSpace(hdr.Get("Content-Type")); contentType != "" { + return contentType + } + } + return "" +} + +func audioTranscriptionIsEventStream(headers ...http.Header) bool { + mediaType := audioTranscriptionContentType(headers...) + if mediaType == "" { + return false + } + parsedMediaType, _, err := mime.ParseMediaType(mediaType) + if err == nil { + mediaType = parsedMediaType + } + return strings.EqualFold(strings.TrimSpace(mediaType), "text/event-stream") +} + +func normalizeAudioTranscriptionResponseFromReader(body io.Reader) (string, error) { + stagedPath, err := stageAudioTranscriptionResponse(body) + if err != nil { + return "", err + } + if stagedPath != "" { + defer func() { + _ = os.Remove(stagedPath) + }() + } + + normalizedFile, err := os.CreateTemp("", audioTranscriptionResponseTempFilePattern) + if err != nil { + return "", fmt.Errorf("failed to create normalized transcription response temp file: %w", err) + } + normalizedPath := normalizedFile.Name() + keepNormalizedFile := false + defer func() { + _ = normalizedFile.Close() + if !keepNormalizedFile { + _ = os.Remove(normalizedPath) + } + }() + + switch kind, errKind := detectAudioTranscriptionResponseKind(stagedPath); { + case errKind != nil: + return "", errKind + case kind == audioTranscriptionResponseEmpty: + if _, err := normalizedFile.Write([]byte(`{"text":""}`)); err != nil { + return "", fmt.Errorf("failed to write normalized empty transcription response: %w", err) + } + case kind == audioTranscriptionResponseJSONObject: + hasText, err := audioTranscriptionObjectHasText(stagedPath) + if err != nil { + return "", err + } + sourceFile, err := os.Open(stagedPath) + if err != nil { + return "", fmt.Errorf("failed to open staged upstream transcription response: %w", err) + } + if err := writeNormalizedAudioTranscriptionObject(normalizedFile, sourceFile, !hasText); err != nil { + _ = sourceFile.Close() + return "", err + } + _ = sourceFile.Close() + case kind == audioTranscriptionResponseJSONString: + sourceFile, err := os.Open(stagedPath) + if err != nil { + return "", fmt.Errorf("failed to open staged upstream transcription response: %w", err) + } + text, err := readAudioTranscriptionJSONString(sourceFile) + _ = sourceFile.Close() + if err != nil { + return "", err + } + payload, err := json.Marshal(map[string]string{"text": text}) + if err != nil { + return "", fmt.Errorf("failed to normalize upstream transcription string response: %w", err) + } + if _, err := normalizedFile.Write(payload); err != nil { + return "", fmt.Errorf("failed to write normalized transcription string response: %w", err) + } + default: + sourceFile, err := os.Open(stagedPath) + if err != nil { + return "", fmt.Errorf("failed to open staged upstream transcription response: %w", err) + } + if err := writeWrappedAudioTranscriptionText(normalizedFile, sourceFile); err != nil { + _ = sourceFile.Close() + return "", err + } + _ = sourceFile.Close() + } + + if err := normalizedFile.Close(); err != nil { + return "", fmt.Errorf("failed to finalize normalized transcription response: %w", err) + } + keepNormalizedFile = true + return normalizedPath, nil +} + +func stageAudioTranscriptionResponse(body io.Reader) (string, error) { + if body == nil { + return "", nil + } + tempFile, err := os.CreateTemp("", audioTranscriptionResponseTempFilePattern) + if err != nil { + return "", fmt.Errorf("failed to create temp file for transcription response: %w", err) + } + tempPath := tempFile.Name() + keepTempFile := false + defer func() { + _ = tempFile.Close() + if !keepTempFile { + _ = os.Remove(tempPath) + } + }() + if _, errCopy := io.Copy(tempFile, body); errCopy != nil { + return "", fmt.Errorf("failed to stage upstream transcription response: %w", errCopy) + } + if errClose := tempFile.Close(); errClose != nil { + return "", fmt.Errorf("failed to finalize staged transcription response: %w", errClose) + } + keepTempFile = true + return tempPath, nil +} + +type audioTranscriptionResponseKind int + +const ( + audioTranscriptionResponseEmpty audioTranscriptionResponseKind = iota + audioTranscriptionResponseJSONObject + audioTranscriptionResponseJSONString + audioTranscriptionResponseText +) + +func detectAudioTranscriptionResponseKind(path string) (audioTranscriptionResponseKind, error) { + if strings.TrimSpace(path) == "" { + return audioTranscriptionResponseEmpty, nil + } + file, err := os.Open(path) + if err != nil { + return audioTranscriptionResponseEmpty, fmt.Errorf("failed to open staged upstream transcription response: %w", err) + } + defer func() { + _ = file.Close() + }() + reader := bufio.NewReader(file) + for { + r, _, err := reader.ReadRune() + if errors.Is(err, io.EOF) { + return audioTranscriptionResponseEmpty, nil + } + if err != nil { + return audioTranscriptionResponseEmpty, fmt.Errorf("failed to inspect staged upstream transcription response: %w", err) + } + if unicode.IsSpace(r) { + continue + } + switch r { + case '{': + return audioTranscriptionResponseJSONObject, nil + case '"': + return audioTranscriptionResponseJSONString, nil + default: + return audioTranscriptionResponseText, nil + } + } +} + +func audioTranscriptionObjectHasText(path string) (bool, error) { + file, err := os.Open(path) + if err != nil { + return false, fmt.Errorf("failed to open staged upstream transcription response: %w", err) + } + defer func() { + _ = file.Close() + }() + decoder := json.NewDecoder(file) + tok, err := decoder.Token() + if err != nil { + return false, fmt.Errorf("failed to parse upstream transcription object: %w", err) + } + delim, ok := tok.(json.Delim) + if !ok || delim != '{' { + return false, fmt.Errorf("upstream transcription response is not a JSON object") + } + for decoder.More() { + keyToken, err := decoder.Token() + if err != nil { + return false, fmt.Errorf("failed to parse upstream transcription object key: %w", err) + } + key, ok := keyToken.(string) + if !ok { + return false, fmt.Errorf("upstream transcription object contains a non-string key") + } + var raw json.RawMessage + if err := decoder.Decode(&raw); err != nil { + return false, fmt.Errorf("failed to parse upstream transcription object field %q: %w", key, err) + } + if key == "text" { + return true, nil + } + } + if _, err := decoder.Token(); err != nil { + return false, fmt.Errorf("failed to parse upstream transcription object end: %w", err) + } + return false, nil +} + +func writeNormalizedAudioTranscriptionObject(dst io.Writer, src io.Reader, injectText bool) error { + if dst == nil { + return fmt.Errorf("missing normalized transcription writer") + } + if src == nil { + _, err := io.WriteString(dst, `{"text":""}`) + return err + } + decoder := json.NewDecoder(src) + tok, err := decoder.Token() + if err != nil { + return fmt.Errorf("failed to parse upstream transcription object: %w", err) + } + delim, ok := tok.(json.Delim) + if !ok || delim != '{' { + return fmt.Errorf("upstream transcription response is not a JSON object") + } + if _, err := io.WriteString(dst, "{"); err != nil { + return err + } + wroteField := false + if injectText { + if _, err := io.WriteString(dst, `"text":""`); err != nil { + return err + } + wroteField = true + } + for decoder.More() { + keyToken, err := decoder.Token() + if err != nil { + return fmt.Errorf("failed to parse upstream transcription object key: %w", err) + } + key, ok := keyToken.(string) + if !ok { + return fmt.Errorf("upstream transcription object contains a non-string key") + } + var raw json.RawMessage + if err := decoder.Decode(&raw); err != nil { + return fmt.Errorf("failed to parse upstream transcription object field %q: %w", key, err) + } + if wroteField { + if _, err := io.WriteString(dst, ","); err != nil { + return err + } + } + keyJSON, err := json.Marshal(key) + if err != nil { + return fmt.Errorf("failed to encode transcription object key %q: %w", key, err) + } + if _, err := dst.Write(keyJSON); err != nil { + return err + } + if _, err := io.WriteString(dst, ":"); err != nil { + return err + } + if _, err := dst.Write(raw); err != nil { + return err + } + wroteField = true + } + if _, err := decoder.Token(); err != nil { + return fmt.Errorf("failed to parse upstream transcription object end: %w", err) + } + _, err = io.WriteString(dst, "}") + return err +} + +func readAudioTranscriptionJSONString(src io.Reader) (string, error) { + if src == nil { + return "", nil + } + var text string + if err := json.NewDecoder(src).Decode(&text); err != nil { + return "", fmt.Errorf("failed to decode upstream transcription string response: %w", err) + } + return text, nil +} + +func writeWrappedAudioTranscriptionText(dst io.Writer, src io.Reader) error { + if dst == nil { + return fmt.Errorf("missing normalized transcription writer") + } + if _, err := io.WriteString(dst, `{"text":"`); err != nil { + return err + } + if src != nil { + reader := bufio.NewReader(src) + for { + r, size, err := reader.ReadRune() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return fmt.Errorf("failed to read staged transcription text response: %w", err) + } + if r == utf8.RuneError && size == 1 { + if _, err := io.WriteString(dst, `\ufffd`); err != nil { + return err + } + continue + } + switch r { + case '\\', '"': + if _, err := io.WriteString(dst, `\`+string(r)); err != nil { + return err + } + case '\b': + if _, err := io.WriteString(dst, `\b`); err != nil { + return err + } + case '\f': + if _, err := io.WriteString(dst, `\f`); err != nil { + return err + } + case '\n': + if _, err := io.WriteString(dst, `\n`); err != nil { + return err + } + case '\r': + if _, err := io.WriteString(dst, `\r`); err != nil { + return err + } + case '\t': + if _, err := io.WriteString(dst, `\t`); err != nil { + return err + } + default: + if r < 0x20 { + if _, err := fmt.Fprintf(dst, "\\u%04x", r); err != nil { + return err + } + continue + } + var encoded [utf8.UTFMax]byte + n := utf8.EncodeRune(encoded[:], r) + if _, err := dst.Write(encoded[:n]); err != nil { + return err + } + } + } + } + _, err := io.WriteString(dst, `"}`) + return err +} + +func writeStagedAudioTranscriptionResponse(dst io.Writer, path string) error { + if dst == nil { + return fmt.Errorf("missing transcription response writer") + } + if strings.TrimSpace(path) == "" { + _, err := io.WriteString(dst, `{"text":""}`) + return err + } + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("failed to open normalized transcription response: %w", err) + } + defer func() { + _ = file.Close() + }() + _, err = io.Copy(dst, file) + return err +} + +func statusCodeOrDefault(err error, fallback int) int { + if err == nil { + return fallback + } + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + return code + } + } + return fallback +} + +type countingWriter struct { + n int64 +} + +func (w *countingWriter) Write(p []byte) (int, error) { + w.n += int64(len(p)) + return len(p), nil +} diff --git a/sdk/api/handlers/openai/openai_audio_handlers_test.go b/sdk/api/handlers/openai/openai_audio_handlers_test.go new file mode 100644 index 0000000000..60da7b414f --- /dev/null +++ b/sdk/api/handlers/openai/openai_audio_handlers_test.go @@ -0,0 +1,1539 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "mime" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/textproto" + "os" + "sort" + "strings" + "sync" + "testing" + "time" + + "github.com/gin-gonic/gin" + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +const audioTestModel = "gpt-4o-mini-transcribe" + +type audioExecutorResponse struct { + status int + body string + contentType string + headers http.Header +} + +type audioCaptureExecutor struct { + id string + mu sync.Mutex + calls int + delay time.Duration + authIDs []string + lastURL string + lastAuthID string + lastFields map[string][]string + lastFileName string + lastFileBody []byte + lastLength int64 + lastAccept string + responses map[string]audioExecutorResponse +} + +type panicWriter struct{} + +func (panicWriter) Write([]byte) (int, error) { + panic("boom") +} + +func (e *audioCaptureExecutor) Identifier() string { + if strings.TrimSpace(e.id) != "" { + return e.id + } + return "openai-compatibility" +} + +func (e *audioCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *audioCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + return nil, errors.New("not implemented") +} + +func (e *audioCaptureExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *audioCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *audioCaptureExecutor) HttpRequest(_ context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) { + e.mu.Lock() + e.calls++ + if auth != nil { + e.lastAuthID = auth.ID + e.authIDs = append(e.authIDs, auth.ID) + } + e.lastURL = req.URL.String() + e.lastFields = make(map[string][]string) + e.lastFileName = "" + e.lastFileBody = nil + e.lastLength = req.ContentLength + e.lastAccept = req.Header.Get("Accept") + e.mu.Unlock() + + mediaType, params, err := mime.ParseMediaType(req.Header.Get("Content-Type")) + if err != nil { + return nil, err + } + if mediaType != "multipart/form-data" { + return nil, errors.New("unexpected content type") + } + + reader := multipart.NewReader(req.Body, params["boundary"]) + for { + part, errNext := reader.NextPart() + if errors.Is(errNext, io.EOF) { + break + } + if errNext != nil { + return nil, errNext + } + + payload, errRead := io.ReadAll(part) + _ = part.Close() + if errRead != nil { + return nil, errRead + } + + e.mu.Lock() + if part.FormName() == audioTranscriptionFileFieldName { + e.lastFileName = part.FileName() + e.lastFileBody = payload + } else { + e.lastFields[part.FormName()] = append(e.lastFields[part.FormName()], string(payload)) + } + e.mu.Unlock() + } + + authID := "" + if auth != nil { + authID = auth.ID + } + + respCfg := audioExecutorResponse{ + status: http.StatusOK, + body: "transcribed text", + contentType: "text/plain; charset=utf-8", + } + if e.responses != nil { + if candidate, ok := e.responses[authID]; ok { + respCfg = candidate + } + } + if e.delay > 0 { + time.Sleep(e.delay) + } + if respCfg.status == 0 { + respCfg.status = http.StatusOK + } + if respCfg.contentType == "" { + respCfg.contentType = "text/plain; charset=utf-8" + } + + respHeaders := make(http.Header) + if respCfg.headers != nil { + respHeaders = respCfg.headers.Clone() + } + if strings.TrimSpace(respCfg.contentType) != "" { + respHeaders.Set("Content-Type", respCfg.contentType) + } + + return &http.Response{ + StatusCode: respCfg.status, + Header: respHeaders, + Body: io.NopCloser(strings.NewReader(respCfg.body)), + }, nil +} + +func (e *audioCaptureExecutor) Calls() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.calls +} + +func (e *audioCaptureExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.authIDs)) + copy(out, e.authIDs) + return out +} + +func TestAudioTranscriptionsWrapsPlainTextResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "prompt": "caption this", + "language": "en", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if strings.TrimSpace(resp.Body.String()) != `{"text":"transcribed text"}` { + t.Fatalf("body = %s", resp.Body.String()) + } + if executor.Calls() != 1 { + t.Fatalf("executor calls = %d, want 1", executor.Calls()) + } + if executor.lastAuthID != "audio-auth" { + t.Fatalf("last auth id = %q, want %q", executor.lastAuthID, "audio-auth") + } + if executor.lastURL != "https://api.example.com/v1/audio/transcriptions" { + t.Fatalf("last URL = %q, want %q", executor.lastURL, "https://api.example.com/v1/audio/transcriptions") + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != audioTestModel { + t.Fatalf("model field = %v", got) + } + if got := executor.lastFields["prompt"]; len(got) != 1 || got[0] != "caption this" { + t.Fatalf("prompt field = %v", got) + } + if got := executor.lastFields["language"]; len(got) != 1 || got[0] != "en" { + t.Fatalf("language field = %v", got) + } + if executor.lastFileName != "sample.webm" { + t.Fatalf("file name = %q, want %q", executor.lastFileName, "sample.webm") + } + if string(executor.lastFileBody) != "fake-audio" { + t.Fatalf("file body = %q, want %q", string(executor.lastFileBody), "fake-audio") + } + if executor.lastLength <= 0 { + t.Fatalf("content length = %d, want > 0", executor.lastLength) + } +} + +func TestAudioTranscriptionsRejectsUnsupportedFileFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + }, audioTranscriptionFileFieldName, "notes.txt", "text/plain", []byte("not audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusBadRequest, resp.Body.String()) + } + if !strings.Contains(resp.Body.String(), "unsupported audio format") { + t.Fatalf("body = %s", resp.Body.String()) + } + if executor.Calls() != 0 { + t.Fatalf("executor calls = %d, want 0", executor.Calls()) + } +} + +func TestAudioTranscriptionsRejectUnsupportedResponseFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "response_format": "markdown", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusBadRequest, resp.Body.String()) + } + if !strings.Contains(resp.Body.String(), "unsupported response_format") || !strings.Contains(resp.Body.String(), "markdown") { + t.Fatalf("body = %s", resp.Body.String()) + } + if executor.Calls() != 0 { + t.Fatalf("executor calls = %d, want 0", executor.Calls()) + } +} + +func TestAudioTranscriptionsPreserveExplicitTextResponseFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{ + responses: map[string]audioExecutorResponse{ + "audio-auth": { + status: http.StatusOK, + body: "raw transcript", + contentType: "text/plain; charset=utf-8", + }, + }, + } + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "response_format": "text", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if strings.TrimSpace(resp.Body.String()) != "raw transcript" { + t.Fatalf("body = %s", resp.Body.String()) + } + if got := resp.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" { + t.Fatalf("content type = %q, want %q", got, "text/plain; charset=utf-8") + } + if got := executor.lastAccept; got != "" { + t.Fatalf("accept = %q, want empty for raw response format", got) + } +} + +func TestAudioTranscriptionsAcceptDiarizedJSONResponseFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{ + responses: map[string]audioExecutorResponse{ + "audio-auth": { + status: http.StatusOK, + body: `{"segments":[{"speaker":"A","text":"hello"}],"language":"en"}`, + contentType: "application/json", + }, + }, + } + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "response_format": "diarized_json", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if got := executor.lastAccept; got != "application/json" { + t.Fatalf("accept = %q, want %q", got, "application/json") + } + + var payload map[string]any + if err := json.Unmarshal(resp.Body.Bytes(), &payload); err != nil { + t.Fatalf("Unmarshal(response): %v", err) + } + if got, ok := payload["text"].(string); !ok || got != "" { + t.Fatalf("text = %#v, want empty string", payload["text"]) + } + if _, ok := payload["segments"]; !ok { + t.Fatalf("segments missing from payload: %v", payload) + } + if got, ok := payload["language"].(string); !ok || got != "en" { + t.Fatalf("language = %#v, want %q", payload["language"], "en") + } +} + +func TestAudioTranscriptionsPreserveExplicitVTTResponseFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{ + responses: map[string]audioExecutorResponse{ + "audio-auth": { + status: http.StatusOK, + body: "WEBVTT\n\n00:00.000 --> 00:01.000\nHello", + contentType: "text/vtt; charset=utf-8", + }, + }, + } + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "response_format": "vtt", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if got := resp.Header().Get("Content-Type"); got != "text/vtt; charset=utf-8" { + t.Fatalf("content type = %q, want %q", got, "text/vtt; charset=utf-8") + } + if !strings.Contains(resp.Body.String(), "WEBVTT") { + t.Fatalf("body = %s", resp.Body.String()) + } +} + +func TestAudioTranscriptionsRawResponseDoesNotEmitKeepAlive(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{ + delay: 2200 * time.Millisecond, + responses: map[string]audioExecutorResponse{ + "audio-auth": { + status: http.StatusOK, + body: "raw transcript", + contentType: "text/plain; charset=utf-8", + }, + }, + } + router := newAudioTestRouterWithConfig(t, &sdkconfig.SDKConfig{ + NonStreamKeepAliveInterval: 1, + }, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "response_format": "text", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%q", resp.Code, http.StatusOK, resp.Body.String()) + } + if resp.Body.String() != "raw transcript" { + t.Fatalf("body = %q, want %q", resp.Body.String(), "raw transcript") + } + if got := resp.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" { + t.Fatalf("content type = %q, want %q", got, "text/plain; charset=utf-8") + } +} + +func TestNormalizeAudioTranscriptionResponseFromReader(t *testing.T) { + testCases := []struct { + name string + body []byte + want string + }{ + { + name: "empty body", + body: nil, + want: `{"text":""}`, + }, + { + name: "existing text field preserved", + body: []byte(`{"text":"hello","segments":[{"id":1}]}`), + want: `{"text":"hello","segments":[{"id":1}]}`, + }, + { + name: "bare json string wrapped", + body: []byte(`"hello"`), + want: `{"text":"hello"}`, + }, + { + name: "raw text wrapped", + body: []byte(`hello`), + want: `{"text":"hello"}`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := normalizeAudioTranscriptionResponseForTest(t, tc.body) + if got != tc.want { + t.Fatalf("normalizeAudioTranscriptionResponseFromReader() = %s, want %s", got, tc.want) + } + }) + } +} + +func TestWriteMultipartBodyRecoversPanic(t *testing.T) { + stagedFile, err := os.CreateTemp("", "audio-transcription-panic-*") + if err != nil { + t.Fatalf("CreateTemp(): %v", err) + } + defer func() { + _ = os.Remove(stagedFile.Name()) + }() + if _, err := stagedFile.Write([]byte("fake-audio")); err != nil { + t.Fatalf("Write(stagedFile): %v", err) + } + if err := stagedFile.Close(); err != nil { + t.Fatalf("Close(stagedFile): %v", err) + } + + req := &audioTranscriptionRequest{ + Model: audioTestModel, + Fields: []audioFormField{{Name: audioTranscriptionModelFieldName, Value: audioTestModel}}, + FileName: "sample.webm", + FileContentType: "audio/webm", + StagedFilePath: stagedFile.Name(), + } + bodyReader, bodyWriter := io.Pipe() + multipartWriter := multipart.NewWriter(panicWriter{}) + done := make(chan struct{}) + go func() { + defer close(done) + req.writeMultipartBody(bodyWriter, multipartWriter, audioTestModel) + }() + + _, err = io.ReadAll(bodyReader) + <-done + if err == nil || !strings.Contains(err.Error(), "panic while writing audio transcription multipart body") { + t.Fatalf("ReadAll(bodyReader) error = %v, want recovered panic error", err) + } +} + +func normalizeAudioTranscriptionResponseForTest(t *testing.T, body []byte) string { + t.Helper() + + var reader io.Reader + if body != nil { + reader = bytes.NewReader(body) + } + normalizedPath, err := normalizeAudioTranscriptionResponseFromReader(reader) + if err != nil { + t.Fatalf("normalizeAudioTranscriptionResponseFromReader(): %v", err) + } + defer func() { + if normalizedPath != "" { + _ = os.Remove(normalizedPath) + } + }() + + payload, err := os.ReadFile(normalizedPath) + if err != nil { + t.Fatalf("ReadFile(normalizedPath): %v", err) + } + return string(payload) +} + +func TestResolveAudioTranscriptionURL_DefaultCodexOAuth(t *testing.T) { + url, err := resolveAudioTranscriptionURL(&coreauth.Auth{ + ID: "codex-auth", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{ + "email": "user@example.com", + }, + }) + if err != nil { + t.Fatalf("resolveAudioTranscriptionURL() error = %v", err) + } + if url != defaultCodexAudioTranscriptionURL { + t.Fatalf("resolveAudioTranscriptionURL() = %q, want %q", url, defaultCodexAudioTranscriptionURL) + } +} + +func TestResolveAudioTranscriptionURL_ConfiguredCodexBaseURL(t *testing.T) { + url, err := resolveAudioTranscriptionURL(&coreauth.Auth{ + ID: "codex-auth", + Provider: "codex", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://chatgpt.com/backend-api/codex", + }, + }) + if err != nil { + t.Fatalf("resolveAudioTranscriptionURL() error = %v", err) + } + if url != "https://chatgpt.com/backend-api/transcribe" { + t.Fatalf("resolveAudioTranscriptionURL() = %q, want %q", url, "https://chatgpt.com/backend-api/transcribe") + } +} + +func TestAudioTranscriptionsStreamTruePassthroughSSE(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{ + responses: map[string]audioExecutorResponse{ + "audio-auth": { + status: http.StatusOK, + body: "event: transcript.text.delta\ndata: {\"delta\":\"hi\"}\n\n", + contentType: "text/event-stream", + }, + }, + } + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "stream": "true", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if strings.TrimSpace(resp.Body.String()) != strings.TrimSpace("event: transcript.text.delta\ndata: {\"delta\":\"hi\"}") { + t.Fatalf("body = %s", resp.Body.String()) + } + if got := resp.Header().Get("Content-Type"); got != "text/event-stream" { + t.Fatalf("content type = %q, want %q", got, "text/event-stream") + } + if got := executor.lastAccept; got != "text/event-stream" { + t.Fatalf("accept = %q, want %q", got, "text/event-stream") + } +} + +func TestAudioTranscriptionsRejectOversizedNonFileFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "prompt": strings.Repeat("a", int(audioTranscriptionNonFileFieldsLimitBytes+1)), + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusBadRequest, resp.Body.String()) + } + if !strings.Contains(resp.Body.String(), "non-file multipart fields exceed") { + t.Fatalf("body = %s", resp.Body.String()) + } + if executor.Calls() != 0 { + t.Fatalf("executor calls = %d, want 0", executor.Calls()) + } +} + +func TestAudioTranscriptionsTreatEventStreamAsStreamingResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{ + responses: map[string]audioExecutorResponse{ + "audio-auth": { + status: http.StatusOK, + body: "event: transcript.text.done\ndata: {\"text\":\"hello\"}\n\n", + contentType: "text/event-stream; charset=utf-8", + }, + }, + } + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + "stream": "definitely", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if got := resp.Header().Get("Content-Type"); got != "text/event-stream; charset=utf-8" { + t.Fatalf("content type = %q, want %q", got, "text/event-stream; charset=utf-8") + } + if !strings.Contains(resp.Body.String(), "event: transcript.text.done") { + t.Fatalf("body = %s", resp.Body.String()) + } +} + +func TestAudioTranscriptionsNormalizeLargeUpstreamResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + largeTranscript := strings.Repeat("x", 8<<20+1) + executor := &audioCaptureExecutor{ + responses: map[string]audioExecutorResponse{ + "audio-auth": { + status: http.StatusOK, + body: largeTranscript, + contentType: "text/plain; charset=utf-8", + }, + }, + } + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body prefix=%s", resp.Code, http.StatusOK, resp.Body.String()[:min(len(resp.Body.String()), 64)]) + } + var payload map[string]string + if err := json.Unmarshal(resp.Body.Bytes(), &payload); err != nil { + t.Fatalf("Unmarshal(response): %v", err) + } + if got := payload["text"]; len(got) != len(largeTranscript) { + t.Fatalf("text len = %d, want %d", len(got), len(largeTranscript)) + } +} + +func TestAudioTranscriptionsRetriesAcrossAuthsOnRetriableHTTPFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{ + responses: map[string]audioExecutorResponse{ + "auth1": { + status: http.StatusInternalServerError, + body: `{"error":"temporary failure"}`, + contentType: "application/json", + }, + "auth2": { + status: http.StatusOK, + body: "retried transcription", + contentType: "text/plain; charset=utf-8", + }, + }, + } + router := newAudioTestRouter(t, executor, + &coreauth.Auth{ + ID: "auth1", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }, + &coreauth.Auth{ + ID: "auth2", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }, + ) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if strings.TrimSpace(resp.Body.String()) != `{"text":"retried transcription"}` { + t.Fatalf("body = %s", resp.Body.String()) + } + if executor.Calls() != 2 { + t.Fatalf("executor calls = %d, want 2", executor.Calls()) + } + if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth1" || got[1] != "auth2" { + t.Fatalf("auth IDs = %v, want [auth1 auth2]", got) + } +} + +func TestAudioTranscriptionsRetriesAcrossAuthsOnRequestBuildFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor, + &coreauth.Auth{ + ID: "auth1", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + }, + &coreauth.Auth{ + ID: "auth2", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }, + ) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if executor.Calls() != 1 { + t.Fatalf("executor calls = %d, want 1", executor.Calls()) + } + if got := executor.AuthIDs(); len(got) != 1 || got[0] != "auth2" { + t.Fatalf("auth IDs = %v, want [auth2]", got) + } +} + +func TestAudioTranscriptionsCleansUpStagedTempFiles(t *testing.T) { + gin.SetMode(gin.TestMode) + + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + t.Setenv("TMP", tempDir) + t.Setenv("TEMP", tempDir) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel, + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + + entries, err := os.ReadDir(tempDir) + if err != nil { + t.Fatalf("ReadDir(%s): %v", tempDir, err) + } + if len(entries) != 0 { + names := make([]string, 0, len(entries)) + for _, entry := range entries { + names = append(names, entry.Name()) + } + t.Fatalf("expected temp dir to be empty, got %v", names) + } +} + +func TestAudioTranscriptionsResolveAutoToTranscriptionModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + reg := registry.GetGlobalRegistry() + reg.RegisterClient("chat-auto-client", "openai", []*registry.ModelInfo{{ + ID: "gpt-5.2", + Created: time.Now().Unix() + 1000, + }}) + t.Cleanup(func() { + reg.UnregisterClient("chat-auto-client") + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": "auto", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != audioTestModel { + t.Fatalf("model field = %v, want [%s]", got, audioTestModel) + } +} + +func TestAudioTranscriptionsStripThinkingSuffixFromMultipartModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + router := newAudioTestRouter(t, executor, &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": audioTestModel + "(high)", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != audioTestModel { + t.Fatalf("model field = %v, want [%s]", got, audioTestModel) + } +} + +func TestAudioTranscriptionsPreserveExplicitWhisperModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{} + auth := &coreauth.Auth{ + ID: "audio-auth", + Provider: "openai-compatibility", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "base_url": "https://api.example.com/v1", + }, + } + router := newAudioTestRouter(t, executor, auth) + + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{ + {ID: audioTestModel}, + {ID: "whisper-1"}, + }) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": "whisper-1", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if executor.lastAuthID != auth.ID { + t.Fatalf("last auth id = %q, want %q", executor.lastAuthID, auth.ID) + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != "whisper-1" { + t.Fatalf("model field = %v, want [whisper-1]", got) + } +} + +func TestAudioTranscriptionsResolveAutoToCompatibleAliasModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{id: "pool"} + manager := coreauth.NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: []internalconfig.OpenAICompatibilityModel{ + {Name: audioTestModel, Alias: "voice-default"}, + }, + }}, + }) + manager.RegisterExecutor(executor) + + poolAuth := &coreauth.Auth{ + ID: "pool-auth", + Provider: "pool", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := manager.Register(context.Background(), poolAuth); err != nil { + t.Fatalf("register pool auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(poolAuth.ID, "pool", []*registry.ModelInfo{{ + ID: "voice-default", + Created: time.Now().Unix(), + }}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(poolAuth.ID) + }) + + geminiAuth := &coreauth.Auth{ + ID: "gemini-auth", + Provider: "gemini", + Status: coreauth.StatusActive, + } + if _, err := manager.Register(context.Background(), geminiAuth); err != nil { + t.Fatalf("register gemini auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(geminiAuth.ID, "gemini", []*registry.ModelInfo{{ + ID: "gemini-speech-to-text", + Created: time.Now().Unix() + 1000, + }}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(geminiAuth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIAPIHandler(base) + router := gin.New() + router.POST("/v1/audio/transcriptions", h.AudioTranscriptions) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": "auto", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if executor.lastAuthID != poolAuth.ID { + t.Fatalf("last auth id = %q, want %q", executor.lastAuthID, poolAuth.ID) + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != audioTestModel { + t.Fatalf("model field = %v, want [%s]", got, audioTestModel) + } +} + +func TestAudioTranscriptionsResolveAutoToWhisperAliasModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{id: "pool"} + manager := coreauth.NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: []internalconfig.OpenAICompatibilityModel{ + {Name: "whisper-1", Alias: "voice-whisper"}, + }, + }}, + }) + manager.RegisterExecutor(executor) + + poolAuth := &coreauth.Auth{ + ID: "pool-auth", + Provider: "pool", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := manager.Register(context.Background(), poolAuth); err != nil { + t.Fatalf("register pool auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(poolAuth.ID, "pool", []*registry.ModelInfo{{ + ID: "voice-whisper", + Created: time.Now().Unix(), + }}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(poolAuth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIAPIHandler(base) + router := gin.New() + router.POST("/v1/audio/transcriptions", h.AudioTranscriptions) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": "auto", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if executor.lastAuthID != poolAuth.ID { + t.Fatalf("last auth id = %q, want %q", executor.lastAuthID, poolAuth.ID) + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != "whisper-1" { + t.Fatalf("model field = %v, want [whisper-1]", got) + } +} + +func TestAudioTranscriptionsResolveAutoSkipsUnavailableAuths(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{id: "pool"} + manager := coreauth.NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: []internalconfig.OpenAICompatibilityModel{ + {Name: audioTestModel, Alias: "voice-new"}, + }, + }}, + }) + manager.RegisterExecutor(executor) + + blockedAuth := &coreauth.Auth{ + ID: "blocked-auth", + Provider: "pool", + Status: coreauth.StatusDisabled, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := manager.Register(context.Background(), blockedAuth); err != nil { + t.Fatalf("register blocked auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(blockedAuth.ID, "pool", []*registry.ModelInfo{{ + ID: "voice-new", + Created: time.Now().Unix() + 1000, + }}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(blockedAuth.ID) + }) + + activeAuth := &coreauth.Auth{ + ID: "active-auth", + Provider: "pool", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := manager.Register(context.Background(), activeAuth); err != nil { + t.Fatalf("register active auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(activeAuth.ID, "pool", []*registry.ModelInfo{{ + ID: audioTestModel, + Created: time.Now().Unix(), + }}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(activeAuth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIAPIHandler(base) + router := gin.New() + router.POST("/v1/audio/transcriptions", h.AudioTranscriptions) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": "auto", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if executor.lastAuthID != activeAuth.ID { + t.Fatalf("last auth id = %q, want %q", executor.lastAuthID, activeAuth.ID) + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != audioTestModel { + t.Fatalf("model field = %v, want [%s]", got, audioTestModel) + } +} + +func TestAudioTranscriptionsResolveAutoSkipsGloballyCoolingDownAuths(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &audioCaptureExecutor{id: "pool"} + manager := coreauth.NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: []internalconfig.OpenAICompatibilityModel{ + {Name: audioTestModel, Alias: "voice-new"}, + }, + }}, + }) + manager.RegisterExecutor(executor) + + blockedAuth := &coreauth.Auth{ + ID: "blocked-auth", + Provider: "pool", + Status: coreauth.StatusActive, + Unavailable: true, + NextRetryAfter: time.Now().Add(30 * time.Minute), + Quota: coreauth.QuotaState{ + Exceeded: true, + NextRecoverAt: time.Now().Add(30 * time.Minute), + }, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := manager.Register(context.Background(), blockedAuth); err != nil { + t.Fatalf("register blocked auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(blockedAuth.ID, "pool", []*registry.ModelInfo{{ + ID: "voice-new", + Created: time.Now().Unix() + 1000, + }}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(blockedAuth.ID) + }) + + activeAuth := &coreauth.Auth{ + ID: "active-auth", + Provider: "pool", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := manager.Register(context.Background(), activeAuth); err != nil { + t.Fatalf("register active auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(activeAuth.ID, "pool", []*registry.ModelInfo{{ + ID: audioTestModel, + Created: time.Now().Unix(), + }}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(activeAuth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIAPIHandler(base) + router := gin.New() + router.POST("/v1/audio/transcriptions", h.AudioTranscriptions) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": "auto", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if executor.lastAuthID != activeAuth.ID { + t.Fatalf("last auth id = %q, want %q", executor.lastAuthID, activeAuth.ID) + } + if got := executor.lastFields["model"]; len(got) != 1 || got[0] != audioTestModel { + t.Fatalf("model field = %v, want [%s]", got, audioTestModel) + } +} + +func TestResolveAudioAutoModelUsesProviderSpecificMetadata(t *testing.T) { + manager := coreauth.NewManager(nil, nil, nil) + + audioAuth := &coreauth.Auth{ + ID: "provider-aware-audio-auth", + Provider: "compat-audio", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "compat-audio", + "provider_key": "compat-audio", + }, + } + if _, err := manager.Register(context.Background(), audioAuth); err != nil { + t.Fatalf("register audio auth: %v", err) + } + + nonAudioAuth := &coreauth.Auth{ + ID: "provider-aware-text-auth", + Provider: "compat-text", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://api.example.com/v1", + "compat_name": "compat-text", + "provider_key": "compat-text", + }, + } + if _, err := manager.Register(context.Background(), nonAudioAuth); err != nil { + t.Fatalf("register text auth: %v", err) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(audioAuth.ID, audioAuth.Provider, []*registry.ModelInfo{{ + ID: "shared-transcription-model", + Created: time.Now().Unix(), + SupportedInputModalities: []string{"audio"}, + SupportedOutputModalities: []string{ + "text", + }, + }}) + reg.RegisterClient(nonAudioAuth.ID, nonAudioAuth.Provider, []*registry.ModelInfo{{ + ID: "shared-transcription-model", + Created: time.Now().Unix() + 1, + SupportedInputModalities: []string{"text"}, + SupportedOutputModalities: []string{ + "text", + }, + }}) + t.Cleanup(func() { + reg.UnregisterClient(audioAuth.ID) + reg.UnregisterClient(nonAudioAuth.ID) + }) + + got, err := resolveAudioAutoModel(manager, "auto") + if err != nil { + t.Fatalf("resolveAudioAutoModel(auto) error = %v", err) + } + if got != "shared-transcription-model" { + t.Fatalf("resolveAudioAutoModel(auto) = %q, want %q", got, "shared-transcription-model") + } +} + +func TestAudioTranscriptionsResolveAutoPinsSelectedAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + + audioExecutor := &audioCaptureExecutor{id: "zz-audio"} + textExecutor := &audioCaptureExecutor{ + id: "aa-text", + responses: map[string]audioExecutorResponse{ + "text-auth": { + status: http.StatusBadRequest, + body: `{"error":"model not found"}`, + contentType: "application/json", + }, + }, + } + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(textExecutor) + manager.RegisterExecutor(audioExecutor) + + textAuth := &coreauth.Auth{ + ID: "text-auth", + Provider: "aa-text", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://text.example.com/v1", + "compat_name": "aa-text", + "provider_key": "aa-text", + }, + } + if _, err := manager.Register(context.Background(), textAuth); err != nil { + t.Fatalf("register text auth: %v", err) + } + + audioAuth := &coreauth.Auth{ + ID: "audio-auth", + Provider: "zz-audio", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "base_url": "https://audio.example.com/v1", + "compat_name": "zz-audio", + "provider_key": "zz-audio", + }, + } + if _, err := manager.Register(context.Background(), audioAuth); err != nil { + t.Fatalf("register audio auth: %v", err) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(textAuth.ID, textAuth.Provider, []*registry.ModelInfo{{ + ID: "shared-transcription-model", + Created: time.Now().Unix() + 1, + SupportedInputModalities: []string{"text"}, + SupportedOutputModalities: []string{ + "text", + }, + }}) + reg.RegisterClient(audioAuth.ID, audioAuth.Provider, []*registry.ModelInfo{{ + ID: "shared-transcription-model", + Created: time.Now().Unix(), + SupportedInputModalities: []string{"audio"}, + SupportedOutputModalities: []string{ + "text", + }, + }}) + t.Cleanup(func() { + reg.UnregisterClient(textAuth.ID) + reg.UnregisterClient(audioAuth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIAPIHandler(base) + router := gin.New() + router.POST("/v1/audio/transcriptions", h.AudioTranscriptions) + + req := newAudioMultipartRequest(t, "/v1/audio/transcriptions", map[string]string{ + "model": "auto", + }, audioTranscriptionFileFieldName, "sample.webm", "", []byte("fake-audio")) + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", resp.Code, http.StatusOK, resp.Body.String()) + } + if audioExecutor.lastAuthID != audioAuth.ID { + t.Fatalf("audio executor auth id = %q, want %q", audioExecutor.lastAuthID, audioAuth.ID) + } + if textExecutor.Calls() != 0 { + t.Fatalf("text executor calls = %d, want 0", textExecutor.Calls()) + } + if got := audioExecutor.lastFields["model"]; len(got) != 1 || got[0] != "shared-transcription-model" { + t.Fatalf("model field = %v, want [shared-transcription-model]", got) + } +} + +func newAudioTestRouter(t *testing.T, executor coreauth.ProviderExecutor, auths ...*coreauth.Auth) *gin.Engine { + return newAudioTestRouterWithConfig(t, &sdkconfig.SDKConfig{}, executor, auths...) +} + +func newAudioTestRouterWithConfig(t *testing.T, cfg *sdkconfig.SDKConfig, executor coreauth.ProviderExecutor, auths ...*coreauth.Auth) *gin.Engine { + t.Helper() + + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + for _, auth := range auths { + current := cloneAudioAuth(auth) + if current.Status == "" { + current.Status = coreauth.StatusActive + } + if current.Provider == "" { + current.Provider = executor.Identifier() + } + if _, err := manager.Register(context.Background(), current); err != nil { + t.Fatalf("Register auth %s: %v", current.ID, err) + } + registry.GetGlobalRegistry().RegisterClient(current.ID, current.Provider, []*registry.ModelInfo{{ID: audioTestModel}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(current.ID) + }) + } + + base := handlers.NewBaseAPIHandlers(cfg, manager) + h := NewOpenAIAPIHandler(base) + router := gin.New() + router.POST("/v1/audio/transcriptions", h.AudioTranscriptions) + return router +} + +func newAudioMultipartRequest(t *testing.T, target string, fields map[string]string, fileField, fileName, fileContentType string, fileContent []byte) *http.Request { + t.Helper() + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + fieldNames := make([]string, 0, len(fields)) + for fieldName := range fields { + fieldNames = append(fieldNames, fieldName) + } + sort.Strings(fieldNames) + for _, fieldName := range fieldNames { + if err := writer.WriteField(fieldName, fields[fieldName]); err != nil { + t.Fatalf("WriteField(%s): %v", fieldName, err) + } + } + + var filePart io.Writer + var err error + if fileContentType == "" { + filePart, err = writer.CreateFormFile(fileField, fileName) + } else { + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{ + "name": fileField, + "filename": fileName, + })) + partHeader.Set("Content-Type", fileContentType) + filePart, err = writer.CreatePart(partHeader) + } + if err != nil { + t.Fatalf("Create file part: %v", err) + } + if _, err := filePart.Write(fileContent); err != nil { + t.Fatalf("Write(file): %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("Close writer: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, target, body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + return req +} + +func cloneAudioAuth(auth *coreauth.Auth) *coreauth.Auth { + if auth == nil { + return nil + } + clone := *auth + if auth.Attributes != nil { + clone.Attributes = make(map[string]string, len(auth.Attributes)) + for key, value := range auth.Attributes { + clone.Attributes[key] = value + } + } + if auth.Metadata != nil { + clone.Metadata = make(map[string]any, len(auth.Metadata)) + for key, value := range auth.Metadata { + clone.Metadata[key] = value + } + } + return &clone +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index b29e04db8c..e66cd284d9 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -42,6 +42,9 @@ type ProviderExecutor interface { HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) } +// HTTPRequestBuilder builds a fresh HTTP request for the selected auth and resolved upstream model. +type HTTPRequestBuilder func(context.Context, *Auth, string) (*http.Request, error) + // ExecutionSessionCloser allows executors to release per-session runtime resources. type ExecutionSessionCloser interface { CloseExecutionSession(sessionID string) @@ -420,25 +423,8 @@ func preserveRequestedModelSuffix(requestedModel, resolved string) string { return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel)) } -func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string { - return m.prepareExecutionModels(auth, routeModel) -} - func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string { - requestedModel := rewriteModelForAuth(routeModel, auth) - requestedModel = m.applyOAuthModelAlias(auth, requestedModel) - if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 { - if len(pool) == 1 { - return pool - } - offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool)) - return rotateStrings(pool, offset) - } - resolved := m.applyAPIKeyModelAlias(auth, requestedModel) - if strings.TrimSpace(resolved) == "" { - resolved = requestedModel - } - return []string{resolved} + return m.resolveExecutionModels(auth, routeModel, true) } func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) { @@ -972,6 +958,40 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli return nil, &Error{Code: "auth_not_found", Message: "no auth available"} } +// ExecuteHTTPRequest selects an auth using the existing scheduler and executes a raw HTTP request. +// selectionModel controls provider/model routing, while resultModel is recorded in auth state updates. +func (m *Manager) ExecuteHTTPRequest(ctx context.Context, providers []string, selectionModel, resultModel string, opts cliproxyexecutor.Options, build HTTPRequestBuilder) (*http.Response, *Auth, error) { + normalized := m.normalizeProviders(providers) + if len(normalized) == 0 { + return nil, nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + if build == nil { + return nil, nil, &Error{Code: "invalid_request", Message: "http request builder is nil"} + } + + _, maxRetryCredentials, maxWait := m.retrySettings() + + var lastErr error + for attempt := 0; ; attempt++ { + resp, auth, errExec := m.executeHTTPMixedOnce(ctx, normalized, selectionModel, resultModel, opts, maxRetryCredentials, build) + if errExec == nil { + return resp, auth, nil + } + lastErr = errExec + wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, selectionModel, maxWait) + if !shouldRetry { + break + } + if errWait := waitForCooldown(ctx, wait); errWait != nil { + return nil, nil, errWait + } + } + if lastErr != nil { + return nil, nil, lastErr + } + return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"} +} + func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) { if len(providers) == 0 { return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} @@ -1164,6 +1184,149 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string } } +func (m *Manager) executeHTTPMixedOnce(ctx context.Context, providers []string, selectionModel, resultModel string, opts cliproxyexecutor.Options, maxRetryCredentials int, build HTTPRequestBuilder) (*http.Response, *Auth, error) { + if len(providers) == 0 { + return nil, nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + if build == nil { + return nil, nil, &Error{Code: "invalid_request", Message: "http request builder is nil"} + } + + routeModel := strings.TrimSpace(selectionModel) + recordModel := strings.TrimSpace(resultModel) + if recordModel == "" { + recordModel = routeModel + } + opts = ensureRequestedModelMetadata(opts, recordModel) + + tried := make(map[string]struct{}) + var lastErr error + for { + if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials { + if lastErr != nil { + return nil, nil, lastErr + } + return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"} + } + + auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) + if errPick != nil { + if lastErr != nil { + return nil, nil, lastErr + } + return nil, nil, errPick + } + + entry := logEntryWithRequestID(ctx) + debugLogAuthSelection(entry, auth, provider, routeModel) + publishSelectedAuthMetadata(opts.Metadata, auth.ID) + + tried[auth.ID] = struct{}{} + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + + models := m.prepareExecutionModels(auth, routeModel) + var authErr error + for _, upstreamModel := range models { + httpReq, errBuild := build(execCtx, auth, upstreamModel) + if errBuild != nil { + if errCtx := execCtx.Err(); errCtx != nil { + return nil, auth, errCtx + } + result := Result{ + AuthID: auth.ID, + Provider: provider, + Model: recordModel, + Success: false, + Error: &Error{ + Message: errBuild.Error(), + HTTPStatus: statusCodeFromError(errBuild), + }, + } + m.MarkResult(execCtx, result) + if isImmediateHTTPBuildError(errBuild) { + return nil, auth, errBuild + } + authErr = errBuild + continue + } + + httpResp, errHTTP := executor.HttpRequest(execCtx, auth, httpReq) + if errHTTP != nil { + if errCtx := execCtx.Err(); errCtx != nil { + return nil, auth, errCtx + } + result := Result{ + AuthID: auth.ID, + Provider: provider, + Model: recordModel, + Success: false, + Error: &Error{Message: errHTTP.Error()}, + } + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errHTTP); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errHTTP); ra != nil { + result.RetryAfter = ra + } + m.MarkResult(execCtx, result) + if isRequestInvalidError(errHTTP) { + return nil, auth, errHTTP + } + authErr = errHTTP + continue + } + + if httpResp != nil && httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { + m.MarkResult(execCtx, Result{ + AuthID: auth.ID, + Provider: provider, + Model: recordModel, + Success: true, + }) + return httpResp, auth, nil + } + + responseErr := newHTTPResponseError(httpResp) + result := Result{ + AuthID: auth.ID, + Provider: provider, + Model: recordModel, + Success: false, + Error: &Error{ + Message: responseErr.Error(), + HTTPStatus: responseErr.StatusCode(), + }, + } + if ra := retryAfterFromHeaders(responseErr.Headers()); ra != nil { + result.RetryAfter = ra + } + m.MarkResult(execCtx, result) + if isRequestInvalidError(responseErr) { + return nil, auth, responseErr + } + authErr = responseErr + } + if authErr != nil { + lastErr = authErr + continue + } + } +} + +func isImmediateHTTPBuildError(err error) bool { + if err == nil { + return false + } + if status := statusCodeFromError(err); status == http.StatusBadRequest || status == http.StatusUnprocessableEntity { + return true + } + return isRequestInvalidError(err) +} + func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options { requestedModel = strings.TrimSpace(requestedModel) if requestedModel == "" { @@ -1876,6 +2039,56 @@ func retryAfterFromError(err error) *time.Duration { return new(*retryAfter) } +func retryAfterFromHeaders(headers http.Header) *time.Duration { + if headers == nil { + return nil + } + raw := strings.TrimSpace(headers.Get("Retry-After")) + if raw == "" { + return nil + } + if seconds, err := strconv.Atoi(raw); err == nil && seconds >= 0 { + retryAfter := time.Duration(seconds) * time.Second + return &retryAfter + } + if retryAt, err := http.ParseTime(raw); err == nil { + retryAfter := time.Until(retryAt) + if retryAfter < 0 { + retryAfter = 0 + } + return &retryAfter + } + return nil +} + +const httpResponseErrorBodyLimitBytes int64 = 8 << 20 + +func newHTTPResponseError(resp *http.Response) *HTTPResponseError { + if resp == nil { + return &HTTPResponseError{ + HTTPStatus: http.StatusBadGateway, + Body: []byte("upstream request failed"), + } + } + defer func() { + _ = resp.Body.Close() + }() + limitedReader := &io.LimitedReader{R: resp.Body, N: httpResponseErrorBodyLimitBytes + 1} + body, _ := io.ReadAll(limitedReader) + if int64(len(body)) > httpResponseErrorBodyLimitBytes { + body = body[:httpResponseErrorBodyLimitBytes] + } + status := resp.StatusCode + if status <= 0 { + status = http.StatusBadGateway + } + return &HTTPResponseError{ + HTTPStatus: status, + Body: body, + HeaderMap: resp.Header.Clone(), + } +} + func statusCodeFromResult(err *Error) int { if err == nil { return 0 diff --git a/sdk/cliproxy/auth/http_error.go b/sdk/cliproxy/auth/http_error.go new file mode 100644 index 0000000000..ea00e98df3 --- /dev/null +++ b/sdk/cliproxy/auth/http_error.go @@ -0,0 +1,44 @@ +package auth + +import ( + "net/http" + "strings" +) + +// HTTPResponseError preserves an upstream HTTP failure for caller-side handling. +type HTTPResponseError struct { + HTTPStatus int + Body []byte + HeaderMap http.Header +} + +// Error implements the error interface. +func (e *HTTPResponseError) Error() string { + if e == nil { + return "" + } + message := strings.TrimSpace(string(e.Body)) + if message == "" { + message = http.StatusText(e.HTTPStatus) + } + if message == "" { + message = "upstream request failed" + } + return message +} + +// StatusCode exposes the upstream HTTP status code. +func (e *HTTPResponseError) StatusCode() int { + if e == nil { + return 0 + } + return e.HTTPStatus +} + +// Headers returns a defensive copy of the upstream response headers. +func (e *HTTPResponseError) Headers() http.Header { + if e == nil || e.HeaderMap == nil { + return nil + } + return e.HeaderMap.Clone() +} diff --git a/sdk/cliproxy/auth/http_error_test.go b/sdk/cliproxy/auth/http_error_test.go new file mode 100644 index 0000000000..2b44f34fa7 --- /dev/null +++ b/sdk/cliproxy/auth/http_error_test.go @@ -0,0 +1,32 @@ +package auth + +import ( + "io" + "net/http" + "strings" + "testing" +) + +func TestNewHTTPResponseError_BoundsBodyReads(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusBadGateway, + Header: http.Header{ + "Content-Type": []string{"text/plain; charset=utf-8"}, + }, + Body: io.NopCloser(strings.NewReader(strings.Repeat("x", int(httpResponseErrorBodyLimitBytes+1)))), + } + + err := newHTTPResponseError(resp) + if err == nil { + t.Fatalf("newHTTPResponseError() returned nil") + } + if got := int64(len(err.Body)); got != httpResponseErrorBodyLimitBytes { + t.Fatalf("body len = %d, want %d", got, httpResponseErrorBodyLimitBytes) + } + if got := err.StatusCode(); got != http.StatusBadGateway { + t.Fatalf("status = %d, want %d", got, http.StatusBadGateway) + } + if got := err.Headers().Get("Content-Type"); got != "text/plain; charset=utf-8" { + t.Fatalf("content type = %q, want %q", got, "text/plain; charset=utf-8") + } +} diff --git a/sdk/cliproxy/auth/http_request_test.go b/sdk/cliproxy/auth/http_request_test.go new file mode 100644 index 0000000000..2a7cbe6fe2 --- /dev/null +++ b/sdk/cliproxy/auth/http_request_test.go @@ -0,0 +1,217 @@ +package auth + +import ( + "context" + "io" + "net/http" + "strings" + "sync" + "testing" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type httpRequestTestExecutor struct { + id string + + mu sync.Mutex + authIDs []string + requestModels []string + statusByModel map[string]int +} + +func (e *httpRequestTestExecutor) Identifier() string { return e.id } + +func (e *httpRequestTestExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "Execute not implemented"} +} + +func (e *httpRequestTestExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "ExecuteStream not implemented"} +} + +func (e *httpRequestTestExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *httpRequestTestExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "CountTokens not implemented"} +} + +func (e *httpRequestTestExecutor) HttpRequest(_ context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + model := req.Header.Get("X-Model") + + e.mu.Lock() + if auth != nil { + e.authIDs = append(e.authIDs, auth.ID) + } + e.requestModels = append(e.requestModels, model) + status := http.StatusOK + if e.statusByModel != nil { + if candidate, ok := e.statusByModel[model]; ok { + status = candidate + } + } + e.mu.Unlock() + + return &http.Response{ + StatusCode: status, + Header: http.Header{ + "Content-Type": []string{"text/plain; charset=utf-8"}, + "X-Model": []string{model}, + }, + Body: io.NopCloser(strings.NewReader(model)), + }, nil +} + +func (e *httpRequestTestExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.authIDs)) + copy(out, e.authIDs) + return out +} + +func (e *httpRequestTestExecutor) RequestModels() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.requestModels)) + copy(out, e.requestModels) + return out +} + +func TestManagerExecuteHTTPRequest_AppliesOpenAICompatModelPool(t *testing.T) { + const alias = "claude-opus-4.66" + + cfg := &internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, + }}, + } + + manager := NewManager(nil, nil, nil) + manager.SetConfig(cfg) + + executor := &httpRequestTestExecutor{ + id: "pool", + statusByModel: map[string]int{ + "qwen3.5-plus": http.StatusInternalServerError, + "glm-5": http.StatusOK, + }, + } + manager.RegisterExecutor(executor) + + auth := &Auth{ + ID: "pool-auth", + Provider: "pool", + Status: StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("register auth: %v", err) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) + t.Cleanup(func() { + reg.UnregisterClient(auth.ID) + }) + + var builtModels []string + resp, selectedAuth, err := manager.ExecuteHTTPRequest(context.Background(), []string{"pool"}, alias, alias, cliproxyexecutor.Options{}, func(ctx context.Context, auth *Auth, upstreamModel string) (*http.Request, error) { + builtModels = append(builtModels, upstreamModel) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://example.com/http", nil) + if err != nil { + return nil, err + } + req.Header.Set("X-Model", upstreamModel) + return req, nil + }) + if err != nil { + t.Fatalf("ExecuteHTTPRequest() error = %v", err) + } + if selectedAuth == nil || selectedAuth.ID != auth.ID { + t.Fatalf("selected auth = %#v, want %q", selectedAuth, auth.ID) + } + if got := builtModels; len(got) != 2 || got[0] != "qwen3.5-plus" || got[1] != "glm-5" { + t.Fatalf("built models = %v, want [qwen3.5-plus glm-5]", got) + } + if got := executor.RequestModels(); len(got) != 2 || got[0] != "qwen3.5-plus" || got[1] != "glm-5" { + t.Fatalf("request models = %v, want [qwen3.5-plus glm-5]", got) + } + if resp == nil { + t.Fatalf("expected non-nil response") + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll(response body): %v", err) + } + _ = resp.Body.Close() + if string(body) != "glm-5" { + t.Fatalf("response body = %q, want %q", string(body), "glm-5") + } +} + +func TestManagerExecuteHTTPRequest_RetriesAcrossAuthsOnBuildFailure(t *testing.T) { + manager := NewManager(nil, nil, nil) + + executor := &httpRequestTestExecutor{id: "pool"} + manager.RegisterExecutor(executor) + + auth1 := &Auth{ID: "auth1", Provider: "pool", Status: StatusActive} + auth2 := &Auth{ID: "auth2", Provider: "pool", Status: StatusActive} + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("register auth1: %v", err) + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("register auth2: %v", err) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth1.ID, "pool", []*registry.ModelInfo{{ID: "test-model"}}) + reg.RegisterClient(auth2.ID, "pool", []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + reg.UnregisterClient(auth1.ID) + reg.UnregisterClient(auth2.ID) + }) + + var buildAuthIDs []string + resp, selectedAuth, err := manager.ExecuteHTTPRequest(context.Background(), []string{"pool"}, "test-model", "test-model", cliproxyexecutor.Options{}, func(ctx context.Context, auth *Auth, upstreamModel string) (*http.Request, error) { + buildAuthIDs = append(buildAuthIDs, auth.ID) + if auth.ID == "auth1" { + return nil, &Error{HTTPStatus: http.StatusBadGateway, Message: "build failed for auth1"} + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://example.com/http", nil) + if err != nil { + return nil, err + } + req.Header.Set("X-Model", upstreamModel) + return req, nil + }) + if err != nil { + t.Fatalf("ExecuteHTTPRequest() error = %v", err) + } + if selectedAuth == nil || selectedAuth.ID != "auth2" { + t.Fatalf("selected auth = %#v, want auth2", selectedAuth) + } + if got := buildAuthIDs; len(got) != 2 || got[0] != "auth1" || got[1] != "auth2" { + t.Fatalf("build auth IDs = %v, want [auth1 auth2]", got) + } + if got := executor.AuthIDs(); len(got) != 1 || got[0] != "auth2" { + t.Fatalf("executor auth IDs = %v, want [auth2]", got) + } + if resp == nil { + t.Fatalf("expected non-nil response") + } + _ = resp.Body.Close() +} diff --git a/sdk/cliproxy/auth/model_preview.go b/sdk/cliproxy/auth/model_preview.go new file mode 100644 index 0000000000..e52678aa7e --- /dev/null +++ b/sdk/cliproxy/auth/model_preview.go @@ -0,0 +1,94 @@ +package auth + +import ( + "sort" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +// RouteModelPreview describes a selectable route model and the upstream models +// it would resolve to without mutating scheduler state. +type RouteModelPreview struct { + Auth *Auth + RouteModel string + CreatedAt int64 + UpstreamModels []string +} + +// ExecutionModelCandidates resolves the upstream model candidates for a route model +// under a specific auth without advancing model-pool rotation state. +func (m *Manager) ExecutionModelCandidates(auth *Auth, routeModel string) []string { + if m == nil { + return nil + } + return m.resolveExecutionModels(auth, routeModel, false) +} + +// PreviewSelectableRouteModels lists currently selectable route models together +// with their upstream execution-model previews. +func (m *Manager) PreviewSelectableRouteModels(now time.Time, authAllowed func(*Auth) bool) []RouteModelPreview { + if m == nil { + return nil + } + reg := registry.GetGlobalRegistry() + previews := make([]RouteModelPreview, 0) + for _, auth := range m.List() { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + continue + } + if authAllowed != nil && !authAllowed(auth) { + continue + } + for _, modelInfo := range reg.GetModelsForClient(auth.ID) { + if modelInfo == nil { + continue + } + routeModel := strings.TrimSpace(modelInfo.ID) + if routeModel == "" { + continue + } + if !IsAuthSelectableForModel(auth, routeModel, now) { + continue + } + upstreamModels := m.ExecutionModelCandidates(auth, routeModel) + if len(upstreamModels) == 0 { + continue + } + previews = append(previews, RouteModelPreview{ + Auth: auth, + RouteModel: routeModel, + CreatedAt: modelInfo.Created, + UpstreamModels: append([]string(nil), upstreamModels...), + }) + } + } + sort.SliceStable(previews, func(i, j int) bool { + if previews[i].Auth != nil && previews[j].Auth != nil && previews[i].Auth.ID != previews[j].Auth.ID { + return previews[i].Auth.ID < previews[j].Auth.ID + } + return previews[i].RouteModel < previews[j].RouteModel + }) + return previews +} + +func (m *Manager) resolveExecutionModels(auth *Auth, routeModel string, rotatePool bool) []string { + requestedModel := rewriteModelForAuth(routeModel, auth) + requestedModel = m.applyOAuthModelAlias(auth, requestedModel) + if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 { + if len(pool) == 1 { + return append([]string(nil), pool...) + } + if rotatePool { + offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool)) + return rotateStrings(pool, offset) + } + return rotateStrings(pool, 0) + } + resolved := m.applyAPIKeyModelAlias(auth, requestedModel) + if strings.TrimSpace(resolved) == "" { + resolved = requestedModel + } + return []string{resolved} +} diff --git a/sdk/cliproxy/auth/openai_compat_pool_test.go b/sdk/cliproxy/auth/openai_compat_pool_test.go index 5a5ecb4fe2..5ea393b3b2 100644 --- a/sdk/cliproxy/auth/openai_compat_pool_test.go +++ b/sdk/cliproxy/auth/openai_compat_pool_test.go @@ -222,6 +222,47 @@ func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { } } +func TestManagerExecutionModelCandidates_OpenAICompatAliasPoolDoesNotRotate(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{id: "pool"} + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + auth, ok := m.GetByID("pool-auth-" + t.Name()) + if !ok { + t.Fatal("expected registered auth") + } + + firstPreview := m.ExecutionModelCandidates(auth, alias) + secondPreview := m.ExecutionModelCandidates(auth, alias) + wantPreview := []string{"qwen3.5-plus", "glm-5"} + if len(firstPreview) != len(wantPreview) || len(secondPreview) != len(wantPreview) { + t.Fatalf("preview lengths = %v / %v, want %v", firstPreview, secondPreview, wantPreview) + } + for i := range wantPreview { + if firstPreview[i] != wantPreview[i] { + t.Fatalf("first preview[%d] = %q, want %q", i, firstPreview[i], wantPreview[i]) + } + if secondPreview[i] != wantPreview[i] { + t.Fatalf("second preview[%d] = %q, want %q", i, secondPreview[i], wantPreview[i]) + } + } + + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute: %v", err) + } + if string(resp.Payload) != "qwen3.5-plus" { + t.Fatalf("payload = %q, want %q", string(resp.Payload), "qwen3.5-plus") + } + got := executor.ExecuteModels() + if len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("execute calls = %v, want [qwen3.5-plus]", got) + } +} + func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { alias := "claude-opus-4.66" invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index cf79e17337..a8dfede0b4 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -403,7 +403,6 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block return false, blockReasonNone, time.Time{} } } - return false, blockReasonNone, time.Time{} } if auth.Unavailable && auth.NextRetryAfter.After(now) { next := auth.NextRetryAfter @@ -420,3 +419,9 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block } return false, blockReasonNone, time.Time{} } + +// IsAuthSelectableForModel reports whether auth is currently eligible for selection for model. +func IsAuthSelectableForModel(auth *Auth, model string, now time.Time) bool { + blocked, _, _ := isAuthBlockedForModel(auth, model, now) + return !blocked +} diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index 79431a9ada..cd22a03217 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -311,6 +311,38 @@ func TestIsAuthBlockedForModel_UnavailableWithoutNextRetryIsNotBlocked(t *testin } } +func TestIsAuthBlockedForModel_FallsBackToGlobalCooldownWhenModelStateMissing(t *testing.T) { + t.Parallel() + + now := time.Now() + nextRetry := now.Add(15 * time.Minute) + auth := &Auth{ + ID: "a", + Unavailable: true, + NextRetryAfter: nextRetry, + Quota: QuotaState{ + Exceeded: true, + NextRecoverAt: nextRetry, + }, + ModelStates: map[string]*ModelState{ + "other-model": { + Status: StatusActive, + }, + }, + } + + blocked, reason, next := isAuthBlockedForModel(auth, "missing-model", now) + if !blocked { + t.Fatalf("blocked = false, want true") + } + if reason != blockReasonCooldown { + t.Fatalf("reason = %v, want %v", reason, blockReasonCooldown) + } + if !next.Equal(nextRetry) { + t.Fatalf("next = %v, want %v", next, nextRetry) + } +} + func TestFillFirstSelectorPick_ThinkingSuffixFallsBackToBaseModelState(t *testing.T) { t.Parallel()