From c08f23b6da38924ce8e87bf928a93d022353a0aa Mon Sep 17 00:00:00 2001 From: leecz Date: Mon, 9 Mar 2026 19:14:57 +0800 Subject: [PATCH 1/2] feat: add sticky session routing via X-Session-ID header Route requests with the same X-Session-ID to the same auth account, leveraging the existing pinnedAuthID mechanism. Bindings expire after 1 hour and are cleared on 429/5xx errors for automatic failover. - New stickyStore (in-memory, sync.RWMutex) with TTL-based expiration - Integrated into executeMixedOnce and executeStreamMixedOnce - Explicit pinned_auth_id takes precedence over sticky binding - Periodic cleanup piggybacked on existing autoRefresh ticker - X-Session-ID header extracted in handler metadata layer Co-Authored-By: Claude Opus 4.6 --- sdk/api/handlers/handlers.go | 16 +++++ sdk/cliproxy/auth/conductor.go | 100 +++++++++++++++++++++++++++++-- sdk/cliproxy/auth/sticky.go | 68 +++++++++++++++++++++ sdk/cliproxy/auth/sticky_test.go | 86 ++++++++++++++++++++++++++ sdk/cliproxy/executor/types.go | 2 + 5 files changed, 266 insertions(+), 6 deletions(-) create mode 100644 sdk/cliproxy/auth/sticky.go create mode 100644 sdk/cliproxy/auth/sticky_test.go diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 0e490e3202..321be75fc0 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -208,6 +208,11 @@ func requestExecutionMetadata(ctx context.Context) map[string]any { if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" { meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID } + // Sticky session: forward X-Session-ID header so the conductor can pin + // subsequent requests from the same session to the same auth account. + if sessionID := stickySessionIDFromHeader(ctx); sessionID != "" { + meta[coreexecutor.StickySessionMetadataKey] = sessionID + } return meta } @@ -252,6 +257,17 @@ func executionSessionIDFromContext(ctx context.Context) string { } } +func stickySessionIDFromHeader(ctx context.Context) string { + if ctx == nil { + return "" + } + ginCtx, ok := ctx.Value("gin").(*gin.Context) + if !ok || ginCtx == nil || ginCtx.Request == nil { + return "" + } + return strings.TrimSpace(ginCtx.GetHeader("X-Session-ID")) +} + // BaseAPIHandler contains the handlers for API endpoints. // It holds a pool of clients to interact with the backend service and manages // load balancing, client selection, and configuration. diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index b29e04db8c..676699753e 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -160,6 +160,9 @@ type Manager struct { // Optional HTTP RoundTripper provider injected by host. rtProvider RoundTripperProvider + // sticky maintains session-to-auth bindings for sticky routing. + sticky *stickyStore + // Auto refresh state refreshCancel context.CancelFunc refreshSemaphore chan struct{} @@ -181,6 +184,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { auths: make(map[string]*Auth), providerOffsets: make(map[string]int), modelPoolOffsets: make(map[string]int), + sticky: newStickyStore(), refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), } // atomic.Value requires non-nil initial value. @@ -483,7 +487,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC } } -func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult { +func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk, stickySessionID string) *cliproxyexecutor.StreamResult { out := make(chan cliproxyexecutor.StreamChunk) go func() { defer close(out) @@ -497,6 +501,12 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro rerr.HTTPStatus = se.StatusCode() } m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}) + // Sticky session: clear binding on mid-stream error. + if stickySessionID != "" { + if sc := statusCodeFromResult(rerr); sc == 429 || sc >= 500 { + m.sticky.Delete(stickySessionID) + } + } } if !forward { return false @@ -532,7 +542,7 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out} } -func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) { +func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, stickySessionID string) (*cliproxyexecutor.StreamResult, error) { if executor == nil { return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} } @@ -592,7 +602,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi errCh := make(chan cliproxyexecutor.StreamChunk, 1) errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr} close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil } if closed && len(buffered) == 0 { @@ -606,7 +616,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi errCh := make(chan cliproxyexecutor.StreamChunk, 1) errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr} close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil } remaining := streamResult.Chunks @@ -615,7 +625,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi close(closedCh) remaining = closedCh } - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining, stickySessionID), nil } if lastErr == nil { lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"} @@ -978,6 +988,18 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req } routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) + + // Sticky session: resolve session→auth binding before pick loop. + // An explicit pinned_auth_id takes precedence over sticky binding. + stickySessionID := stickySessionIDFromMetadata(opts.Metadata) + if stickySessionID != "" { + if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { + if boundAuth, found := m.sticky.Get(stickySessionID); found { + opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth + } + } + } + tried := make(map[string]struct{}) var lastErr error for { @@ -1025,12 +1047,24 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req result.RetryAfter = ra } m.MarkResult(execCtx, result) + // Sticky session: clear binding on rate-limit or server error so next + // request falls back to normal auth selection. + if stickySessionID != "" { + if sc := statusCodeFromResult(result.Error); sc == 429 || sc >= 500 { + m.sticky.Delete(stickySessionID) + delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) + } + } if isRequestInvalidError(errExec) { return cliproxyexecutor.Response{}, errExec } authErr = errExec continue } + // Sticky session: bind session to the auth that succeeded. + if stickySessionID != "" { + m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL) + } m.MarkResult(execCtx, result) return resp, nil } @@ -1122,6 +1156,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string } routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) + + // Sticky session: resolve session→auth binding before pick loop. + // An explicit pinned_auth_id takes precedence over sticky binding. + stickySessionID := stickySessionIDFromMetadata(opts.Metadata) + if stickySessionID != "" { + if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { + if boundAuth, found := m.sticky.Get(stickySessionID); found { + opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth + } + } + } + tried := make(map[string]struct{}) var lastErr error for { @@ -1149,17 +1195,32 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel) + streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, stickySessionID) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx } + // Sticky session: clear binding on rate-limit or server error. + if stickySessionID != "" { + sc := 0 + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { + sc = se.StatusCode() + } + if sc == 429 || sc >= 500 { + m.sticky.Delete(stickySessionID) + delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) + } + } if isRequestInvalidError(errStream) { return nil, errStream } lastErr = errStream continue } + // Sticky session: bind session to the auth that started streaming. + if stickySessionID != "" { + m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL) + } return streamResult, nil } } @@ -1221,6 +1282,27 @@ func pinnedAuthIDFromMetadata(meta map[string]any) string { } } +// stickySessionTTL is the duration a session-to-auth binding remains valid. +const stickySessionTTL = time.Hour + +func stickySessionIDFromMetadata(meta map[string]any) string { + if len(meta) == 0 { + return "" + } + raw, ok := meta[cliproxyexecutor.StickySessionMetadataKey] + if !ok || raw == nil { + return "" + } + switch val := raw.(type) { + case string: + return strings.TrimSpace(val) + case []byte: + return strings.TrimSpace(string(val)) + default: + return "" + } +} + func publishSelectedAuthMetadata(meta map[string]any, authID string) { if len(meta) == 0 { return @@ -2331,6 +2413,7 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio go func() { ticker := time.NewTicker(interval) defer ticker.Stop() + stickyCleanupCounter := 0 m.checkRefreshes(ctx) for { select { @@ -2338,6 +2421,11 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio return case <-ticker.C: m.checkRefreshes(ctx) + stickyCleanupCounter++ + if stickyCleanupCounter >= 60 { // ~every 5 min at default 5s interval + m.sticky.Cleanup() + stickyCleanupCounter = 0 + } } } }() diff --git a/sdk/cliproxy/auth/sticky.go b/sdk/cliproxy/auth/sticky.go new file mode 100644 index 0000000000..252e35b950 --- /dev/null +++ b/sdk/cliproxy/auth/sticky.go @@ -0,0 +1,68 @@ +package auth + +import ( + "sync" + "time" +) + +// stickyStore maintains session-to-auth bindings so that requests carrying the +// same session ID are routed to the same auth/account. Entries expire after a +// configurable TTL and are garbage-collected by Cleanup. +type stickyStore struct { + mu sync.RWMutex + entries map[string]stickyEntry +} + +type stickyEntry struct { + authID string + expiresAt time.Time +} + +func newStickyStore() *stickyStore { + return &stickyStore{entries: make(map[string]stickyEntry)} +} + +// Get returns the bound auth ID for the given session, if it exists and has not +// expired. +func (s *stickyStore) Get(sessionID string) (string, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + e, ok := s.entries[sessionID] + if !ok || time.Now().After(e.expiresAt) { + return "", false + } + return e.authID, true +} + +// Set binds a session to an auth ID with the specified TTL. +func (s *stickyStore) Set(sessionID, authID string, ttl time.Duration) { + s.mu.Lock() + s.entries[sessionID] = stickyEntry{authID: authID, expiresAt: time.Now().Add(ttl)} + s.mu.Unlock() +} + +// Delete removes the binding for the given session ID. +func (s *stickyStore) Delete(sessionID string) { + s.mu.Lock() + delete(s.entries, sessionID) + s.mu.Unlock() +} + +// Cleanup removes all expired entries. +func (s *stickyStore) Cleanup() { + now := time.Now() + s.mu.Lock() + for k, e := range s.entries { + if now.After(e.expiresAt) { + delete(s.entries, k) + } + } + s.mu.Unlock() +} + +// Len returns the number of entries (including possibly-expired ones). +func (s *stickyStore) Len() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.entries) +} diff --git a/sdk/cliproxy/auth/sticky_test.go b/sdk/cliproxy/auth/sticky_test.go new file mode 100644 index 0000000000..7b84740d91 --- /dev/null +++ b/sdk/cliproxy/auth/sticky_test.go @@ -0,0 +1,86 @@ +package auth + +import ( + "testing" + "time" +) + +func TestStickyStore_SetAndGet(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Hour) + + got, ok := s.Get("sess-1") + if !ok || got != "auth-ai" { + t.Fatalf("expected auth-ai, got %q (ok=%v)", got, ok) + } +} + +func TestStickyStore_GetMiss(t *testing.T) { + s := newStickyStore() + _, ok := s.Get("nonexistent") + if ok { + t.Fatal("expected miss for nonexistent session") + } +} + +func TestStickyStore_GetExpired(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Millisecond) + time.Sleep(2 * time.Millisecond) + + _, ok := s.Get("sess-1") + if ok { + t.Fatal("expected miss for expired entry") + } +} + +func TestStickyStore_Delete(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Hour) + s.Delete("sess-1") + + _, ok := s.Get("sess-1") + if ok { + t.Fatal("expected miss after delete") + } +} + +func TestStickyStore_Overwrite(t *testing.T) { + s := newStickyStore() + s.Set("sess-1", "auth-ai", time.Hour) + s.Set("sess-1", "auth-cc", time.Hour) + + got, ok := s.Get("sess-1") + if !ok || got != "auth-cc" { + t.Fatalf("expected auth-cc after overwrite, got %q", got) + } +} + +func TestStickyStore_Cleanup(t *testing.T) { + s := newStickyStore() + s.Set("expired", "auth-ai", time.Millisecond) + s.Set("alive", "auth-cc", time.Hour) + time.Sleep(2 * time.Millisecond) + + s.Cleanup() + + if s.Len() != 1 { + t.Fatalf("expected 1 entry after cleanup, got %d", s.Len()) + } + _, ok := s.Get("alive") + if !ok { + t.Fatal("alive entry should still exist") + } +} + +func TestStickyStore_Len(t *testing.T) { + s := newStickyStore() + if s.Len() != 0 { + t.Fatalf("expected 0, got %d", s.Len()) + } + s.Set("a", "x", time.Hour) + s.Set("b", "y", time.Hour) + if s.Len() != 2 { + t.Fatalf("expected 2, got %d", s.Len()) + } +} diff --git a/sdk/cliproxy/executor/types.go b/sdk/cliproxy/executor/types.go index 4ea8103947..88535c4d45 100644 --- a/sdk/cliproxy/executor/types.go +++ b/sdk/cliproxy/executor/types.go @@ -19,6 +19,8 @@ const ( SelectedAuthCallbackMetadataKey = "selected_auth_callback" // ExecutionSessionMetadataKey identifies a long-lived downstream execution session. ExecutionSessionMetadataKey = "execution_session_id" + // StickySessionMetadataKey carries the session ID for sticky auth routing. + StickySessionMetadataKey = "sticky_session_id" ) // Request encapsulates the translated payload that will be sent to a provider executor. From dc08d09b9c5b3db160601fa09188001674a370b4 Mon Sep 17 00:00:00 2001 From: leecz Date: Fri, 13 Mar 2026 14:45:59 +0800 Subject: [PATCH 2/2] fix: address review feedback on sticky session implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Scope sticky binding by model: key store on sessionID|model composite key so the same session ID with different models gets independent auth bindings (fixes cross-model auth_not_found) - Defer stream binding until genuine success: add bool return to executeStreamWithModelPool so bootstrap errors and empty streams don't poison the sticky store - Rename header X-Session-ID → X-CLIProxyAPI-Session-ID to avoid collision with generic gateway/proxy headers - Update comments in sticky.go to reference new header name - Add TestStickyKey and TestStickyStore_CompositeKey Addresses luispater's blocking review on PR #1998. Co-Authored-By: Claude Opus 4.6 --- sdk/api/handlers/handlers.go | 8 ++- sdk/cliproxy/auth/conductor.go | 107 ++++++++++++++++--------------- sdk/cliproxy/auth/sticky.go | 21 +++++- sdk/cliproxy/auth/sticky_test.go | 70 ++++++++++++++++++++ 4 files changed, 149 insertions(+), 57 deletions(-) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 321be75fc0..dc89a7360c 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -208,7 +208,7 @@ func requestExecutionMetadata(ctx context.Context) map[string]any { if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" { meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID } - // Sticky session: forward X-Session-ID header so the conductor can pin + // Sticky session: forward X-CLIProxyAPI-Session-ID header so the conductor can pin // subsequent requests from the same session to the same auth account. if sessionID := stickySessionIDFromHeader(ctx); sessionID != "" { meta[coreexecutor.StickySessionMetadataKey] = sessionID @@ -265,7 +265,11 @@ func stickySessionIDFromHeader(ctx context.Context) string { if !ok || ginCtx == nil || ginCtx.Request == nil { return "" } - return strings.TrimSpace(ginCtx.GetHeader("X-Session-ID")) + id := strings.TrimSpace(ginCtx.GetHeader("X-CLIProxyAPI-Session-ID")) + if len(id) > coreauth.StickyMaxSessionIDLen { + return "" + } + return id } // BaseAPIHandler contains the handlers for API endpoints. diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 676699753e..a2b70b4b45 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -487,7 +487,7 @@ func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamC } } -func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk, stickySessionID string) *cliproxyexecutor.StreamResult { +func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk, stickyKey string) *cliproxyexecutor.StreamResult { out := make(chan cliproxyexecutor.StreamChunk) go func() { defer close(out) @@ -502,9 +502,9 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro } m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}) // Sticky session: clear binding on mid-stream error. - if stickySessionID != "" { + if stickyKey != "" { if sc := statusCodeFromResult(rerr); sc == 429 || sc >= 500 { - m.sticky.Delete(stickySessionID) + m.sticky.Delete(stickyKey) } } } @@ -542,9 +542,9 @@ func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, ro return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out} } -func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, stickySessionID string) (*cliproxyexecutor.StreamResult, error) { +func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, stickyKey string) (*cliproxyexecutor.StreamResult, bool, error) { if executor == nil { - return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} + return nil, false, &Error{Code: "executor_not_found", Message: "executor not registered"} } execModels := m.prepareExecutionModels(auth, routeModel) var lastErr error @@ -554,7 +554,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts) if errStream != nil { if errCtx := ctx.Err(); errCtx != nil { - return nil, errCtx + return nil, false, errCtx } rerr := &Error{Message: errStream.Error()} if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { @@ -564,7 +564,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi result.RetryAfter = retryAfterFromError(errStream) m.MarkResult(ctx, result) if isRequestInvalidError(errStream) { - return nil, errStream + return nil, false, errStream } lastErr = errStream continue @@ -574,7 +574,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi if bootstrapErr != nil { if errCtx := ctx.Err(); errCtx != nil { discardStreamChunks(streamResult.Chunks) - return nil, errCtx + return nil, false, errCtx } if isRequestInvalidError(bootstrapErr) { rerr := &Error{Message: bootstrapErr.Error()} @@ -585,7 +585,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi result.RetryAfter = retryAfterFromError(bootstrapErr) m.MarkResult(ctx, result) discardStreamChunks(streamResult.Chunks) - return nil, bootstrapErr + return nil, false, bootstrapErr } if idx < len(execModels)-1 { rerr := &Error{Message: bootstrapErr.Error()} @@ -602,7 +602,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi errCh := make(chan cliproxyexecutor.StreamChunk, 1) errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr} close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickyKey), false, nil } if closed && len(buffered) == 0 { @@ -616,7 +616,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi errCh := make(chan cliproxyexecutor.StreamChunk, 1) errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr} close(errCh) - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickySessionID), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh, stickyKey), false, nil } remaining := streamResult.Chunks @@ -625,12 +625,12 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi close(closedCh) remaining = closedCh } - return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining, stickySessionID), nil + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining, stickyKey), true, nil } if lastErr == nil { lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"} } - return nil, lastErr + return nil, false, lastErr } func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() { @@ -989,16 +989,8 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) - // Sticky session: resolve session→auth binding before pick loop. - // An explicit pinned_auth_id takes precedence over sticky binding. - stickySessionID := stickySessionIDFromMetadata(opts.Metadata) - if stickySessionID != "" { - if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { - if boundAuth, found := m.sticky.Get(stickySessionID); found { - opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth - } - } - } + stickySessionID := m.resolveStickySession(opts.Metadata, routeModel) + sk := stickyKey(stickySessionID, routeModel) tried := make(map[string]struct{}) var lastErr error @@ -1049,9 +1041,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req m.MarkResult(execCtx, result) // Sticky session: clear binding on rate-limit or server error so next // request falls back to normal auth selection. - if stickySessionID != "" { + if sk != "" { if sc := statusCodeFromResult(result.Error); sc == 429 || sc >= 500 { - m.sticky.Delete(stickySessionID) + m.sticky.Delete(sk) delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) } } @@ -1062,8 +1054,8 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req continue } // Sticky session: bind session to the auth that succeeded. - if stickySessionID != "" { - m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL) + if sk != "" { + m.sticky.Set(sk, auth.ID, stickySessionTTL) } m.MarkResult(execCtx, result) return resp, nil @@ -1157,16 +1149,8 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) - // Sticky session: resolve session→auth binding before pick loop. - // An explicit pinned_auth_id takes precedence over sticky binding. - stickySessionID := stickySessionIDFromMetadata(opts.Metadata) - if stickySessionID != "" { - if _, alreadyPinned := opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { - if boundAuth, found := m.sticky.Get(stickySessionID); found { - opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth - } - } - } + stickySessionID := m.resolveStickySession(opts.Metadata, routeModel) + sk := stickyKey(stickySessionID, routeModel) tried := make(map[string]struct{}) var lastErr error @@ -1195,19 +1179,15 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, stickySessionID) + streamResult, streamOK, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, sk) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx } // Sticky session: clear binding on rate-limit or server error. - if stickySessionID != "" { - sc := 0 - if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { - sc = se.StatusCode() - } - if sc == 429 || sc >= 500 { - m.sticky.Delete(stickySessionID) + if sk != "" { + if sc := statusCodeFromError(errStream); sc == 429 || sc >= 500 { + m.sticky.Delete(sk) delete(opts.Metadata, cliproxyexecutor.PinnedAuthMetadataKey) } } @@ -1217,9 +1197,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string lastErr = errStream continue } - // Sticky session: bind session to the auth that started streaming. - if stickySessionID != "" { - m.sticky.Set(stickySessionID, auth.ID, stickySessionTTL) + // Sticky session: only bind on genuine stream success. + if sk != "" && streamOK { + m.sticky.Set(sk, auth.ID, stickySessionTTL) } return streamResult, nil } @@ -1285,6 +1265,12 @@ func pinnedAuthIDFromMetadata(meta map[string]any) string { // stickySessionTTL is the duration a session-to-auth binding remains valid. const stickySessionTTL = time.Hour +// stickyKey builds a composite store key from session ID and model so that +// the same session ID used with different models gets independent bindings. +func stickyKey(sessionID, model string) string { + return sessionID + "|" + model +} + func stickySessionIDFromMetadata(meta map[string]any) string { if len(meta) == 0 { return "" @@ -1303,6 +1289,23 @@ func stickySessionIDFromMetadata(meta map[string]any) string { } } +// resolveStickySession extracts the sticky session ID from metadata and, if a +// valid binding exists, sets the pinned auth ID. An explicit pinned_auth_id +// in the metadata takes precedence over any sticky binding. +// routeModel scopes the lookup so the same session ID with different models +// gets independent auth bindings. +func (m *Manager) resolveStickySession(meta map[string]any, routeModel string) string { + id := stickySessionIDFromMetadata(meta) + if id != "" { + if _, alreadyPinned := meta[cliproxyexecutor.PinnedAuthMetadataKey]; !alreadyPinned { + if boundAuth, found := m.sticky.Get(stickyKey(id, routeModel)); found { + meta[cliproxyexecutor.PinnedAuthMetadataKey] = boundAuth + } + } + } + return id +} + func publishSelectedAuthMetadata(meta map[string]any, authID string) { if len(meta) == 0 { return @@ -2413,7 +2416,8 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio go func() { ticker := time.NewTicker(interval) defer ticker.Stop() - stickyCleanupCounter := 0 + lastStickyCleanup := time.Now() + const stickyCleanupInterval = 5 * time.Minute m.checkRefreshes(ctx) for { select { @@ -2421,10 +2425,9 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio return case <-ticker.C: m.checkRefreshes(ctx) - stickyCleanupCounter++ - if stickyCleanupCounter >= 60 { // ~every 5 min at default 5s interval + if time.Since(lastStickyCleanup) >= stickyCleanupInterval { m.sticky.Cleanup() - stickyCleanupCounter = 0 + lastStickyCleanup = time.Now() } } } diff --git a/sdk/cliproxy/auth/sticky.go b/sdk/cliproxy/auth/sticky.go index 252e35b950..c0eeede850 100644 --- a/sdk/cliproxy/auth/sticky.go +++ b/sdk/cliproxy/auth/sticky.go @@ -8,9 +8,13 @@ import ( // stickyStore maintains session-to-auth bindings so that requests carrying the // same session ID are routed to the same auth/account. Entries expire after a // configurable TTL and are garbage-collected by Cleanup. +// +// maxEntries caps the number of stored bindings to prevent memory exhaustion +// from untrusted X-CLIProxyAPI-Session-ID headers. type stickyStore struct { - mu sync.RWMutex - entries map[string]stickyEntry + mu sync.RWMutex + entries map[string]stickyEntry + maxEntries int } type stickyEntry struct { @@ -18,8 +22,14 @@ type stickyEntry struct { expiresAt time.Time } +// stickyMaxEntries is the upper bound on stored session bindings. +const stickyMaxEntries = 10_000 + +// StickyMaxSessionIDLen limits the accepted X-CLIProxyAPI-Session-ID length. +const StickyMaxSessionIDLen = 256 + func newStickyStore() *stickyStore { - return &stickyStore{entries: make(map[string]stickyEntry)} + return &stickyStore{entries: make(map[string]stickyEntry), maxEntries: stickyMaxEntries} } // Get returns the bound auth ID for the given session, if it exists and has not @@ -35,8 +45,13 @@ func (s *stickyStore) Get(sessionID string) (string, bool) { } // Set binds a session to an auth ID with the specified TTL. +// If the store is at capacity, the write is silently dropped. func (s *stickyStore) Set(sessionID, authID string, ttl time.Duration) { s.mu.Lock() + if _, exists := s.entries[sessionID]; !exists && len(s.entries) >= s.maxEntries { + s.mu.Unlock() + return + } s.entries[sessionID] = stickyEntry{authID: authID, expiresAt: time.Now().Add(ttl)} s.mu.Unlock() } diff --git a/sdk/cliproxy/auth/sticky_test.go b/sdk/cliproxy/auth/sticky_test.go index 7b84740d91..5b1d67512c 100644 --- a/sdk/cliproxy/auth/sticky_test.go +++ b/sdk/cliproxy/auth/sticky_test.go @@ -84,3 +84,73 @@ func TestStickyStore_Len(t *testing.T) { t.Fatalf("expected 2, got %d", s.Len()) } } + +func TestStickyStore_MaxEntries(t *testing.T) { + s := newStickyStore() + s.maxEntries = 2 + + s.Set("a", "x", time.Hour) + s.Set("b", "y", time.Hour) + s.Set("c", "z", time.Hour) // should be silently dropped + + if s.Len() != 2 { + t.Fatalf("expected 2 (capped), got %d", s.Len()) + } + if _, ok := s.Get("c"); ok { + t.Fatal("entry 'c' should have been dropped due to capacity") + } + // overwriting existing entry should still work at capacity + s.Set("a", "updated", time.Hour) + got, ok := s.Get("a") + if !ok || got != "updated" { + t.Fatalf("expected 'updated' for overwrite at capacity, got %q (ok=%v)", got, ok) + } +} + +func TestStickyKey(t *testing.T) { + cases := []struct { + sessionID string + model string + want string + }{ + {"sess-1", "claude-3-opus", "sess-1|claude-3-opus"}, + {"sess-1", "claude-3-sonnet", "sess-1|claude-3-sonnet"}, + {"", "claude-3-opus", "|claude-3-opus"}, + {"sess-1", "", "sess-1|"}, + } + for _, tc := range cases { + got := stickyKey(tc.sessionID, tc.model) + if got != tc.want { + t.Errorf("stickyKey(%q, %q) = %q, want %q", tc.sessionID, tc.model, got, tc.want) + } + } +} + +func TestStickyStore_CompositeKey(t *testing.T) { + s := newStickyStore() + + // Same session ID, different models → independent bindings + k1 := stickyKey("sess-1", "claude-3-opus") + k2 := stickyKey("sess-1", "claude-3-sonnet") + + s.Set(k1, "auth-cc", time.Hour) + s.Set(k2, "auth-ai", time.Hour) + + got1, ok1 := s.Get(k1) + if !ok1 || got1 != "auth-cc" { + t.Fatalf("expected auth-cc for opus key, got %q (ok=%v)", got1, ok1) + } + got2, ok2 := s.Get(k2) + if !ok2 || got2 != "auth-ai" { + t.Fatalf("expected auth-ai for sonnet key, got %q (ok=%v)", got2, ok2) + } + + // Deleting one doesn't affect the other + s.Delete(k1) + if _, ok := s.Get(k1); ok { + t.Fatal("expected miss after deleting opus key") + } + if _, ok := s.Get(k2); !ok { + t.Fatal("sonnet key should still exist after deleting opus key") + } +}