Skip to content
5 changes: 3 additions & 2 deletions internal/adapter/health/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ func (c *HTTPHealthChecker) StartChecking(ctx context.Context) error {
}

func (c *HTTPHealthChecker) StopChecking(ctx context.Context) error {
if !c.isRunning.Load() {
// CAS from true→false is the single guard; concurrent callers that lose the
// race see false and return early without touching stopCh.
if !c.isRunning.CompareAndSwap(true, false) {
return nil
}

Expand All @@ -114,7 +116,6 @@ func (c *HTTPHealthChecker) StopChecking(ctx context.Context) error {
}

close(c.stopCh)
c.isRunning.Store(false)

return nil
}
Expand Down
31 changes: 31 additions & 0 deletions internal/adapter/health/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,34 @@ func (s *statusCodeHTTPClient) Do(req *http.Request) (*http.Response, error) {
Body: http.NoBody,
}, nil
}

// TestStopChecking_DoubleInvoke verifies concurrent double-stops do not panic.
// Previously, two callers that both passed the isRunning.Load() guard could
// race to close(stopCh), causing a "close of closed channel" panic.
func TestStopChecking_DoubleInvoke(t *testing.T) {
t.Parallel()

loggerCfg := &logger.Config{Level: "error", Theme: "default"}
log, cleanup, _ := logger.New(loggerCfg)
defer cleanup()
styledLogger := logger.NewPlainStyledLogger(log)

mockRepo := newMockRepository()
checker := NewHTTPHealthChecker(mockRepo, styledLogger, &mockHTTPClient{statusCode: 200})

// Start the checker so isRunning == true.
if err := checker.StartChecking(context.Background()); err != nil {
t.Fatalf("StartChecking: %v", err)
}

// Two concurrent stops — neither should panic.
var wg sync.WaitGroup
wg.Add(2)
for range 2 {
go func() {
defer wg.Done()
_ = checker.StopChecking(context.Background())
}()
}
wg.Wait()
}
20 changes: 11 additions & 9 deletions internal/adapter/proxy/core/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,18 @@ func IsConnectionError(err error) bool {
return hasConnectionError(err)
}

// connectionErrors is a last-resort string fallback for errors that have lost
// their type information (e.g. plain errors.New from middleware or test stubs).
// Well-wrapped OS errors are already caught by the net.Error / syscall.Errno
// branches above, so this list covers only cases those branches cannot reach.
var connectionErrors = []string{
"connection refused",
"connection reset",
"no such host",
"network is unreachable",
"no route to host",
"connection timed out",
"i/o timeout",
"dial tcp",
"connectex:",
"connection refused", // syscall.ECONNREFUSED on non-unwrappable paths
"connection reset", // syscall.ECONNRESET on non-unwrappable paths
"no such host", // *net.DNSError without type chain
"network is unreachable", // syscall.ENETUNREACH without type chain
"no route to host", // syscall.EHOSTUNREACH without type chain
"i/o timeout", // plain-string timeout errors; net.Error.Timeout() covers wrapped ones
"connectex:", // Windows dial error prefix; appears without net.Error wrapping on some paths
}

