Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions internal/adapter/proxy/config/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ const (
DefaultKeepAlive = 60 * time.Second

// Olla-specific defaults for high-performance
OllaDefaultStreamBufferSize = 64 * 1024 // Larger buffer for better streaming performance
OllaDefaultMaxIdleConns = 100
OllaDefaultMaxConnsPerHost = 50
OllaDefaultIdleConnTimeout = 90 * time.Second
OllaDefaultStreamBufferSize = 64 * 1024 // Larger buffer for better streaming performance
OllaDefaultMaxIdleConns = 100
OllaDefaultMaxConnsPerHost = 50
OllaDefaultMaxIdleConnsPerHost = 25 // Half of MaxConnsPerHost; idle slots rarely need to match total capacity
OllaDefaultIdleConnTimeout = 90 * time.Second
// Olla uses 30s timeouts for faster failure detection in AI workloads
OllaDefaultTimeout = 30 * time.Second
OllaDefaultKeepAlive = 30 * time.Second
Expand Down Expand Up @@ -118,9 +119,10 @@ type OllaConfig struct {
BaseProxyConfig

// Olla-specific fields for advanced connection pooling
IdleConnTimeout time.Duration
MaxIdleConns int
MaxConnsPerHost int
IdleConnTimeout time.Duration
MaxIdleConns int
MaxConnsPerHost int
MaxIdleConnsPerHost int
}

// GetStreamBufferSize returns the stream buffer size, defaulting to OllaDefaultStreamBufferSize for better performance
Expand Down Expand Up @@ -155,6 +157,14 @@ func (c *OllaConfig) GetMaxConnsPerHost() int {
return c.MaxConnsPerHost
}

// GetMaxIdleConnsPerHost returns the maximum idle connections per host, defaulting to OllaDefaultMaxIdleConnsPerHost
func (c *OllaConfig) GetMaxIdleConnsPerHost() int {
if c.MaxIdleConnsPerHost == 0 {
return OllaDefaultMaxIdleConnsPerHost
}
return c.MaxIdleConnsPerHost
}

// GetConnectionTimeout returns the connection timeout, defaulting to OllaDefaultTimeout (30s for faster failure detection)
func (c *OllaConfig) GetConnectionTimeout() time.Duration {
if c.ConnectionTimeout == 0 {
Expand Down
8 changes: 7 additions & 1 deletion internal/adapter/proxy/olla/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ func NewService(
if configuration.MaxConnsPerHost == 0 {
configuration.MaxConnsPerHost = config.OllaDefaultMaxConnsPerHost
}
if configuration.MaxIdleConnsPerHost == 0 {
configuration.MaxIdleConnsPerHost = config.OllaDefaultMaxIdleConnsPerHost
}
if configuration.IdleConnTimeout == 0 {
configuration.IdleConnTimeout = config.OllaDefaultIdleConnTimeout
}
Expand Down Expand Up @@ -215,7 +218,8 @@ func NewService(
func createOptimisedTransport(config *Configuration) *http.Transport {
return &http.Transport{
MaxIdleConns: config.MaxIdleConns,
MaxIdleConnsPerHost: config.MaxConnsPerHost,
MaxIdleConnsPerHost: config.MaxIdleConnsPerHost,
MaxConnsPerHost: config.MaxConnsPerHost,
IdleConnTimeout: config.IdleConnTimeout,
TLSHandshakeTimeout: DefaultTLSHandshakeTimeout,
DisableCompression: true,
Expand Down Expand Up @@ -693,11 +697,13 @@ func (s *Service) UpdateConfig(config ports.ProxyConfiguration) {
newConfig.MaxIdleConns = ollaConfig.MaxIdleConns
newConfig.IdleConnTimeout = ollaConfig.IdleConnTimeout
newConfig.MaxConnsPerHost = ollaConfig.MaxConnsPerHost
newConfig.MaxIdleConnsPerHost = ollaConfig.MaxIdleConnsPerHost
} else {
// fallback: preserve current Olla-specific settings for non-Olla configs
newConfig.MaxIdleConns = s.configuration.MaxIdleConns
newConfig.IdleConnTimeout = s.configuration.IdleConnTimeout
newConfig.MaxConnsPerHost = s.configuration.MaxConnsPerHost
newConfig.MaxIdleConnsPerHost = s.configuration.MaxIdleConnsPerHost
}

// Update configuration atomically
Expand Down
3 changes: 0 additions & 3 deletions internal/adapter/proxy/olla/service_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit
return fmt.Errorf("circuit breaker open for endpoint %s", endpoint.Name)
}

s.Selector.IncrementConnections(endpoint)
defer s.Selector.DecrementConnections(endpoint)

// Build target URL using common function that respects preserve_path
targetURL := common.BuildTargetURL(r, endpoint, s.configuration.GetProxyPrefix())
stats.TargetUrl = targetURL.String()
Expand Down
84 changes: 84 additions & 0 deletions internal/adapter/proxy/olla/service_transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package olla

import (
"testing"
"time"

"github.com/thushan/olla/internal/adapter/proxy/config"
)

// TestCreateOptimisedTransport_ConnectionLimits verifies that both MaxConnsPerHost and
// MaxIdleConnsPerHost are mapped to their correct fields on http.Transport.
// Previously MaxConnsPerHost was mistakenly written to MaxIdleConnsPerHost and
// MaxConnsPerHost was never set (defaulting to 0 = unlimited).
func TestCreateOptimisedTransport_ConnectionLimits(t *testing.T) {
t.Parallel()

cfg := &Configuration{}
cfg.MaxConnsPerHost = 42
cfg.MaxIdleConnsPerHost = 17
cfg.MaxIdleConns = 200
cfg.IdleConnTimeout = 90 * time.Second

transport := createOptimisedTransport(cfg)

if transport.MaxConnsPerHost != 42 {
t.Errorf("MaxConnsPerHost: want 42, got %d", transport.MaxConnsPerHost)
}
if transport.MaxIdleConnsPerHost != 17 {
t.Errorf("MaxIdleConnsPerHost: want 17, got %d", transport.MaxIdleConnsPerHost)
}
if transport.MaxIdleConns != 200 {
t.Errorf("MaxIdleConns: want 200, got %d", transport.MaxIdleConns)
}
}

// TestCreateOptimisedTransport_DefaultsApplied verifies that NewService fills in sensible
// defaults before handing the config to createOptimisedTransport, so a zero-value config
// never silently leaves MaxConnsPerHost unlimited.
func TestCreateOptimisedTransport_DefaultsApplied(t *testing.T) {
t.Parallel()

// Zero-value config — defaults should be filled in by NewService, but we can verify
// the expected defaults are consistent with the package constants.
cfg := &Configuration{}
cfg.MaxConnsPerHost = config.OllaDefaultMaxConnsPerHost
cfg.MaxIdleConnsPerHost = config.OllaDefaultMaxIdleConnsPerHost
cfg.MaxIdleConns = config.OllaDefaultMaxIdleConns
cfg.IdleConnTimeout = config.OllaDefaultIdleConnTimeout

transport := createOptimisedTransport(cfg)

if transport.MaxConnsPerHost != config.OllaDefaultMaxConnsPerHost {
t.Errorf("MaxConnsPerHost: want %d, got %d", config.OllaDefaultMaxConnsPerHost, transport.MaxConnsPerHost)
}
if transport.MaxIdleConnsPerHost != config.OllaDefaultMaxIdleConnsPerHost {
t.Errorf("MaxIdleConnsPerHost: want %d, got %d", config.OllaDefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
}
}

// TestCreateOptimisedTransport_FieldsAreDistinct guards against the specific regression
// where MaxConnsPerHost value bled into MaxIdleConnsPerHost. Using distinct values
// makes the mapping error immediately visible.
func TestCreateOptimisedTransport_FieldsAreDistinct(t *testing.T) {
t.Parallel()

cfg := &Configuration{}
cfg.MaxConnsPerHost = 100
cfg.MaxIdleConnsPerHost = 10
cfg.MaxIdleConns = 500

transport := createOptimisedTransport(cfg)

// Regression guard: if the bug is reintroduced both fields get value 100.
if transport.MaxConnsPerHost == transport.MaxIdleConnsPerHost {
t.Errorf("MaxConnsPerHost (%d) and MaxIdleConnsPerHost (%d) are equal — likely a field mapping regression",
transport.MaxConnsPerHost, transport.MaxIdleConnsPerHost)
}
if transport.MaxConnsPerHost != 100 {
t.Errorf("MaxConnsPerHost: want 100, got %d", transport.MaxConnsPerHost)
}
if transport.MaxIdleConnsPerHost != 10 {
t.Errorf("MaxIdleConnsPerHost: want 10, got %d", transport.MaxIdleConnsPerHost)
}
}
150 changes: 150 additions & 0 deletions internal/adapter/proxy/proxy_olla_connection_counting_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package proxy

import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/thushan/olla/internal/adapter/proxy/olla"
"github.com/thushan/olla/internal/core/domain"
)

// countingEndpointSelector tracks the number of Increment and Decrement calls
// using atomic counters so the test is safe under concurrent execution.
type countingEndpointSelector struct {
incrementCalls atomic.Int64
decrementCalls atomic.Int64
endpoint *domain.Endpoint
}

func (c *countingEndpointSelector) Select(_ context.Context, endpoints []*domain.Endpoint) (*domain.Endpoint, error) {
if c.endpoint != nil {
return c.endpoint, nil
}
if len(endpoints) > 0 {
return endpoints[0], nil
}
return nil, nil
}

func (c *countingEndpointSelector) Name() string { return "counting" }

func (c *countingEndpointSelector) IncrementConnections(_ *domain.Endpoint) {
c.incrementCalls.Add(1)
}

func (c *countingEndpointSelector) DecrementConnections(_ *domain.Endpoint) {
c.decrementCalls.Add(1)
}

// TestOllaProxy_ConnectionCountingNoDuplication verifies that a single successful proxy
// attempt results in exactly one IncrementConnections call and one DecrementConnections
// call. Before the fix, proxyToSingleEndpoint also incremented/decremented, producing
// counts of two each.
func TestOllaProxy_ConnectionCountingNoDuplication(t *testing.T) {
t.Parallel()

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
defer upstream.Close()

endpoint := createTestEndpoint("test-endpoint", upstream.URL, domain.StatusHealthy)

selector := &countingEndpointSelector{endpoint: endpoint}

config := &olla.Configuration{}
config.ResponseTimeout = 5 * time.Second
config.ReadTimeout = 2 * time.Second
config.StreamBufferSize = 8192
config.MaxIdleConns = 10
config.IdleConnTimeout = 30 * time.Second
config.MaxConnsPerHost = 5

proxy, err := olla.NewService(
&mockDiscoveryService{endpoints: []*domain.Endpoint{endpoint}},
selector,
config,
createTestStatsCollector(),
nil,
createTestLogger(),
)
if err != nil {
t.Fatalf("failed to create Olla proxy: %v", err)
}

req, stats, rlog := createTestRequestWithStats("POST", "/v1/chat/completions", `{"model":"test"}`)
w := httptest.NewRecorder()

if err := proxy.ProxyRequestToEndpoints(req.Context(), w, req, []*domain.Endpoint{endpoint}, stats, rlog); err != nil {
t.Fatalf("proxy request failed: %v", err)
}

if got := selector.incrementCalls.Load(); got != 1 {
t.Errorf("IncrementConnections called %d times; want exactly 1", got)
}
if got := selector.decrementCalls.Load(); got != 1 {
t.Errorf("DecrementConnections called %d times; want exactly 1", got)
}
}

// TestOllaProxy_ConnectionCountReturnsToZero verifies that after a completed request
// the net connection delta is zero — i.e. every increment is paired with a decrement.
func TestOllaProxy_ConnectionCountReturnsToZero(t *testing.T) {
t.Parallel()

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
defer upstream.Close()

endpoint := createTestEndpoint("test-endpoint", upstream.URL, domain.StatusHealthy)
selector := &countingEndpointSelector{endpoint: endpoint}

config := &olla.Configuration{}
config.ResponseTimeout = 5 * time.Second
config.ReadTimeout = 2 * time.Second
config.StreamBufferSize = 8192
config.MaxIdleConns = 10
config.IdleConnTimeout = 30 * time.Second
config.MaxConnsPerHost = 5

proxy, err := olla.NewService(
&mockDiscoveryService{endpoints: []*domain.Endpoint{endpoint}},
selector,
config,
createTestStatsCollector(),
nil,
createTestLogger(),
)
if err != nil {
t.Fatalf("failed to create Olla proxy: %v", err)
}

const requests = 5
for i := 0; i < requests; i++ {
req, stats, rlog := createTestRequestWithStats("POST", "/v1/chat/completions", `{"model":"test"}`)
w := httptest.NewRecorder()
if err := proxy.ProxyRequestToEndpoints(req.Context(), w, req, []*domain.Endpoint{endpoint}, stats, rlog); err != nil {
t.Fatalf("request %d failed: %v", i+1, err)
}
}

inc := selector.incrementCalls.Load()
dec := selector.decrementCalls.Load()

if inc != requests {
t.Errorf("IncrementConnections called %d times; want %d", inc, requests)
}
if dec != requests {
t.Errorf("DecrementConnections called %d times; want %d", dec, requests)
}
if net := inc - dec; net != 0 {
t.Errorf("net connection delta is %d after all requests completed; want 0", net)
}
}
5 changes: 2 additions & 3 deletions internal/adapter/proxy/sherpa/service_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ func (s *Service) ProxyRequestToEndpointsWithRetry(ctx context.Context, w http.R
}

// proxyToSingleEndpoint executes the proxy request to a specific endpoint
// Note: Connection increment/decrement is handled by RetryHandler.executeProxyAttempt
// to avoid double-counting (see proxy_olla_connection_counting_test.go for context).
func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoint *domain.Endpoint, stats *ports.RequestStats, rlog logger.StyledLogger) error {
stats.EndpointName = endpoint.Name

s.Selector.IncrementConnections(endpoint)
defer s.Selector.DecrementConnections(endpoint)

targetURL := common.BuildTargetURL(r, endpoint, s.configuration.GetProxyPrefix())

stats.TargetUrl = targetURL.String()
Expand Down
7 changes: 4 additions & 3 deletions internal/adapter/translator/anthropic/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import "fmt"
// AnthropicRequest represents an Anthropic API request
// Maps to the Anthropic Messages API format
type AnthropicRequest struct {
ToolChoice interface{} `json:"tool_choice,omitempty"` // string or object
System interface{} `json:"system,omitempty"` // string or []ContentBlock
Thinking interface{} `json:"thinking,omitempty"` // Extended thinking configuration
ToolChoice interface{} `json:"tool_choice,omitempty"` // string or object
System interface{} `json:"system,omitempty"` // string or []ContentBlock
Thinking interface{} `json:"thinking,omitempty"` // Extended thinking configuration
OutputConfig interface{} `json:"output_config,omitempty"` // Output configuration (effort, structured output format)
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
Expand Down
Loading
Loading