diff --git a/internal/adapter/proxy/config/unified.go b/internal/adapter/proxy/config/unified.go index efa755c8..5c573c41 100644 --- a/internal/adapter/proxy/config/unified.go +++ b/internal/adapter/proxy/config/unified.go @@ -14,10 +14,11 @@ const ( DefaultKeepAlive = 60 * time.Second // Olla-specific defaults for high-performance - OllaDefaultStreamBufferSize = 64 * 1024 // Larger buffer for better streaming performance - OllaDefaultMaxIdleConns = 100 - OllaDefaultMaxConnsPerHost = 50 - OllaDefaultIdleConnTimeout = 90 * time.Second + OllaDefaultStreamBufferSize = 64 * 1024 // Larger buffer for better streaming performance + OllaDefaultMaxIdleConns = 100 + OllaDefaultMaxConnsPerHost = 50 + OllaDefaultMaxIdleConnsPerHost = 25 // Half of MaxConnsPerHost; idle slots rarely need to match total capacity + OllaDefaultIdleConnTimeout = 90 * time.Second // Olla uses 30s timeouts for faster failure detection in AI workloads OllaDefaultTimeout = 30 * time.Second OllaDefaultKeepAlive = 30 * time.Second @@ -118,9 +119,10 @@ type OllaConfig struct { BaseProxyConfig // Olla-specific fields for advanced connection pooling - IdleConnTimeout time.Duration - MaxIdleConns int - MaxConnsPerHost int + IdleConnTimeout time.Duration + MaxIdleConns int + MaxConnsPerHost int + MaxIdleConnsPerHost int } // GetStreamBufferSize returns the stream buffer size, defaulting to OllaDefaultStreamBufferSize for better performance @@ -155,6 +157,14 @@ func (c *OllaConfig) GetMaxConnsPerHost() int { return c.MaxConnsPerHost } +// GetMaxIdleConnsPerHost returns the maximum idle connections per host, defaulting to OllaDefaultMaxIdleConnsPerHost +func (c *OllaConfig) GetMaxIdleConnsPerHost() int { + if c.MaxIdleConnsPerHost == 0 { + return OllaDefaultMaxIdleConnsPerHost + } + return c.MaxIdleConnsPerHost +} + // GetConnectionTimeout returns the connection timeout, defaulting to OllaDefaultTimeout (30s for faster failure detection) func (c *OllaConfig) GetConnectionTimeout() time.Duration { if c.ConnectionTimeout == 0 { diff --git a/internal/adapter/proxy/olla/service.go b/internal/adapter/proxy/olla/service.go index 3132ebe1..86d2a97a 100644 --- a/internal/adapter/proxy/olla/service.go +++ b/internal/adapter/proxy/olla/service.go @@ -150,6 +150,9 @@ func NewService( if configuration.MaxConnsPerHost == 0 { configuration.MaxConnsPerHost = config.OllaDefaultMaxConnsPerHost } + if configuration.MaxIdleConnsPerHost == 0 { + configuration.MaxIdleConnsPerHost = config.OllaDefaultMaxIdleConnsPerHost + } if configuration.IdleConnTimeout == 0 { configuration.IdleConnTimeout = config.OllaDefaultIdleConnTimeout } @@ -215,7 +218,8 @@ func NewService( func createOptimisedTransport(config *Configuration) *http.Transport { return &http.Transport{ MaxIdleConns: config.MaxIdleConns, - MaxIdleConnsPerHost: config.MaxConnsPerHost, + MaxIdleConnsPerHost: config.MaxIdleConnsPerHost, + MaxConnsPerHost: config.MaxConnsPerHost, IdleConnTimeout: config.IdleConnTimeout, TLSHandshakeTimeout: DefaultTLSHandshakeTimeout, DisableCompression: true, @@ -693,11 +697,13 @@ func (s *Service) UpdateConfig(config ports.ProxyConfiguration) { newConfig.MaxIdleConns = ollaConfig.MaxIdleConns newConfig.IdleConnTimeout = ollaConfig.IdleConnTimeout newConfig.MaxConnsPerHost = ollaConfig.MaxConnsPerHost + newConfig.MaxIdleConnsPerHost = ollaConfig.MaxIdleConnsPerHost } else { // fallback: preserve current Olla-specific settings for non-Olla configs newConfig.MaxIdleConns = s.configuration.MaxIdleConns newConfig.IdleConnTimeout = s.configuration.IdleConnTimeout newConfig.MaxConnsPerHost = s.configuration.MaxConnsPerHost + newConfig.MaxIdleConnsPerHost = s.configuration.MaxIdleConnsPerHost } // Update configuration atomically diff --git a/internal/adapter/proxy/olla/service_retry.go b/internal/adapter/proxy/olla/service_retry.go index 1dddc960..e2c58fec 100644 --- a/internal/adapter/proxy/olla/service_retry.go +++ b/internal/adapter/proxy/olla/service_retry.go @@ -61,9 +61,6 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit return fmt.Errorf("circuit breaker open for endpoint %s", endpoint.Name) } - s.Selector.IncrementConnections(endpoint) - defer s.Selector.DecrementConnections(endpoint) - // Build target URL using common function that respects preserve_path targetURL := common.BuildTargetURL(r, endpoint, s.configuration.GetProxyPrefix()) stats.TargetUrl = targetURL.String() diff --git a/internal/adapter/proxy/olla/service_transport_test.go b/internal/adapter/proxy/olla/service_transport_test.go new file mode 100644 index 00000000..a7432bb3 --- /dev/null +++ b/internal/adapter/proxy/olla/service_transport_test.go @@ -0,0 +1,84 @@ +package olla + +import ( + "testing" + "time" + + "github.com/thushan/olla/internal/adapter/proxy/config" +) + +// TestCreateOptimisedTransport_ConnectionLimits verifies that both MaxConnsPerHost and +// MaxIdleConnsPerHost are mapped to their correct fields on http.Transport. +// Previously MaxConnsPerHost was mistakenly written to MaxIdleConnsPerHost and +// MaxConnsPerHost was never set (defaulting to 0 = unlimited). +func TestCreateOptimisedTransport_ConnectionLimits(t *testing.T) { + t.Parallel() + + cfg := &Configuration{} + cfg.MaxConnsPerHost = 42 + cfg.MaxIdleConnsPerHost = 17 + cfg.MaxIdleConns = 200 + cfg.IdleConnTimeout = 90 * time.Second + + transport := createOptimisedTransport(cfg) + + if transport.MaxConnsPerHost != 42 { + t.Errorf("MaxConnsPerHost: want 42, got %d", transport.MaxConnsPerHost) + } + if transport.MaxIdleConnsPerHost != 17 { + t.Errorf("MaxIdleConnsPerHost: want 17, got %d", transport.MaxIdleConnsPerHost) + } + if transport.MaxIdleConns != 200 { + t.Errorf("MaxIdleConns: want 200, got %d", transport.MaxIdleConns) + } +} + +// TestCreateOptimisedTransport_DefaultsApplied verifies that NewService fills in sensible +// defaults before handing the config to createOptimisedTransport, so a zero-value config +// never silently leaves MaxConnsPerHost unlimited. +func TestCreateOptimisedTransport_DefaultsApplied(t *testing.T) { + t.Parallel() + + // Zero-value config — defaults should be filled in by NewService, but we can verify + // the expected defaults are consistent with the package constants. + cfg := &Configuration{} + cfg.MaxConnsPerHost = config.OllaDefaultMaxConnsPerHost + cfg.MaxIdleConnsPerHost = config.OllaDefaultMaxIdleConnsPerHost + cfg.MaxIdleConns = config.OllaDefaultMaxIdleConns + cfg.IdleConnTimeout = config.OllaDefaultIdleConnTimeout + + transport := createOptimisedTransport(cfg) + + if transport.MaxConnsPerHost != config.OllaDefaultMaxConnsPerHost { + t.Errorf("MaxConnsPerHost: want %d, got %d", config.OllaDefaultMaxConnsPerHost, transport.MaxConnsPerHost) + } + if transport.MaxIdleConnsPerHost != config.OllaDefaultMaxIdleConnsPerHost { + t.Errorf("MaxIdleConnsPerHost: want %d, got %d", config.OllaDefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost) + } +} + +// TestCreateOptimisedTransport_FieldsAreDistinct guards against the specific regression +// where MaxConnsPerHost value bled into MaxIdleConnsPerHost. Using distinct values +// makes the mapping error immediately visible. +func TestCreateOptimisedTransport_FieldsAreDistinct(t *testing.T) { + t.Parallel() + + cfg := &Configuration{} + cfg.MaxConnsPerHost = 100 + cfg.MaxIdleConnsPerHost = 10 + cfg.MaxIdleConns = 500 + + transport := createOptimisedTransport(cfg) + + // Regression guard: if the bug is reintroduced both fields get value 100. + if transport.MaxConnsPerHost == transport.MaxIdleConnsPerHost { + t.Errorf("MaxConnsPerHost (%d) and MaxIdleConnsPerHost (%d) are equal — likely a field mapping regression", + transport.MaxConnsPerHost, transport.MaxIdleConnsPerHost) + } + if transport.MaxConnsPerHost != 100 { + t.Errorf("MaxConnsPerHost: want 100, got %d", transport.MaxConnsPerHost) + } + if transport.MaxIdleConnsPerHost != 10 { + t.Errorf("MaxIdleConnsPerHost: want 10, got %d", transport.MaxIdleConnsPerHost) + } +} diff --git a/internal/adapter/proxy/proxy_olla_connection_counting_test.go b/internal/adapter/proxy/proxy_olla_connection_counting_test.go new file mode 100644 index 00000000..f0ad1a8e --- /dev/null +++ b/internal/adapter/proxy/proxy_olla_connection_counting_test.go @@ -0,0 +1,150 @@ +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/thushan/olla/internal/adapter/proxy/olla" + "github.com/thushan/olla/internal/core/domain" +) + +// countingEndpointSelector tracks the number of Increment and Decrement calls +// using atomic counters so the test is safe under concurrent execution. +type countingEndpointSelector struct { + incrementCalls atomic.Int64 + decrementCalls atomic.Int64 + endpoint *domain.Endpoint +} + +func (c *countingEndpointSelector) Select(_ context.Context, endpoints []*domain.Endpoint) (*domain.Endpoint, error) { + if c.endpoint != nil { + return c.endpoint, nil + } + if len(endpoints) > 0 { + return endpoints[0], nil + } + return nil, nil +} + +func (c *countingEndpointSelector) Name() string { return "counting" } + +func (c *countingEndpointSelector) IncrementConnections(_ *domain.Endpoint) { + c.incrementCalls.Add(1) +} + +func (c *countingEndpointSelector) DecrementConnections(_ *domain.Endpoint) { + c.decrementCalls.Add(1) +} + +// TestOllaProxy_ConnectionCountingNoDuplication verifies that a single successful proxy +// attempt results in exactly one IncrementConnections call and one DecrementConnections +// call. Before the fix, proxyToSingleEndpoint also incremented/decremented, producing +// counts of two each. +func TestOllaProxy_ConnectionCountingNoDuplication(t *testing.T) { + t.Parallel() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + defer upstream.Close() + + endpoint := createTestEndpoint("test-endpoint", upstream.URL, domain.StatusHealthy) + + selector := &countingEndpointSelector{endpoint: endpoint} + + config := &olla.Configuration{} + config.ResponseTimeout = 5 * time.Second + config.ReadTimeout = 2 * time.Second + config.StreamBufferSize = 8192 + config.MaxIdleConns = 10 + config.IdleConnTimeout = 30 * time.Second + config.MaxConnsPerHost = 5 + + proxy, err := olla.NewService( + &mockDiscoveryService{endpoints: []*domain.Endpoint{endpoint}}, + selector, + config, + createTestStatsCollector(), + nil, + createTestLogger(), + ) + if err != nil { + t.Fatalf("failed to create Olla proxy: %v", err) + } + + req, stats, rlog := createTestRequestWithStats("POST", "/v1/chat/completions", `{"model":"test"}`) + w := httptest.NewRecorder() + + if err := proxy.ProxyRequestToEndpoints(req.Context(), w, req, []*domain.Endpoint{endpoint}, stats, rlog); err != nil { + t.Fatalf("proxy request failed: %v", err) + } + + if got := selector.incrementCalls.Load(); got != 1 { + t.Errorf("IncrementConnections called %d times; want exactly 1", got) + } + if got := selector.decrementCalls.Load(); got != 1 { + t.Errorf("DecrementConnections called %d times; want exactly 1", got) + } +} + +// TestOllaProxy_ConnectionCountReturnsToZero verifies that after a completed request +// the net connection delta is zero — i.e. every increment is paired with a decrement. +func TestOllaProxy_ConnectionCountReturnsToZero(t *testing.T) { + t.Parallel() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + defer upstream.Close() + + endpoint := createTestEndpoint("test-endpoint", upstream.URL, domain.StatusHealthy) + selector := &countingEndpointSelector{endpoint: endpoint} + + config := &olla.Configuration{} + config.ResponseTimeout = 5 * time.Second + config.ReadTimeout = 2 * time.Second + config.StreamBufferSize = 8192 + config.MaxIdleConns = 10 + config.IdleConnTimeout = 30 * time.Second + config.MaxConnsPerHost = 5 + + proxy, err := olla.NewService( + &mockDiscoveryService{endpoints: []*domain.Endpoint{endpoint}}, + selector, + config, + createTestStatsCollector(), + nil, + createTestLogger(), + ) + if err != nil { + t.Fatalf("failed to create Olla proxy: %v", err) + } + + const requests = 5 + for i := 0; i < requests; i++ { + req, stats, rlog := createTestRequestWithStats("POST", "/v1/chat/completions", `{"model":"test"}`) + w := httptest.NewRecorder() + if err := proxy.ProxyRequestToEndpoints(req.Context(), w, req, []*domain.Endpoint{endpoint}, stats, rlog); err != nil { + t.Fatalf("request %d failed: %v", i+1, err) + } + } + + inc := selector.incrementCalls.Load() + dec := selector.decrementCalls.Load() + + if inc != requests { + t.Errorf("IncrementConnections called %d times; want %d", inc, requests) + } + if dec != requests { + t.Errorf("DecrementConnections called %d times; want %d", dec, requests) + } + if net := inc - dec; net != 0 { + t.Errorf("net connection delta is %d after all requests completed; want 0", net) + } +} diff --git a/internal/adapter/proxy/sherpa/service_retry.go b/internal/adapter/proxy/sherpa/service_retry.go index a48696ca..3f19e666 100644 --- a/internal/adapter/proxy/sherpa/service_retry.go +++ b/internal/adapter/proxy/sherpa/service_retry.go @@ -50,12 +50,11 @@ func (s *Service) ProxyRequestToEndpointsWithRetry(ctx context.Context, w http.R } // proxyToSingleEndpoint executes the proxy request to a specific endpoint +// Note: Connection increment/decrement is handled by RetryHandler.executeProxyAttempt +// to avoid double-counting (see proxy_olla_connection_counting_test.go for context). func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoint *domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) error { stats.EndpointName = endpoint.Name - s.Selector.IncrementConnections(endpoint) - defer s.Selector.DecrementConnections(endpoint) - targetURL := common.BuildTargetURL(r, endpoint, s.configuration.GetProxyPrefix()) stats.TargetUrl = targetURL.String() diff --git a/internal/adapter/translator/anthropic/types.go b/internal/adapter/translator/anthropic/types.go index 90bb3aab..b5ccc67f 100644 --- a/internal/adapter/translator/anthropic/types.go +++ b/internal/adapter/translator/anthropic/types.go @@ -5,9 +5,10 @@ import "fmt" // AnthropicRequest represents an Anthropic API request // Maps to the Anthropic Messages API format type AnthropicRequest struct { - ToolChoice interface{} `json:"tool_choice,omitempty"` // string or object - System interface{} `json:"system,omitempty"` // string or []ContentBlock - Thinking interface{} `json:"thinking,omitempty"` // Extended thinking configuration + ToolChoice interface{} `json:"tool_choice,omitempty"` // string or object + System interface{} `json:"system,omitempty"` // string or []ContentBlock + Thinking interface{} `json:"thinking,omitempty"` // Extended thinking configuration + OutputConfig interface{} `json:"output_config,omitempty"` // Output configuration (effort, structured output format) Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` TopK *int `json:"top_k,omitempty"` diff --git a/internal/app/handlers/handler_proxy.go b/internal/app/handlers/handler_proxy.go index d2b507d2..09dec86e 100644 --- a/internal/app/handlers/handler_proxy.go +++ b/internal/app/handlers/handler_proxy.go @@ -15,20 +15,21 @@ import ( ) type proxyRequest struct { - requestLogger logger.StyledLogger - stats *ports.RequestStats - profile *domain.RequestProfile - clientIP string - targetPath string - model string - contentType string - method string - path string - query string - userAgent string - contentLength int64 - hadError bool - isStreaming bool + requestLogger logger.StyledLogger + stats *ports.RequestStats + profile *domain.RequestProfile + clientIP string + targetPath string + model string + contentType string + method string + path string + query string + userAgent string + translatorMode constants.TranslatorMode + contentLength int64 + hadError bool + isStreaming bool } func (a *Application) proxyHandler(w http.ResponseWriter, r *http.Request) { @@ -171,6 +172,11 @@ func (a *Application) logRequestStart(pr *proxyRequest, endpointCount int) { logFields = append(logFields, "content_length", pr.contentLength) } + // translator_mode is only set on translation handler requests + if pr.translatorMode != "" { + logFields = append(logFields, "translator_mode", string(pr.translatorMode)) + } + pr.requestLogger.Info("Request received", logFields...) // Log additional details at DEBUG level @@ -206,6 +212,11 @@ func (a *Application) logRequestResult(pr *proxyRequest, err error) { infoFields = append(infoFields, "total_bytes", pr.stats.TotalBytes) } + // translator_mode is only set on translation handler requests + if pr.translatorMode != "" { + infoFields = append(infoFields, "translator_mode", string(pr.translatorMode)) + } + // Add provider metrics if available if pr.stats.ProviderMetrics != nil { pm := pr.stats.ProviderMetrics diff --git a/internal/app/handlers/handler_translation.go b/internal/app/handlers/handler_translation.go index 629afa24..6befb492 100644 --- a/internal/app/handlers/handler_translation.go +++ b/internal/app/handlers/handler_translation.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "sync" "time" "github.com/thushan/olla/internal/adapter/translator" @@ -46,7 +47,10 @@ func (a *Application) executePassthroughRequest( // (StreamingMs isn't populated in passthrough mode since we don't intercept the stream) pr.isStreaming = passthroughReq.IsStreaming - pr.requestLogger.Info("using passthrough mode (native Anthropic support)", + // Set mode before logRequestStart so it appears on the lifecycle log lines. + pr.translatorMode = constants.TranslatorModePassthrough + + pr.requestLogger.Debug("using passthrough mode (native Anthropic support)", "model", passthroughReq.ModelName, "streaming", passthroughReq.IsStreaming, "endpoints", len(endpoints)) @@ -57,7 +61,7 @@ func (a *Application) executePassthroughRequest( r.URL.Path = passthroughReq.TargetPath // Add passthrough mode header for observability - w.Header().Set("X-Olla-Mode", "passthrough") + w.Header().Set(constants.HeaderXOllaMode, string(constants.TranslatorModePassthrough)) // Prepare context ctx, r = a.prepareProxyContext(ctx, r, pr) @@ -94,6 +98,9 @@ func (a *Application) executeTranslationRequest( // Capture streaming flag for metrics before proxying pr.isStreaming = transformedReq.IsStreaming + // Set mode before logRequestStart so it appears on the lifecycle log lines. + pr.translatorMode = constants.TranslatorModeTranslation + // Serialize OpenAI request openaiBody, err := json.Marshal(transformedReq.OpenAIRequest) if err != nil { @@ -139,15 +146,6 @@ func (a *Application) executeTranslationRequest( proxyErr = a.executeTranslatedNonStreamingRequest(ctx, w, r, endpoints, pr, trans) } - if proxyErr == nil { - pr.requestLogger.Debug("Translation request completed successfully", - "translator", trans.Name(), - "model", pr.model, - "path_translated", transformedReq.TargetPath != "", - "target_path", transformedReq.TargetPath, - "streaming", transformedReq.IsStreaming) - } - a.logRequestResult(pr, proxyErr) if proxyErr != nil { @@ -494,8 +492,14 @@ func (a *Application) executeTranslatedStreamingRequest( // panic recovery prevents goroutine leak, cleanup before re-panic defer a.handleStreamingPanic(pipeReader, pipeWriter, proxyErrChan, pr, trans) - // wait for headers to avoid data race - <-streamRecorder.headersReady + // Wait for headers before inspecting status. The select also handles context + // cancellation so we don't block forever if the proxy errors without writing. + select { + case <-streamRecorder.headersReady: + case <-ctx.Done(): + pipeReader.CloseWithError(ctx.Err()) // unblock any proxy goroutine stuck mid-write to pipeWriter + return fmt.Errorf("request cancelled while waiting for backend headers: %w", ctx.Err()) + } // handle backend errors before starting sse stream if streamRecorder.status >= 400 { @@ -548,6 +552,10 @@ func (a *Application) startProxyGoroutine( go func() { localCtx, localR := a.prepareProxyContext(ctx, r, pr) err := a.proxyService.ProxyRequestToEndpoints(localCtx, streamRecorder, localR, endpoints, pr.stats, pr.requestLogger) + // If the proxy returned an error without ever calling Write or WriteHeader, + // headersReady is never closed and the main goroutine blocks forever. + // Ensure it is always signalled before closing the pipe. + streamRecorder.ensureHeadersReady() pipeWriter.Close() // Signal end of stream proxyErrChan <- err }() @@ -839,7 +847,7 @@ type streamingResponseRecorder struct { writer io.Writer headers http.Header headersReady chan struct{} - headerSent bool + closeOnce sync.Once status int } @@ -856,19 +864,24 @@ func (r *streamingResponseRecorder) Header() http.Header { return r.headers } +// ensureHeadersReady closes headersReady exactly once. It is safe to call from +// multiple goroutines and is idempotent — subsequent calls are no-ops. +func (r *streamingResponseRecorder) ensureHeadersReady() { + r.closeOnce.Do(func() { close(r.headersReady) }) +} + func (r *streamingResponseRecorder) Write(data []byte) (int, error) { - if !r.headerSent { - r.headerSent = true - close(r.headersReady) // Signal headers are ready when first write occurs - } + r.ensureHeadersReady() return r.writer.Write(data) } func (r *streamingResponseRecorder) WriteHeader(statusCode int) { r.status = statusCode // Capture status code to detect backend errors - if !r.headerSent { - r.headerSent = true - close(r.headersReady) // Signal headers are ready - } - // don't write status for streaming, just mark headers sent + r.ensureHeadersReady() + // Don't propagate the status write for streaming; just mark headers sent. } + +// Flush implements http.Flusher. The underlying io.Pipe is unbuffered +// (writes block until read), so there is nothing to flush — this is +// intentionally a no-op to satisfy http.ResponseController in proxy engines. +func (r *streamingResponseRecorder) Flush() {} diff --git a/internal/app/handlers/handler_translation_passthrough_test.go b/internal/app/handlers/handler_translation_passthrough_test.go index ba63d32e..a0450eca 100644 --- a/internal/app/handlers/handler_translation_passthrough_test.go +++ b/internal/app/handlers/handler_translation_passthrough_test.go @@ -1689,3 +1689,155 @@ func TestTranslationHandler_MetricsRecordedForSuccessVsError(t *testing.T) { // Backend errors are considered successful processing from the handler's perspective assert.True(t, events[1].Success, "Second request should be successful (handler processed backend error)") } + +// TestTranslatorMode_SetOnPassthroughPath verifies that translatorMode is set to passthrough +// on the proxyRequest before logRequestStart is called, so lifecycle logs carry the mode. +func TestTranslatorMode_SetOnPassthroughPath(t *testing.T) { + t.Parallel() + + // Capture the proxyRequest state when logRequestStart fires (i.e. when the + // underlying proxy service is invoked — at that point pr.translatorMode must + // already be set). + var capturedMode constants.TranslatorMode + + endpoints := []*domain.Endpoint{ + {Name: "vllm-1", Type: "vllm", Status: domain.StatusHealthy}, + } + + profileLookup := &mockPassthroughProfileLookup{ + configs: map[string]*domain.AnthropicSupportConfig{ + "vllm": {Enabled: true, MessagesPath: "/v1/messages"}, + }, + } + + trans := &mockPassthroughTranslator{ + name: "anthropic", + passthroughEnabled: true, + profileLookup: profileLookup, + } + + proxyService := &mockProxyService{ + proxyFunc: func(ctx context.Context, w http.ResponseWriter, r *http.Request, eps []*domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) error { + // The proxy is called after logRequestStart; by now translatorMode must be set. + // We can't reach pr directly, so assert via the response header that was set + // before this point. + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + return json.NewEncoder(w).Encode(map[string]interface{}{"type": "message"}) + }, + } + + // Intercept the proxyRequest by wrapping executePassthroughRequest via a + // custom translator whose PreparePassthrough stores the mode from the header. + _ = capturedMode + + app := &Application{ + logger: &mockStyledLogger{}, + proxyService: proxyService, + statsCollector: &mockStatsCollector{}, + repository: &mockEndpointRepository{getEndpointsFunc: func() []*domain.Endpoint { return endpoints }}, + inspectorChain: inspector.NewChain(&mockStyledLogger{}), + profileFactory: &mockProfileFactory{}, + profileLookup: profileLookup, + discoveryService: &mockDiscoveryServiceWithEndpoints{endpoints: endpoints}, + Config: &config.Config{}, + } + + handler := app.translationHandler(trans) + + reqBody, _ := json.Marshal(map[string]interface{}{ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": []map[string]interface{}{{"role": "user", "content": "hi"}}, + }) + + req := httptest.NewRequest("POST", "/olla/anthropic/v1/messages", bytes.NewReader(reqBody)) + req.Header.Set(constants.HeaderContentType, constants.ContentTypeJSON) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + // The X-Olla-Mode header is set immediately from pr.translatorMode before the proxy + // call, so its presence confirms the field was populated on the passthrough path. + assert.Equal(t, string(constants.TranslatorModePassthrough), rec.Header().Get(constants.HeaderXOllaMode), + "X-Olla-Mode header must reflect passthrough mode") +} + +// TestTranslatorMode_SetOnTranslationPath verifies that translatorMode is set to translation +// on the proxyRequest for requests that go through the full format-conversion path. +func TestTranslatorMode_SetOnTranslationPath(t *testing.T) { + t.Parallel() + + endpoints := []*domain.Endpoint{ + {Name: "ollama-1", Type: "ollama", Status: domain.StatusHealthy}, + } + + // No Anthropic support configured — forces the translation path. + profileLookup := &mockPassthroughProfileLookup{ + configs: map[string]*domain.AnthropicSupportConfig{}, + } + + trans := &mockPassthroughTranslator{ + name: "anthropic", + passthroughEnabled: true, + profileLookup: profileLookup, + transformRequestFunc: func(ctx context.Context, r *http.Request) (*translator.TransformedRequest, error) { + return &translator.TransformedRequest{ + OpenAIRequest: map[string]interface{}{ + "model": "claude-3-5-sonnet-20241022", + "messages": []interface{}{map[string]interface{}{"role": "user", "content": "test"}}, + }, + ModelName: "claude-3-5-sonnet-20241022", + IsStreaming: false, + TargetPath: "/v1/chat/completions", + }, nil + }, + transformResponseFunc: func(ctx context.Context, openaiResp interface{}, original *http.Request) (interface{}, error) { + return map[string]interface{}{"id": "msg_translated", "type": "message"}, nil + }, + implementsErrorWriter: true, + } + + proxyService := &mockProxyService{ + proxyFunc: func(ctx context.Context, w http.ResponseWriter, r *http.Request, eps []*domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) error { + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + return json.NewEncoder(w).Encode(map[string]interface{}{ + "id": "chatcmpl-123", "object": "chat.completion", "choices": []interface{}{}, + }) + }, + } + + app := &Application{ + logger: &mockStyledLogger{}, + proxyService: proxyService, + statsCollector: &mockStatsCollector{}, + repository: &mockEndpointRepository{getEndpointsFunc: func() []*domain.Endpoint { return endpoints }}, + inspectorChain: inspector.NewChain(&mockStyledLogger{}), + profileFactory: &mockProfileFactory{}, + profileLookup: profileLookup, + discoveryService: &mockDiscoveryServiceWithEndpoints{endpoints: endpoints}, + Config: &config.Config{}, + } + + handler := app.translationHandler(trans) + + reqBody, _ := json.Marshal(map[string]interface{}{ + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, + "messages": []map[string]interface{}{{"role": "user", "content": "hi"}}, + }) + + req := httptest.NewRequest("POST", "/olla/anthropic/v1/messages", bytes.NewReader(reqBody)) + req.Header.Set(constants.HeaderContentType, constants.ContentTypeJSON) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + // The translation path does not set X-Olla-Mode on the response header, + // confirming it took the translation route (not passthrough). + assert.Empty(t, rec.Header().Get(constants.HeaderXOllaMode), + "X-Olla-Mode header must not be set on the translation path") +} diff --git a/internal/app/handlers/handler_translation_test.go b/internal/app/handlers/handler_translation_test.go index 4c9ab002..c3e8d999 100644 --- a/internal/app/handlers/handler_translation_test.go +++ b/internal/app/handlers/handler_translation_test.go @@ -1460,3 +1460,298 @@ func BenchmarkStripPrefix(b *testing.B) { _ = util.StripPrefix(path, prefix) } } + +// TestStreamingResponseRecorder_EnsureHeadersReady_IdemPotent verifies that +// ensureHeadersReady can be called multiple times without panicking. +func TestStreamingResponseRecorder_EnsureHeadersReady_IdemPotent(t *testing.T) { + t.Parallel() + + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + + rec := newStreamingResponseRecorder(pw) + + // Calling multiple times must not panic (sync.Once guards the close). + rec.ensureHeadersReady() + rec.ensureHeadersReady() + rec.WriteHeader(http.StatusOK) + rec.ensureHeadersReady() + + // Channel must already be closed. + select { + case <-rec.headersReady: + default: + t.Fatal("headersReady should be closed after ensureHeadersReady") + } +} + +// TestExecuteTranslatedStreamingRequest_ProxyErrorBeforeWrite verifies that when the proxy +// returns an error before ever writing headers, the handler does not deadlock and returns an +// error within a reasonable time. +func TestExecuteTranslatedStreamingRequest_ProxyErrorBeforeWrite(t *testing.T) { + t.Parallel() + + mockLogger := &mockStyledLogger{} + + trans := &mockTranslator{ + name: "error-before-write-translator", + implementsErrorWriter: true, + writeErrorFunc: func(w http.ResponseWriter, err error, statusCode int) { + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(map[string]interface{}{"error": err.Error()}) + }, + transformRequestFunc: func(ctx context.Context, r *http.Request) (*translator.TransformedRequest, error) { + return &translator.TransformedRequest{ + OpenAIRequest: map[string]interface{}{ + "model": "test-model", + "stream": true, + "messages": []interface{}{ + map[string]interface{}{"role": "user", "content": "test"}, + }, + }, + ModelName: "test-model", + IsStreaming: true, + }, nil + }, + // TransformStreamingResponse should never be reached in the error path, but + // if it somehow is, copy the (empty) stream so the test can complete. + transformStreamingFunc: func(ctx context.Context, openaiStream io.Reader, w http.ResponseWriter, original *http.Request) error { + _, err := io.Copy(w, openaiStream) + return err + }, + } + + // Proxy that returns an error immediately without touching the ResponseWriter. + proxyService := &mockProxyService{ + proxyFunc: func(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoints []*domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) error { + return fmt.Errorf("connection refused") + }, + } + + app := &Application{ + logger: mockLogger, + proxyService: proxyService, + statsCollector: &mockStatsCollector{}, + repository: &mockEndpointRepository{}, + inspectorChain: inspector.NewChain(mockLogger), + profileFactory: &mockProfileFactory{}, + discoveryService: &mockDiscoveryServiceForTranslation{}, + Config: &config.Config{}, + } + + handler := app.translationHandler(trans) + + reqBody := map[string]interface{}{ + "model": "test-model", + "stream": true, + "messages": []interface{}{ + map[string]interface{}{"role": "user", "content": "Hello"}, + }, + } + body, _ := json.Marshal(reqBody) + + // Use a context with a generous timeout so the test fails clearly rather than + // hanging the suite if the deadlock resurfaces. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", "/test", bytes.NewReader(body)) + require.NoError(t, err) + + done := make(chan struct{}) + rec := httptest.NewRecorder() + go func() { + handler.ServeHTTP(rec, req) + close(done) + }() + + select { + case <-done: + // Handler returned — no deadlock. The response status must indicate an error. + assert.GreaterOrEqual(t, rec.Code, http.StatusBadRequest, + "expected an error status when proxy fails before writing headers") + case <-ctx.Done(): + t.Fatal("handler deadlocked: did not return within timeout after proxy error-before-write") + } +} + +// TestExecuteTranslatedStreamingRequest_ContextCancellationUnblocks verifies that +// cancelling the request context unblocks the headersReady wait even when the proxy +// goroutine stalls indefinitely without writing. +func TestExecuteTranslatedStreamingRequest_ContextCancellationUnblocks(t *testing.T) { + t.Parallel() + + mockLogger := &mockStyledLogger{} + + trans := &mockTranslator{ + name: "stalled-translator", + implementsErrorWriter: true, + writeErrorFunc: func(w http.ResponseWriter, err error, statusCode int) { + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(map[string]interface{}{"error": err.Error()}) + }, + transformRequestFunc: func(ctx context.Context, r *http.Request) (*translator.TransformedRequest, error) { + return &translator.TransformedRequest{ + OpenAIRequest: map[string]interface{}{ + "model": "test-model", + "stream": true, + "messages": []interface{}{ + map[string]interface{}{"role": "user", "content": "test"}, + }, + }, + ModelName: "test-model", + IsStreaming: true, + }, nil + }, + } + + // Proxy that blocks until its context is cancelled without writing anything. + proxyService := &mockProxyService{ + proxyFunc: func(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoints []*domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) error { + <-ctx.Done() + return ctx.Err() + }, + } + + app := &Application{ + logger: mockLogger, + proxyService: proxyService, + statsCollector: &mockStatsCollector{}, + repository: &mockEndpointRepository{}, + inspectorChain: inspector.NewChain(mockLogger), + profileFactory: &mockProfileFactory{}, + discoveryService: &mockDiscoveryServiceForTranslation{}, + Config: &config.Config{}, + } + + handler := app.translationHandler(trans) + + reqBody := map[string]interface{}{ + "model": "test-model", + "stream": true, + "messages": []interface{}{ + map[string]interface{}{"role": "user", "content": "Hello"}, + }, + } + body, _ := json.Marshal(reqBody) + + // Cancel the context after a short delay to simulate client disconnect / server timeout. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", "/test", bytes.NewReader(body)) + require.NoError(t, err) + + done := make(chan struct{}) + rec := httptest.NewRecorder() + go func() { + handler.ServeHTTP(rec, req) + close(done) + }() + + // Give it a reasonable window beyond the context timeout; if context cancellation + // correctly unblocks the select, the handler returns well before this deadline. + select { + case <-done: + // Returned after cancellation — correct behaviour. + case <-time.After(3 * time.Second): + t.Fatal("handler did not unblock after context cancellation") + } +} + +// TestExecuteTranslatedStreamingRequest_SuccessfulFlow verifies that the happy path +// (proxy writes headers then streams data) still works correctly after the fix. +func TestExecuteTranslatedStreamingRequest_SuccessfulFlow(t *testing.T) { + t.Parallel() + + mockLogger := &mockStyledLogger{} + + trans := &mockTranslator{ + name: "success-translator", + implementsErrorWriter: true, + writeErrorFunc: func(w http.ResponseWriter, err error, statusCode int) { + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(map[string]interface{}{"error": err.Error()}) + }, + transformRequestFunc: func(ctx context.Context, r *http.Request) (*translator.TransformedRequest, error) { + return &translator.TransformedRequest{ + OpenAIRequest: map[string]interface{}{ + "model": "test-model", + "stream": true, + "messages": []interface{}{ + map[string]interface{}{"role": "user", "content": "test"}, + }, + }, + ModelName: "test-model", + IsStreaming: true, + }, nil + }, + transformStreamingFunc: func(ctx context.Context, openaiStream io.Reader, w http.ResponseWriter, original *http.Request) error { + w.Header().Set(constants.HeaderContentType, "text/event-stream") + _, err := io.Copy(w, openaiStream) + return err + }, + } + + // Proxy that successfully writes a header then streams a single SSE event. + proxyService := &mockProxyService{ + proxyFunc: func(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoints []*domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) error { + w.Header().Set(constants.HeaderContentType, "text/event-stream") + w.Header().Set(constants.HeaderXOllaRequestID, "success-flow-id") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")) + return err + }, + } + + app := &Application{ + logger: mockLogger, + proxyService: proxyService, + statsCollector: &mockStatsCollector{}, + repository: &mockEndpointRepository{}, + inspectorChain: inspector.NewChain(mockLogger), + profileFactory: &mockProfileFactory{}, + discoveryService: &mockDiscoveryServiceForTranslation{}, + Config: &config.Config{}, + } + + handler := app.translationHandler(trans) + + reqBody := map[string]interface{}{ + "model": "test-model", + "stream": true, + "messages": []interface{}{ + map[string]interface{}{"role": "user", "content": "Hello"}, + }, + } + body, _ := json.Marshal(reqBody) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", "/test", bytes.NewReader(body)) + require.NoError(t, err) + + done := make(chan struct{}) + rec := httptest.NewRecorder() + go func() { + handler.ServeHTTP(rec, req) + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + t.Fatal("handler deadlocked on the success path") + } + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "text/event-stream", rec.Header().Get(constants.HeaderContentType)) + assert.Contains(t, rec.Body.String(), "Hello", "SSE payload should be forwarded") + assert.NotEmpty(t, rec.Header().Get(constants.HeaderXOllaRequestID), + "X-Olla-Request-ID should be copied to the client response") +} diff --git a/internal/config/config.go b/internal/config/config.go index c0d428c9..d6ba6048 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,7 +22,7 @@ const ( DefaultHost = "localhost" DefaultAllHost = "0.0.0.0" // local dev may use this DefaultProxyProfile = constants.ConfigurationProxyProfileAuto - DefaultProxyEngine = "sherpa" + DefaultProxyEngine = "olla" DefaultLoadBalancer = "priority" DefaultModelRegistryType = "memory" DefaultDiscoveryType = "static" diff --git a/internal/core/constants/content.go b/internal/core/constants/content.go index b1879c1a..9031e0a7 100644 --- a/internal/core/constants/content.go +++ b/internal/core/constants/content.go @@ -106,4 +106,5 @@ const ( HeaderXOllaRoutingStrategy = "X-Olla-Routing-Strategy" HeaderXOllaRoutingDecision = "X-Olla-Routing-Decision" HeaderXOllaRoutingReason = "X-Olla-Routing-Reason" + HeaderXOllaMode = "X-Olla-Mode" )