diff --git a/internal/adapter/health/checker.go b/internal/adapter/health/checker.go index a382a88..73c0d87 100644 --- a/internal/adapter/health/checker.go +++ b/internal/adapter/health/checker.go @@ -105,7 +105,9 @@ func (c *HTTPHealthChecker) StartChecking(ctx context.Context) error { } func (c *HTTPHealthChecker) StopChecking(ctx context.Context) error { - if !c.isRunning.Load() { + // CAS from true→false is the single guard; concurrent callers that lose the + // race see false and return early without touching stopCh. + if !c.isRunning.CompareAndSwap(true, false) { return nil } @@ -114,7 +116,6 @@ func (c *HTTPHealthChecker) StopChecking(ctx context.Context) error { } close(c.stopCh) - c.isRunning.Store(false) return nil } diff --git a/internal/adapter/health/checker_test.go b/internal/adapter/health/checker_test.go index 053f97b..ed9f5c9 100644 --- a/internal/adapter/health/checker_test.go +++ b/internal/adapter/health/checker_test.go @@ -714,3 +714,34 @@ func (s *statusCodeHTTPClient) Do(req *http.Request) (*http.Response, error) { Body: http.NoBody, }, nil } + +// TestStopChecking_DoubleInvoke verifies concurrent double-stops do not panic. +// Previously, two callers that both passed the isRunning.Load() guard could +// race to close(stopCh), causing a "close of closed channel" panic. +func TestStopChecking_DoubleInvoke(t *testing.T) { + t.Parallel() + + loggerCfg := &logger.Config{Level: "error", Theme: "default"} + log, cleanup, _ := logger.New(loggerCfg) + defer cleanup() + styledLogger := logger.NewPlainStyledLogger(log) + + mockRepo := newMockRepository() + checker := NewHTTPHealthChecker(mockRepo, styledLogger, &mockHTTPClient{statusCode: 200}) + + // Start the checker so isRunning == true. + if err := checker.StartChecking(context.Background()); err != nil { + t.Fatalf("StartChecking: %v", err) + } + + // Two concurrent stops — neither should panic. + var wg sync.WaitGroup + wg.Add(2) + for range 2 { + go func() { + defer wg.Done() + _ = checker.StopChecking(context.Background()) + }() + } + wg.Wait() +} diff --git a/internal/adapter/proxy/core/retry.go b/internal/adapter/proxy/core/retry.go index 4fbb157..9ad01a2 100644 --- a/internal/adapter/proxy/core/retry.go +++ b/internal/adapter/proxy/core/retry.go @@ -210,16 +210,18 @@ func IsConnectionError(err error) bool { return hasConnectionError(err) } +// connectionErrors is a last-resort string fallback for errors that have lost +// their type information (e.g. plain errors.New from middleware or test stubs). +// Well-wrapped OS errors are already caught by the net.Error / syscall.Errno +// branches above, so this list covers only cases those branches cannot reach. var connectionErrors = []string{ - "connection refused", - "connection reset", - "no such host", - "network is unreachable", - "no route to host", - "connection timed out", - "i/o timeout", - "dial tcp", - "connectex:", + "connection refused", // syscall.ECONNREFUSED on non-unwrappable paths + "connection reset", // syscall.ECONNRESET on non-unwrappable paths + "no such host", // *net.DNSError without type chain + "network is unreachable", // syscall.ENETUNREACH without type chain + "no route to host", // syscall.EHOSTUNREACH without type chain + "i/o timeout", // plain-string timeout errors; net.Error.Timeout() covers wrapped ones + "connectex:", // Windows dial error prefix; appears without net.Error wrapping on some paths } func hasConnectionError(err error) bool { diff --git a/internal/adapter/proxy/olla/service.go b/internal/adapter/proxy/olla/service.go index b58ccb1..eecda19 100644 --- a/internal/adapter/proxy/olla/service.go +++ b/internal/adapter/proxy/olla/service.go @@ -31,6 +31,7 @@ import ( "net/url" "runtime" "runtime/debug" + "sync" "sync/atomic" "time" @@ -66,22 +67,23 @@ type Service struct { *core.BaseProxyComponents // Object pools for zero-allocation operations - bufferPool *pool.Pool[*[]byte] - requestPool *pool.Pool[*requestContext] - responsePool *pool.Pool[[]byte] - errorPool *pool.Pool[*errorContext] + bufferPool *pool.Pool[*[]byte] + requestPool *pool.Pool[*requestContext] + errorPool *pool.Pool[*errorContext] transport *http.Transport configuration *Configuration retryHandler *core.RetryHandler - // Cleanup management cleanupTicker *time.Ticker cleanupStop chan struct{} // Per-endpoint connection pools and circuit breakers endpointPools xsync.Map[string, *connectionPool] circuitBreakers xsync.Map[string, *circuitBreaker] + + // Cleanup management + cleanupOnce sync.Once } // connectionPool isolates HTTP transport instances per endpoint @@ -177,13 +179,6 @@ func NewService( return nil, fmt.Errorf("failed to create request pool: %w", err) } - responsePool, err := pool.NewLitePool(func() []byte { - return make([]byte, 32*1024) // 32KB for response bodies - }) - if err != nil { - return nil, fmt.Errorf("failed to create response pool: %w", err) - } - errorPool, err := pool.NewLitePool(func() *errorContext { return &errorContext{} }) @@ -197,7 +192,6 @@ func NewService( BaseProxyComponents: base, bufferPool: bufferPool, requestPool: requestPool, - responsePool: responsePool, errorPool: errorPool, transport: transport, configuration: configuration, @@ -328,104 +322,6 @@ func (s *Service) ProxyRequestToEndpoints(ctx context.Context, w http.ResponseWr return s.ProxyRequestToEndpointsWithRetry(ctx, w, r, endpoints, stats, rlog) } -// proxyToSingleEndpointLegacy retained for reference during migration -// TODO: Remove after retry logic stability confirmed -func (s *Service) proxyToSingleEndpointLegacy(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoints []*domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) (err error) { - // Get request context from pool - reqCtx := s.requestPool.Get() - defer s.requestPool.Put(reqCtx) - - reqCtx.requestID = stats.RequestID - reqCtx.startTime = stats.StartTime - - // Panic recovery - defer func() { - if rec := recover(); rec != nil { - s.handlePanic(ctx, w, r, stats, rlog, rec, &err) - } - }() - - s.IncrementRequests() - - // Use context logger if available, fallback to provided logger - ctxLogger := middleware.GetLogger(ctx) - if ctxLogger != nil { - ctxLogger.Debug("Olla proxy request started", - "method", r.Method, - "url", r.URL.String(), - "endpoint_count", len(endpoints)) - } else { - rlog.Debug("proxy request started", "method", r.Method, "url", r.URL.String()) - } - - if len(endpoints) == 0 { - if ctxLogger != nil { - ctxLogger.Error("No healthy endpoints available for request") - } else { - rlog.Error("no healthy endpoints available") - } - s.RecordFailure(ctx, nil, time.Since(stats.StartTime), common.ErrNoHealthyEndpoints) - return common.ErrNoHealthyEndpoints - } - - if ctxLogger != nil { - ctxLogger.Debug("Using provided endpoints", "count", len(endpoints)) - } else { - rlog.Debug("using provided endpoints", "count", len(endpoints)) - } - - // Select endpoint with circuit breaker check - endpoint, cb := s.selectEndpointWithCircuitBreaker(endpoints, rlog) - if endpoint == nil { - s.RecordFailure(ctx, nil, time.Since(stats.StartTime), errors.New("all endpoints circuit breakers open")) - return errors.New("all endpoints unavailable due to circuit breakers") - } - - stats.EndpointName = endpoint.Name - reqCtx.endpoint = endpoint.Name - - // Track connections - s.Selector.IncrementConnections(endpoint) - defer s.Selector.DecrementConnections(endpoint) - - // Build target URL - targetURL := s.buildTargetURL(r, endpoint) - stats.TargetUrl = targetURL.String() - reqCtx.targetURL = targetURL.String() - - if ctxLogger != nil { - ctxLogger.Info("Request dispatching", - "endpoint", endpoint.Name, - "target", stats.TargetUrl, - "model", stats.Model) - } else { - rlog.Info("Request dispatching", "endpoint", endpoint.Name, "target", stats.TargetUrl, "model", stats.Model) - } - - // Create and prepare proxy request - // Rewrite model name in request body if this is an alias-resolved request - core.RewriteModelForAlias(ctx, r, endpoint) - - proxyReq, err := s.prepareProxyRequest(ctx, r, targetURL, stats) - if err != nil { - cb.RecordFailure() - s.RecordFailure(ctx, endpoint, time.Since(stats.StartTime), err) - return fmt.Errorf("failed to create proxy request: %w", err) - } - - rlog.Debug("created proxy request") - - // Execute backend request - resp, err := s.executeBackendRequest(ctx, endpoint, proxyReq, cb, stats, rlog) - if err != nil { - return err - } - defer resp.Body.Close() - - // Handle successful response - return s.handleSuccessfulResponse(ctx, w, r, resp, endpoint, cb, stats, rlog) -} - // handlePanic handles panic recovery in proxy requests func (s *Service) handlePanic(ctx context.Context, w http.ResponseWriter, r *http.Request, stats *ports.RequestStats, rlog logger.StyledLogger, rec interface{}, err *error) { s.RecordFailure(ctx, nil, time.Since(stats.StartTime), fmt.Errorf("panic: %v", rec)) @@ -483,7 +379,7 @@ func (s *Service) prepareProxyRequest(ctx context.Context, r *http.Request, targ stats.HeaderProcessingMs = time.Since(headerStart).Milliseconds() // Add model header - if model, ok := ctx.Value("model").(string); ok && model != "" { + if model, ok := ctx.Value(constants.ContextModelKey).(string); ok && model != "" { proxyReq.Header.Set("X-Model", model) stats.Model = model } @@ -777,29 +673,31 @@ func (s *Service) cleanupUnusedResources() { } } -// Cleanup cleans up resources +// Cleanup cleans up resources. Safe to call more than once. func (s *Service) Cleanup() { - // Stop cleanup goroutine - if s.cleanupStop != nil { - close(s.cleanupStop) - } - if s.cleanupTicker != nil { - s.cleanupTicker.Stop() - } + s.cleanupOnce.Do(func() { + // Stop cleanup goroutine + if s.cleanupStop != nil { + close(s.cleanupStop) + } + if s.cleanupTicker != nil { + s.cleanupTicker.Stop() + } - // Close all endpoint pools - s.endpointPools.Range(func(key string, pool *connectionPool) bool { - pool.transport.CloseIdleConnections() - return true - }) + // Close all endpoint pools + s.endpointPools.Range(func(key string, pool *connectionPool) bool { + pool.transport.CloseIdleConnections() + return true + }) - s.endpointPools.Clear() - s.circuitBreakers.Clear() + s.endpointPools.Clear() + s.circuitBreakers.Clear() - s.BaseProxyComponents.Shutdown() + s.BaseProxyComponents.Shutdown() - // force GC to clean up - runtime.GC() + // force GC to clean up + runtime.GC() - s.Logger.Debug("Olla proxy service cleaned up") + s.Logger.Debug("Olla proxy service cleaned up") + }) } diff --git a/internal/adapter/proxy/olla/service_leak_test.go b/internal/adapter/proxy/olla/service_leak_test.go index 60882dd..06f0552 100644 --- a/internal/adapter/proxy/olla/service_leak_test.go +++ b/internal/adapter/proxy/olla/service_leak_test.go @@ -367,6 +367,29 @@ func (m *mockStatsCollector) GetEndpointStats() map[string]ports.EndpointStats { func (m *mockStatsCollector) GetSecurityStats() ports.SecurityStats { return ports.SecurityStats{} } func (m *mockStatsCollector) GetConnectionStats() map[string]int64 { return nil } +// TestCleanup_DoubleInvoke verifies that calling Cleanup twice does not panic. +// Previously, the second call would close an already-closed channel. +func TestCleanup_DoubleInvoke(t *testing.T) { + t.Parallel() + + s := &Service{ + BaseProxyComponents: &core.BaseProxyComponents{ + Logger: createTestLogger(), + }, + configuration: &Configuration{}, + endpointPools: *xsync.NewMap[string, *connectionPool](), + circuitBreakers: *xsync.NewMap[string, *circuitBreaker](), + cleanupTicker: time.NewTicker(time.Hour), + cleanupStop: make(chan struct{}), + } + + go s.cleanupLoop() + + // Neither call should panic. + s.Cleanup() + s.Cleanup() +} + func createTestLogger() logger.StyledLogger { loggerCfg := &logger.Config{Level: "error", Theme: "default"} log, _, _ := logger.New(loggerCfg) diff --git a/internal/adapter/proxy/proxy_headers_test.go b/internal/adapter/proxy/proxy_headers_test.go index 94039d1..e6c162a 100644 --- a/internal/adapter/proxy/proxy_headers_test.go +++ b/internal/adapter/proxy/proxy_headers_test.go @@ -83,7 +83,7 @@ func TestProxyResponseHeaders(t *testing.T) { // Test with model in context t.Run("with model", func(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) - ctx := context.WithValue(req.Context(), "model", "llama3.2:3b") + ctx := context.WithValue(req.Context(), constants.ContextModelKey, "llama3.2:3b") req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -127,7 +127,7 @@ func TestProxyResponseHeaders_NoOverride(t *testing.T) { proxy, _ := sherpa.NewService(discovery, selector, config, createTestStatsCollector(), nil, createTestLogger()) req := httptest.NewRequest("GET", "/test", nil) - ctx := context.WithValue(req.Context(), "model", "real-model") + ctx := context.WithValue(req.Context(), constants.ContextModelKey, "real-model") req = req.WithContext(ctx) w := httptest.NewRecorder() diff --git a/internal/adapter/proxy/sherpa/service_retry.go b/internal/adapter/proxy/sherpa/service_retry.go index 901b320..8b9601c 100644 --- a/internal/adapter/proxy/sherpa/service_retry.go +++ b/internal/adapter/proxy/sherpa/service_retry.go @@ -86,7 +86,7 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit stats.HeaderProcessingMs = time.Since(headerStart).Milliseconds() // Add model header if available - if model, ok := ctx.Value("model").(string); ok && model != "" { + if model, ok := ctx.Value(constants.ContextModelKey).(string); ok && model != "" { proxyReq.Header.Set(constants.HeaderXModel, model) stats.Model = model } diff --git a/internal/app/handlers/handler_translation.go b/internal/app/handlers/handler_translation.go index f5593f9..b54fef1 100644 --- a/internal/app/handlers/handler_translation.go +++ b/internal/app/handlers/handler_translation.go @@ -349,7 +349,7 @@ func (a *Application) executeTranslatedNonStreamingRequest( // prepareProxyContext sets up context with model, routing decision, and alias rewrite map func (a *Application) prepareProxyContext(ctx context.Context, r *http.Request, pr *proxyRequest) (context.Context, *http.Request) { if pr.model != "" { - ctx = context.WithValue(ctx, "model", pr.model) + ctx = context.WithValue(ctx, constants.ContextModelKey, pr.model) r = r.WithContext(ctx) } diff --git a/internal/app/handlers/handler_translation_alias_test.go b/internal/app/handlers/handler_translation_alias_test.go index 6db0a0d..06b3340 100644 --- a/internal/app/handlers/handler_translation_alias_test.go +++ b/internal/app/handlers/handler_translation_alias_test.go @@ -77,7 +77,7 @@ func TestPrepareProxyContext_NoAliasMapWhenNoneStored(t *testing.T) { assert.Nil(t, rawMap, "alias rewrite map should not be present for non-alias requests") // Model should still be set in context - assert.Equal(t, "llama3.1:8b", r.Context().Value("model")) + assert.Equal(t, "llama3.1:8b", r.Context().Value(constants.ContextModelKey)) } func TestPrepareProxyContext_NilProfile(t *testing.T) { @@ -100,5 +100,5 @@ func TestPrepareProxyContext_NilProfile(t *testing.T) { rawMap := r.Context().Value(constants.ContextModelAliasMapKey) assert.Nil(t, rawMap, "alias rewrite map should not be present when profile is nil") - assert.Equal(t, "llama3.1:8b", r.Context().Value("model")) + assert.Equal(t, "llama3.1:8b", r.Context().Value(constants.ContextModelKey)) } diff --git a/internal/core/constants/context.go b/internal/core/constants/context.go index ad85523..c160b90 100644 --- a/internal/core/constants/context.go +++ b/internal/core/constants/context.go @@ -11,6 +11,11 @@ const ( ContextKeyStream = "stream" // indicates whether the response should be streamed or buffered ContextProviderTypeKey = "provider_type" // the provider type for the request, used for routing and load balancing + // ContextModelKey carries the resolved model name through the proxy pipeline. + // Using a typed key prevents accidental collisions with plain-string keys from + // third-party middleware that might also use "model". + ContextModelKey = contextKey("model") + // Sticky session context keys — set by the handler before balancer selection // and read back after to surface affinity decisions in response headers. ContextStickyKeyKey = contextKey("sticky-key") // computed affinity key for this request