diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index b29e04db8c..9f46c7cf4a 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -421,10 +421,6 @@ func preserveRequestedModelSuffix(requestedModel, resolved string) string { } func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string { - return m.prepareExecutionModels(auth, routeModel) -} - -func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string { requestedModel := rewriteModelForAuth(routeModel, auth) requestedModel = m.applyOAuthModelAlias(auth, requestedModel) if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 { @@ -441,6 +437,46 @@ func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string return []string{resolved} } +func executionResultModel(routeModel, upstreamModel string, pooled bool) string { + if pooled { + if resolved := strings.TrimSpace(upstreamModel); resolved != "" { + return resolved + } + } + if requested := strings.TrimSpace(routeModel); requested != "" { + return requested + } + return strings.TrimSpace(upstreamModel) +} + +func filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string { + if len(candidates) == 0 { + return nil + } + now := time.Now() + out := make([]string, 0, len(candidates)) + for _, upstreamModel := range candidates { + stateModel := executionResultModel(routeModel, upstreamModel, pooled) + blocked, _, _ := isAuthBlockedForModel(auth, stateModel, now) + if blocked { + continue + } + out = append(out, upstreamModel) + } + return out +} + +func (m *Manager) preparedExecutionModels(auth *Auth, routeModel string) ([]string, bool) { + candidates := m.executionModelCandidates(auth, routeModel) + pooled := len(candidates) > 1 + return filterExecutionModels(auth, routeModel, candidates, pooled), pooled +} + +func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string { + models, _ := m.preparedExecutionModels(auth, routeModel) + return models +} + func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) { if ch == nil { return @@ -451,6 +487,59 @@ func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) { }() } +type streamBootstrapError struct { + cause error + headers http.Header +} + +func cloneHTTPHeader(headers http.Header) http.Header { + if headers == nil { + return nil + } + return headers.Clone() +} + +func newStreamBootstrapError(err error, headers http.Header) error { + if err == nil { + return nil + } + return &streamBootstrapError{ + cause: err, + headers: cloneHTTPHeader(headers), + } +} + +func (e *streamBootstrapError) Error() string { + if e == nil || e.cause == nil { + return "" + } + return e.cause.Error() +} + +func (e *streamBootstrapError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func (e *streamBootstrapError) Headers() http.Header { + if e == nil { + return nil + } + return cloneHTTPHeader(e.headers) +} + +func streamErrorResult(headers http.Header, err error) *cliproxyexecutor.StreamResult { + ch := make(chan cliproxyexecutor.StreamChunk, 1) + ch <- cliproxyexecutor.StreamChunk{Err: err} + close(ch) + return &cliproxyexecutor.StreamResult{ + Headers: cloneHTTPHeader(headers), + Chunks: ch, + } +} + func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) { if ch == nil { return nil, true, nil @@ -483,7 +572,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC } } -func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult { +func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, resultModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult { out := make(chan cliproxyexecutor.StreamChunk) go func() { defer close(out) @@ -496,7 +585,7 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil { rerr.HTTPStatus = se.StatusCode() } - m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}) + m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}) } if !forward { return false @@ -526,19 +615,19 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro } } if !failed { - m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true}) + m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: true}) } }() return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out} } -func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) { +func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, execModels []string, pooled bool) (*cliproxyexecutor.StreamResult, error) { if executor == nil { return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} } - execModels := m.prepareExecutionModels(auth, routeModel) var lastErr error for idx, execModel := range execModels { + resultModel := executionResultModel(routeModel, execModel, pooled) execReq := req execReq.Model = execModel streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts) @@ -550,7 +639,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { rerr.HTTPStatus = se.StatusCode() } - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr} result.RetryAfter = retryAfterFromError(errStream) m.MarkResult(ctx, result) if isRequestInvalidError(errStream) { @@ -571,7 +660,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { rerr.HTTPStatus = se.StatusCode() } - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr} result.RetryAfter = retryAfterFromError(bootstrapErr) m.MarkResult(ctx, result) discardStreamChunks(streamResult.Chunks) @@ -582,31 +671,33 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { rerr.HTTPStatus = se.StatusCode() } - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr} result.RetryAfter = retryAfterFromError(bootstrapErr) m.MarkResult(ctx, result) discardStreamChunks(streamResult.Chunks) lastErr = bootstrapErr continue } - errCh := make(chan cliproxyexecutor.StreamChunk, 1) - errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr} - close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + rerr := &Error{Message: bootstrapErr.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(bootstrapErr) + m.MarkResult(ctx, result) + discardStreamChunks(streamResult.Chunks) + return nil, newStreamBootstrapError(bootstrapErr, streamResult.Headers) } if closed && len(buffered) == 0 { emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true} - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr} + result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: emptyErr} m.MarkResult(ctx, result) if idx < len(execModels)-1 { lastErr = emptyErr continue } - errCh := make(chan cliproxyexecutor.StreamChunk, 1) - errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr} - close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + return nil, newStreamBootstrapError(emptyErr, streamResult.Headers) } remaining := streamResult.Chunks @@ -615,7 +706,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi close(closedCh) remaining = closedCh } - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, resultModel, streamResult.Headers, buffered, remaining), nil } if lastErr == nil { lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"} @@ -979,9 +1070,10 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) tried := make(map[string]struct{}) + attempted := make(map[string]struct{}) var lastErr error for { - if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials { + if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials { if lastErr != nil { return cliproxyexecutor.Response{}, lastErr } @@ -1006,13 +1098,18 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - models := m.prepareExecutionModels(auth, routeModel) + models, pooled := m.preparedExecutionModels(auth, routeModel) + if len(models) == 0 { + continue + } + attempted[auth.ID] = struct{}{} var authErr error for _, upstreamModel := range models { + resultModel := executionResultModel(routeModel, upstreamModel, pooled) execReq := req execReq.Model = upstreamModel resp, errExec := executor.Execute(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil} if errExec != nil { if errCtx := execCtx.Err(); errCtx != nil { return cliproxyexecutor.Response{}, errCtx @@ -1051,9 +1148,10 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) tried := make(map[string]struct{}) + attempted := make(map[string]struct{}) var lastErr error for { - if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials { + if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials { if lastErr != nil { return cliproxyexecutor.Response{}, lastErr } @@ -1078,13 +1176,18 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - models := m.prepareExecutionModels(auth, routeModel) + models, pooled := m.preparedExecutionModels(auth, routeModel) + if len(models) == 0 { + continue + } + attempted[auth.ID] = struct{}{} var authErr error for _, upstreamModel := range models { + resultModel := executionResultModel(routeModel, upstreamModel, pooled) execReq := req execReq.Model = upstreamModel resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil} if errExec != nil { if errCtx := execCtx.Err(); errCtx != nil { return cliproxyexecutor.Response{}, errCtx @@ -1096,14 +1199,14 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, if ra := retryAfterFromError(errExec); ra != nil { result.RetryAfter = ra } - m.hook.OnResult(execCtx, result) + m.MarkResult(execCtx, result) if isRequestInvalidError(errExec) { return cliproxyexecutor.Response{}, errExec } authErr = errExec continue } - m.hook.OnResult(execCtx, result) + m.MarkResult(execCtx, result) return resp, nil } if authErr != nil { @@ -1123,10 +1226,15 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) tried := make(map[string]struct{}) + attempted := make(map[string]struct{}) var lastErr error for { - if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials { + if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials { if lastErr != nil { + var bootstrapErr *streamBootstrapError + if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil { + return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil + } return nil, lastErr } return nil, &Error{Code: "auth_not_found", Message: "no auth available"} @@ -1134,6 +1242,10 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) if errPick != nil { if lastErr != nil { + var bootstrapErr *streamBootstrapError + if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil { + return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil + } return nil, lastErr } return nil, errPick @@ -1149,7 +1261,12 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel) + models, pooled := m.preparedExecutionModels(auth, routeModel) + if len(models) == 0 { + continue + } + attempted[auth.ID] = struct{}{} + streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx @@ -1627,53 +1744,60 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { } statusCode := statusCodeFromResult(result.Error) - switch statusCode { - case 401: - next := now.Add(30 * time.Minute) - state.NextRetryAfter = next - suspendReason = "unauthorized" - shouldSuspendModel = true - case 402, 403: - next := now.Add(30 * time.Minute) - state.NextRetryAfter = next - suspendReason = "payment_required" - shouldSuspendModel = true - case 404: + if isModelSupportResultError(result.Error) { next := now.Add(12 * time.Hour) state.NextRetryAfter = next - suspendReason = "not_found" + suspendReason = "model_not_supported" shouldSuspendModel = true - case 429: - var next time.Time - backoffLevel := state.Quota.BackoffLevel - if result.RetryAfter != nil { - next = now.Add(*result.RetryAfter) - } else { - cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth)) - if cooldown > 0 { - next = now.Add(cooldown) + } else { + switch statusCode { + case 401: + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next + suspendReason = "unauthorized" + shouldSuspendModel = true + case 402, 403: + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next + suspendReason = "payment_required" + shouldSuspendModel = true + case 404: + next := now.Add(12 * time.Hour) + state.NextRetryAfter = next + suspendReason = "not_found" + shouldSuspendModel = true + case 429: + var next time.Time + backoffLevel := state.Quota.BackoffLevel + if result.RetryAfter != nil { + next = now.Add(*result.RetryAfter) + } else { + cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth)) + if cooldown > 0 { + next = now.Add(cooldown) + } + backoffLevel = nextLevel } - backoffLevel = nextLevel - } - state.NextRetryAfter = next - state.Quota = QuotaState{ - Exceeded: true, - Reason: "quota", - NextRecoverAt: next, - BackoffLevel: backoffLevel, - } - suspendReason = "quota" - shouldSuspendModel = true - setModelQuota = true - case 408, 500, 502, 503, 504: - if quotaCooldownDisabledForAuth(auth) { - state.NextRetryAfter = time.Time{} - } else { - next := now.Add(1 * time.Minute) state.NextRetryAfter = next + state.Quota = QuotaState{ + Exceeded: true, + Reason: "quota", + NextRecoverAt: next, + BackoffLevel: backoffLevel, + } + suspendReason = "quota" + shouldSuspendModel = true + setModelQuota = true + case 408, 500, 502, 503, 504: + if quotaCooldownDisabledForAuth(auth) { + state.NextRetryAfter = time.Time{} + } else { + next := now.Add(1 * time.Minute) + state.NextRetryAfter = next + } + default: + state.NextRetryAfter = time.Time{} } - default: - state.NextRetryAfter = time.Time{} } auth.Status = StatusError @@ -1883,14 +2007,65 @@ func statusCodeFromResult(err *Error) int { return err.StatusCode() } +func isModelSupportErrorMessage(message string) bool { + lower := strings.ToLower(strings.TrimSpace(message)) + if lower == "" { + return false + } + patterns := [...]string{ + "model_not_supported", + "requested model is not supported", + "requested model is unsupported", + "requested model is unavailable", + "model is not supported", + "model not supported", + "unsupported model", + "model unavailable", + "not available for your plan", + "not available for your account", + } + for _, pattern := range patterns { + if strings.Contains(lower, pattern) { + return true + } + } + return false +} + +func isModelSupportError(err error) bool { + if err == nil { + return false + } + status := statusCodeFromError(err) + if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity { + return false + } + return isModelSupportErrorMessage(err.Error()) +} + +func isModelSupportResultError(err *Error) bool { + if err == nil { + return false + } + status := statusCodeFromResult(err) + if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity { + return false + } + return isModelSupportErrorMessage(err.Message) +} + // isRequestInvalidError returns true if the error represents a client request // error that should not be retried. Specifically, it treats 400 responses with // "invalid_request_error" and all 422 responses as request-shape failures, -// where switching auths or pooled upstream models will not help. +// where switching auths or pooled upstream models will not help. Model-support +// errors are excluded so routing can fall through to another auth or upstream. func isRequestInvalidError(err error) bool { if err == nil { return false } + if isModelSupportError(err) { + return false + } status := statusCodeFromError(err) switch status { case http.StatusBadRequest: diff --git a/sdk/cliproxy/auth/conductor_overrides_test.go b/sdk/cliproxy/auth/conductor_overrides_test.go index 7aca49da64..3ad0ce676b 100644 --- a/sdk/cliproxy/auth/conductor_overrides_test.go +++ b/sdk/cliproxy/auth/conductor_overrides_test.go @@ -108,6 +108,76 @@ func (e *credentialRetryLimitExecutor) Calls() int { return e.calls } +type authFallbackExecutor struct { + id string + + mu sync.Mutex + executeCalls []string + streamCalls []string + executeErrors map[string]error + streamFirstErrors map[string]error +} + +func (e *authFallbackExecutor) Identifier() string { + return e.id +} + +func (e *authFallbackExecutor) Execute(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + e.mu.Lock() + e.executeCalls = append(e.executeCalls, auth.ID) + err := e.executeErrors[auth.ID] + e.mu.Unlock() + if err != nil { + return cliproxyexecutor.Response{}, err + } + return cliproxyexecutor.Response{Payload: []byte(auth.ID)}, nil +} + +func (e *authFallbackExecutor) ExecuteStream(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + e.mu.Lock() + e.streamCalls = append(e.streamCalls, auth.ID) + err := e.streamFirstErrors[auth.ID] + e.mu.Unlock() + + ch := make(chan cliproxyexecutor.StreamChunk, 1) + if err != nil { + ch <- cliproxyexecutor.StreamChunk{Err: err} + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil + } + ch <- cliproxyexecutor.StreamChunk{Payload: []byte(auth.ID)} + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil +} + +func (e *authFallbackExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *authFallbackExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, &Error{HTTPStatus: 500, Message: "not implemented"} +} + +func (e *authFallbackExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) { + return nil, nil +} + +func (e *authFallbackExecutor) ExecuteCalls() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.executeCalls)) + copy(out, e.executeCalls) + return out +} + +func (e *authFallbackExecutor) StreamCalls() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.streamCalls)) + copy(out, e.streamCalls) + return out +} + func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) { t.Helper() @@ -191,6 +261,153 @@ func TestManager_MaxRetryCredentials_LimitsCrossCredentialRetries(t *testing.T) } } +func TestManager_ModelSupportBadRequest_FallsBackAndSuspendsAuth(t *testing.T) { + m := NewManager(nil, nil, nil) + executor := &authFallbackExecutor{ + id: "claude", + executeErrors: map[string]error{ + "aa-bad-auth": &Error{ + HTTPStatus: http.StatusBadRequest, + Message: "invalid_request_error: The requested model is not supported.", + }, + }, + } + m.RegisterExecutor(executor) + + model := "claude-opus-4-6" + badAuth := &Auth{ID: "aa-bad-auth", Provider: "claude"} + goodAuth := &Auth{ID: "bb-good-auth", Provider: "claude"} + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(badAuth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + reg.RegisterClient(goodAuth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { + reg.UnregisterClient(badAuth.ID) + reg.UnregisterClient(goodAuth.ID) + }) + + if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil { + t.Fatalf("register bad auth: %v", errRegister) + } + if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil { + t.Fatalf("register good auth: %v", errRegister) + } + + request := cliproxyexecutor.Request{Model: model} + for i := 0; i < 2; i++ { + resp, errExecute := m.Execute(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{}) + if errExecute != nil { + t.Fatalf("execute %d error = %v, want success", i, errExecute) + } + if string(resp.Payload) != goodAuth.ID { + t.Fatalf("execute %d payload = %q, want %q", i, string(resp.Payload), goodAuth.ID) + } + } + + got := executor.ExecuteCalls() + want := []string{badAuth.ID, goodAuth.ID, goodAuth.ID} + if len(got) != len(want) { + t.Fatalf("execute calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d auth = %q, want %q", i, got[i], want[i]) + } + } + + updatedBad, ok := m.GetByID(badAuth.ID) + if !ok || updatedBad == nil { + t.Fatalf("expected bad auth to remain registered") + } + state := updatedBad.ModelStates[model] + if state == nil { + t.Fatalf("expected model state for %q", model) + } + if !state.Unavailable { + t.Fatalf("expected bad auth model state to be unavailable") + } + if state.NextRetryAfter.IsZero() { + t.Fatalf("expected bad auth model state cooldown to be set") + } +} + +func TestManagerExecuteStream_ModelSupportBadRequestFallsBackAndSuspendsAuth(t *testing.T) { + m := NewManager(nil, nil, nil) + executor := &authFallbackExecutor{ + id: "claude", + streamFirstErrors: map[string]error{ + "aa-bad-auth": &Error{ + HTTPStatus: http.StatusBadRequest, + Message: "invalid_request_error: The requested model is not supported.", + }, + }, + } + m.RegisterExecutor(executor) + + model := "claude-opus-4-6" + badAuth := &Auth{ID: "aa-bad-auth", Provider: "claude"} + goodAuth := &Auth{ID: "bb-good-auth", Provider: "claude"} + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(badAuth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + reg.RegisterClient(goodAuth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { + reg.UnregisterClient(badAuth.ID) + reg.UnregisterClient(goodAuth.ID) + }) + + if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil { + t.Fatalf("register bad auth: %v", errRegister) + } + if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil { + t.Fatalf("register good auth: %v", errRegister) + } + + request := cliproxyexecutor.Request{Model: model} + for i := 0; i < 2; i++ { + streamResult, errExecute := m.ExecuteStream(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{}) + if errExecute != nil { + t.Fatalf("execute stream %d error = %v, want success", i, errExecute) + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("execute stream %d chunk error = %v, want success", i, chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + if string(payload) != goodAuth.ID { + t.Fatalf("execute stream %d payload = %q, want %q", i, string(payload), goodAuth.ID) + } + } + + got := executor.StreamCalls() + want := []string{badAuth.ID, goodAuth.ID, goodAuth.ID} + if len(got) != len(want) { + t.Fatalf("stream calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("stream call %d auth = %q, want %q", i, got[i], want[i]) + } + } + + updatedBad, ok := m.GetByID(badAuth.ID) + if !ok || updatedBad == nil { + t.Fatalf("expected bad auth to remain registered") + } + state := updatedBad.ModelStates[model] + if state == nil { + t.Fatalf("expected model state for %q", model) + } + if !state.Unavailable { + t.Fatalf("expected bad auth model state to be unavailable") + } + if state.NextRetryAfter.IsZero() { + t.Fatalf("expected bad auth model state cooldown to be set") + } +} + func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) { prev := quotaCooldownDisabled.Load() quotaCooldownDisabled.Store(false) diff --git a/sdk/cliproxy/auth/openai_compat_pool_test.go b/sdk/cliproxy/auth/openai_compat_pool_test.go index 5a5ecb4fe2..9a977aae3d 100644 --- a/sdk/cliproxy/auth/openai_compat_pool_test.go +++ b/sdk/cliproxy/auth/openai_compat_pool_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "net/http" + "strings" "sync" "testing" @@ -116,6 +117,47 @@ func (e *openAICompatPoolExecutor) StreamModels() []string { return out } +type authScopedOpenAICompatPoolExecutor struct { + id string + + mu sync.Mutex + executeCalls []string +} + +func (e *authScopedOpenAICompatPoolExecutor) Identifier() string { return e.id } + +func (e *authScopedOpenAICompatPoolExecutor) Execute(_ context.Context, auth *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + call := auth.ID + "|" + req.Model + e.mu.Lock() + e.executeCalls = append(e.executeCalls, call) + e.mu.Unlock() + return cliproxyexecutor.Response{Payload: []byte(call)}, nil +} + +func (e *authScopedOpenAICompatPoolExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "ExecuteStream not implemented"} +} + +func (e *authScopedOpenAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *authScopedOpenAICompatPoolExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "CountTokens not implemented"} +} + +func (e *authScopedOpenAICompatPoolExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) { + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"} +} + +func (e *authScopedOpenAICompatPoolExecutor) ExecuteCalls() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.executeCalls)) + copy(out, e.executeCalls) + return out +} + func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager { t.Helper() cfg := &internalconfig.Config{ @@ -153,6 +195,21 @@ func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []interna return m } +func readOpenAICompatStreamPayload(t *testing.T, streamResult *cliproxyexecutor.StreamResult) string { + t.Helper() + if streamResult == nil { + t.Fatal("expected stream result") + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream error: %v", chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + return string(payload) +} + func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) { alias := "claude-opus-4.66" invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} @@ -243,6 +300,87 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { t.Fatalf("execute calls = %v, want only first invalid model", got) } } + +func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t *testing.T) { + alias := "claude-opus-4.66" + modelSupportErr := &Error{ + HTTPStatus: http.StatusBadRequest, + Message: "invalid_request_error: The requested model is not supported.", + } + executor := &openAICompatPoolExecutor{ + id: "pool", + executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute error = %v, want fallback success", err) + } + if string(resp.Payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5") + } + got := executor.ExecuteModels() + want := []string{"qwen3.5-plus", "glm-5"} + if len(got) != len(want) { + t.Fatalf("execute calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) + } + } + + updated, ok := m.GetByID("pool-auth-" + t.Name()) + if !ok || updated == nil { + t.Fatalf("expected auth to remain registered") + } + state := updated.ModelStates["qwen3.5-plus"] + if state == nil { + t.Fatalf("expected suspended upstream model state") + } + if !state.Unavailable || state.NextRetryAfter.IsZero() { + t.Fatalf("expected upstream model suspension, got %+v", state) + } +} + +func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessableEntity(t *testing.T) { + alias := "claude-opus-4.66" + modelSupportErr := &Error{ + HTTPStatus: http.StatusUnprocessableEntity, + Message: "The requested model is not supported.", + } + executor := &openAICompatPoolExecutor{ + id: "pool", + executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute error = %v, want fallback success", err) + } + if string(resp.Payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5") + } + got := executor.ExecuteModels() + want := []string{"qwen3.5-plus", "glm-5"} + if len(got) != len(want) { + t.Fatalf("execute calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) { alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{ @@ -364,6 +502,84 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test t.Fatalf("stream calls = %v, want only first invalid model", got) } } + +func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) { + alias := "claude-opus-4.66" + modelSupportErr := &Error{ + HTTPStatus: http.StatusBadRequest, + Message: "invalid_request_error: The requested model is not supported.", + } + executor := &openAICompatPoolExecutor{ + id: "pool", + executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + for i := 0; i < 3; i++ { + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute %d: %v", i, err) + } + if string(resp.Payload) != "glm-5" { + t.Fatalf("execute %d payload = %q, want %q", i, string(resp.Payload), "glm-5") + } + } + + got := executor.ExecuteModels() + want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"} + if len(got) != len(want) { + t.Fatalf("execute calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) { + alias := "claude-opus-4.66" + modelSupportErr := &Error{ + HTTPStatus: http.StatusUnprocessableEntity, + Message: "The requested model is not supported.", + } + executor := &openAICompatPoolExecutor{ + id: "pool", + streamFirstErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + for i := 0; i < 3; i++ { + streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute stream %d: %v", i, err) + } + if payload := readOpenAICompatStreamPayload(t, streamResult); payload != "glm-5" { + t.Fatalf("execute stream %d payload = %q, want %q", i, payload, "glm-5") + } + if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" { + t.Fatalf("execute stream %d header X-Model = %q, want %q", i, gotHeader, "glm-5") + } + } + + got := executor.StreamModels() + want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"} + if len(got) != len(want) { + t.Fatalf("stream calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{id: "pool"} @@ -391,6 +607,127 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T } } +func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterRequests(t *testing.T) { + alias := "claude-opus-4.66" + modelSupportErr := &Error{ + HTTPStatus: http.StatusBadRequest, + Message: "invalid_request_error: The requested model is unsupported.", + } + executor := &openAICompatPoolExecutor{ + id: "pool", + countErrors: map[string]error{"qwen3.5-plus": modelSupportErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + for i := 0; i < 3; i++ { + resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute count %d: %v", i, err) + } + if string(resp.Payload) != "glm-5" { + t.Fatalf("execute count %d payload = %q, want %q", i, string(resp.Payload), "glm-5") + } + } + + got := executor.CountModels() + want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"} + if len(got) != len(want) { + t.Fatalf("count calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudget(t *testing.T) { + alias := "claude-opus-4.66" + cfg := &internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, + }}, + } + m := NewManager(nil, nil, nil) + m.SetConfig(cfg) + m.SetRetryConfig(0, 0, 1) + + executor := &authScopedOpenAICompatPoolExecutor{id: "pool"} + m.RegisterExecutor(executor) + + badAuth := &Auth{ + ID: "aa-blocked-auth", + Provider: "pool", + Status: StatusActive, + Attributes: map[string]string{ + "api_key": "bad-key", + "compat_name": "pool", + "provider_key": "pool", + }, + } + goodAuth := &Auth{ + ID: "bb-good-auth", + Provider: "pool", + Status: StatusActive, + Attributes: map[string]string{ + "api_key": "good-key", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := m.Register(context.Background(), badAuth); err != nil { + t.Fatalf("register bad auth: %v", err) + } + if _, err := m.Register(context.Background(), goodAuth); err != nil { + t.Fatalf("register good auth: %v", err) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(badAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) + reg.RegisterClient(goodAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) + t.Cleanup(func() { + reg.UnregisterClient(badAuth.ID) + reg.UnregisterClient(goodAuth.ID) + }) + + modelSupportErr := &Error{ + HTTPStatus: http.StatusBadRequest, + Message: "invalid_request_error: The requested model is not supported.", + } + for _, upstreamModel := range []string{"qwen3.5-plus", "glm-5"} { + m.MarkResult(context.Background(), Result{ + AuthID: badAuth.ID, + Provider: "pool", + Model: upstreamModel, + Success: false, + Error: modelSupportErr, + }) + } + + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute error = %v, want success via fallback auth", err) + } + if !strings.HasPrefix(string(resp.Payload), goodAuth.ID+"|") { + t.Fatalf("payload = %q, want auth %q", string(resp.Payload), goodAuth.ID) + } + + got := executor.ExecuteCalls() + if len(got) != 1 { + t.Fatalf("execute calls = %v, want only one real execution on fallback auth", got) + } + if !strings.HasPrefix(got[0], goodAuth.ID+"|") { + t.Fatalf("execute call = %q, want fallback auth %q", got[0], goodAuth.ID) + } +} + func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) { alias := "claude-opus-4.66" invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}