func hasConnectionError(err error) bool {
Expand Down
160 changes: 29 additions & 131 deletions internal/adapter/proxy/olla/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"net/url"
"runtime"
"runtime/debug"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -66,22 +67,23 @@ type Service struct {
*core.BaseProxyComponents

// Object pools for zero-allocation operations
bufferPool *pool.Pool[*[]byte]
requestPool *pool.Pool[*requestContext]
responsePool *pool.Pool[[]byte]
errorPool *pool.Pool[*errorContext]
bufferPool *pool.Pool[*[]byte]
requestPool *pool.Pool[*requestContext]
errorPool *pool.Pool[*errorContext]

transport *http.Transport
configuration *Configuration
retryHandler *core.RetryHandler

// Cleanup management
cleanupTicker *time.Ticker
cleanupStop chan struct{}

// Per-endpoint connection pools and circuit breakers
endpointPools xsync.Map[string, *connectionPool]
circuitBreakers xsync.Map[string, *circuitBreaker]

// Cleanup management
cleanupOnce sync.Once
}

// connectionPool isolates HTTP transport instances per endpoint
Expand Down Expand Up @@ -177,13 +179,6 @@ func NewService(
return nil, fmt.Errorf("failed to create request pool: %w", err)
}

responsePool, err := pool.NewLitePool(func() []byte {
return make([]byte, 32*1024) // 32KB for response bodies
})
if err != nil {
return nil, fmt.Errorf("failed to create response pool: %w", err)
}

errorPool, err := pool.NewLitePool(func() *errorContext {
return &errorContext{}
})
Expand All @@ -197,7 +192,6 @@ func NewService(
BaseProxyComponents: base,
bufferPool: bufferPool,
requestPool: requestPool,
responsePool: responsePool,
errorPool: errorPool,
transport: transport,
configuration: configuration,
Expand Down Expand Up @@ -328,104 +322,6 @@ func (s *Service) ProxyRequestToEndpoints(ctx context.Context, w http.ResponseWr
return s.ProxyRequestToEndpointsWithRetry(ctx, w, r, endpoints, stats, rlog)
}

// proxyToSingleEndpointLegacy retained for reference during migration
// TODO: Remove after retry logic stability confirmed
func (s *Service) proxyToSingleEndpointLegacy(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoints []*domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) (err error) {
// Get request context from pool
reqCtx := s.requestPool.Get()
defer s.requestPool.Put(reqCtx)

reqCtx.requestID = stats.RequestID
reqCtx.startTime = stats.StartTime

// Panic recovery
defer func() {
if rec := recover(); rec != nil {
s.handlePanic(ctx, w, r, stats, rlog, rec, &err)
}
}()

s.IncrementRequests()

// Use context logger if available, fallback to provided logger
ctxLogger := middleware.GetLogger(ctx)
if ctxLogger != nil {
ctxLogger.Debug("Olla proxy request started",
"method", r.Method,
"url", r.URL.String(),
"endpoint_count", len(endpoints))
} else {
rlog.Debug("proxy request started", "method", r.Method, "url", r.URL.String())
}

if len(endpoints) == 0 {
if ctxLogger != nil {
ctxLogger.Error("No healthy endpoints available for request")
} else {
rlog.Error("no healthy endpoints available")
}
s.RecordFailure(ctx, nil, time.Since(stats.StartTime), common.ErrNoHealthyEndpoints)
return common.ErrNoHealthyEndpoints
}

if ctxLogger != nil {
ctxLogger.Debug("Using provided endpoints", "count", len(endpoints))
} else {
rlog.Debug("using provided endpoints", "count", len(endpoints))
}

// Select endpoint with circuit breaker check
endpoint, cb := s.selectEndpointWithCircuitBreaker(endpoints, rlog)
if endpoint == nil {
s.RecordFailure(ctx, nil, time.Since(stats.StartTime), errors.New("all endpoints circuit breakers open"))
return errors.New("all endpoints unavailable due to circuit breakers")
}

stats.EndpointName = endpoint.Name
reqCtx.endpoint = endpoint.Name

// Track connections
s.Selector.IncrementConnections(endpoint)
defer s.Selector.DecrementConnections(endpoint)

// Build target URL
targetURL := s.buildTargetURL(r, endpoint)
stats.TargetUrl = targetURL.String()
reqCtx.targetURL = targetURL.String()

if ctxLogger != nil {
ctxLogger.Info("Request dispatching",
"endpoint", endpoint.Name,
"target", stats.TargetUrl,
"model", stats.Model)
} else {
rlog.Info("Request dispatching", "endpoint", endpoint.Name, "target", stats.TargetUrl, "model", stats.Model)
}

// Create and prepare proxy request
// Rewrite model name in request body if this is an alias-resolved request
core.RewriteModelForAlias(ctx, r, endpoint)

proxyReq, err := s.prepareProxyRequest(ctx, r, targetURL, stats)
if err != nil {
cb.RecordFailure()
s.RecordFailure(ctx, endpoint, time.Since(stats.StartTime), err)
return fmt.Errorf("failed to create proxy request: %w", err)
}

rlog.Debug("created proxy request")

// Execute backend request
resp, err := s.executeBackendRequest(ctx, endpoint, proxyReq, cb, stats, rlog)
if err != nil {
return err
}
defer resp.Body.Close()

// Handle successful response
return s.handleSuccessfulResponse(ctx, w, r, resp, endpoint, cb, stats, rlog)
}

// handlePanic handles panic recovery in proxy requests
func (s *Service) handlePanic(ctx context.Context, w http.ResponseWriter, r *http.Request, stats *ports.RequestStats, rlog logger.StyledLogger, rec interface{}, err *error) {
s.RecordFailure(ctx, nil, time.Since(stats.StartTime), fmt.Errorf("panic: %v", rec))
Expand Down Expand Up @@ -483,7 +379,7 @@ func (s *Service) prepareProxyRequest(ctx context.Context, r *http.Request, targ
stats.HeaderProcessingMs = time.Since(headerStart).Milliseconds()

// Add model header
if model, ok := ctx.Value("model").(string); ok && model != "" {
if model, ok := ctx.Value(constants.ContextModelKey).(string); ok && model != "" {
proxyReq.Header.Set("X-Model", model)
stats.Model = model
}
Expand Down Expand Up @@ -777,29 +673,31 @@ func (s *Service) cleanupUnusedResources() {
}
}

// Cleanup cleans up resources
// Cleanup cleans up resources. Safe to call more than once.
func (s *Service) Cleanup() {
// Stop cleanup goroutine
if s.cleanupStop != nil {
close(s.cleanupStop)
}
if s.cleanupTicker != nil {
s.cleanupTicker.Stop()
}
s.cleanupOnce.Do(func() {
// Stop cleanup goroutine
if s.cleanupStop != nil {
close(s.cleanupStop)
}
if s.cleanupTicker != nil {
s.cleanupTicker.Stop()
}

// Close all endpoint pools
s.endpointPools.Range(func(key string, pool *connectionPool) bool {
pool.transport.CloseIdleConnections()
return true
})
// Close all endpoint pools
s.endpointPools.Range(func(key string, pool *connectionPool) bool {
pool.transport.CloseIdleConnections()
return true
})

s.endpointPools.Clear()
s.circuitBreakers.Clear()
s.endpointPools.Clear()
s.circuitBreakers.Clear()

s.BaseProxyComponents.Shutdown()
s.BaseProxyComponents.Shutdown()

// force GC to clean up
runtime.GC()
// force GC to clean up
runtime.GC()

s.Logger.Debug("Olla proxy service cleaned up")
s.Logger.Debug("Olla proxy service cleaned up")
})
}
23 changes: 23 additions & 0 deletions internal/adapter/proxy/olla/service_leak_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,29 @@ func (m *mockStatsCollector) GetEndpointStats() map[string]ports.EndpointStats {
func (m *mockStatsCollector) GetSecurityStats() ports.SecurityStats { return ports.SecurityStats{} }
func (m *mockStatsCollector) GetConnectionStats() map[string]int64 { return nil }

// TestCleanup_DoubleInvoke verifies that calling Cleanup twice does not panic.
// Previously, the second call would close an already-closed channel.
func TestCleanup_DoubleInvoke(t *testing.T) {
t.Parallel()

s := &Service{
BaseProxyComponents: &core.BaseProxyComponents{
Logger: createTestLogger(),
},
configuration: &Configuration{},
endpointPools: *xsync.NewMap[string, *connectionPool](),
circuitBreakers: *xsync.NewMap[string, *circuitBreaker](),
cleanupTicker: time.NewTicker(time.Hour),
cleanupStop: make(chan struct{}),
}

go s.cleanupLoop()

// Neither call should panic.
s.Cleanup()
s.Cleanup()
}

func createTestLogger() logger.StyledLogger {
loggerCfg := &logger.Config{Level: "error", Theme: "default"}
log, _, _ := logger.New(loggerCfg)
Expand Down
4 changes: 2 additions & 2 deletions internal/adapter/proxy/proxy_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestProxyResponseHeaders(t *testing.T) {
// Test with model in context
t.Run("with model", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "model", "llama3.2:3b")
ctx := context.WithValue(req.Context(), constants.ContextModelKey, "llama3.2:3b")
req = req.WithContext(ctx)

w := httptest.NewRecorder()
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestProxyResponseHeaders_NoOverride(t *testing.T) {
proxy, _ := sherpa.NewService(discovery, selector, config, createTestStatsCollector(), nil, createTestLogger())

req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "model", "real-model")
ctx := context.WithValue(req.Context(), constants.ContextModelKey, "real-model")
req = req.WithContext(ctx)

w := httptest.NewRecorder()
Expand Down
2 changes: 1 addition & 1 deletion internal/adapter/proxy/sherpa/service_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit
stats.HeaderProcessingMs = time.Since(headerStart).Milliseconds()

// Add model header if available
if model, ok := ctx.Value("model").(string); ok && model != "" {
if model, ok := ctx.Value(constants.ContextModelKey).(string); ok && model != "" {
proxyReq.Header.Set(constants.HeaderXModel, model)
stats.Model = model
}
Expand Down
2 changes: 1 addition & 1 deletion internal/app/handlers/handler_translation.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func (a *Application) executeTranslatedNonStreamingRequest(
// prepareProxyContext sets up context with model, routing decision, and alias rewrite map
func (a *Application) prepareProxyContext(ctx context.Context, r *http.Request, pr *proxyRequest) (context.Context, *http.Request) {
if pr.model != "" {
ctx = context.WithValue(ctx, "model", pr.model)
ctx = context.WithValue(ctx, constants.ContextModelKey, pr.model)
r = r.WithContext(ctx)
}

Expand Down
4 changes: 2 additions & 2 deletions internal/app/handlers/handler_translation_alias_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestPrepareProxyContext_NoAliasMapWhenNoneStored(t *testing.T) {
assert.Nil(t, rawMap, "alias rewrite map should not be present for non-alias requests")

// Model should still be set in context
assert.Equal(t, "llama3.1:8b", r.Context().Value("model"))
assert.Equal(t, "llama3.1:8b", r.Context().Value(constants.ContextModelKey))
}

func TestPrepareProxyContext_NilProfile(t *testing.T) {
Expand All @@ -100,5 +100,5 @@ func TestPrepareProxyContext_NilProfile(t *testing.T) {
rawMap := r.Context().Value(constants.ContextModelAliasMapKey)
assert.Nil(t, rawMap, "alias rewrite map should not be present when profile is nil")

assert.Equal(t, "llama3.1:8b", r.Context().Value("model"))
assert.Equal(t, "llama3.1:8b", r.Context().Value(constants.ContextModelKey))
}
5 changes: 5 additions & 0 deletions internal/core/constants/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ const (
ContextKeyStream = "stream" // indicates whether the response should be streamed or buffered
ContextProviderTypeKey = "provider_type" // the provider type for the request, used for routing and load balancing

// ContextModelKey carries the resolved model name through the proxy pipeline.
// Using a typed key prevents accidental collisions with plain-string keys from
// third-party middleware that might also use "model".
ContextModelKey = contextKey("model")

// Sticky session context keys — set by the handler before balancer selection
// and read back after to surface affinity decisions in response headers.
ContextStickyKeyKey = contextKey("sticky-key") // computed affinity key for this request
Expand Down
Loading