diff --git a/CLAUDE.md b/CLAUDE.md index 3691674..98b9db2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -114,6 +114,8 @@ olla/ - `internal/adapter/translator/types.go` - PassthroughCapable interface and translator types - `internal/adapter/translator/anthropic/` - Anthropic translator implementation - `internal/adapter/stats/translator_collector.go` - Translator metrics collector +- `internal/adapter/balancer/sticky.go` - Sticky session wrapper +- `internal/app/handlers/handler_stats_sticky.go` - Sticky session stats endpoint - `internal/core/constants/translator.go` - TranslatorMode and FallbackReason constants - `internal/core/ports/stats.go` - StatsCollector interface with translator tracking - `internal/core/domain/profile_config.go` - AnthropicSupportConfig for backend profiles @@ -130,6 +132,7 @@ olla/ - `/internal/status/models` - Models status details - `/internal/stats/models` - Model statistics - `/internal/stats/translators` - Translator statistics +- `/internal/stats/sticky` - Sticky session statistics (returns `{"enabled":false}` when disabled) - `/internal/process` - Process statistics - `/version` - Version information @@ -158,6 +161,9 @@ Dynamically registered based on configured translators (e.g., Anthropic Messages - `X-Olla-Routing-Strategy`: Routing strategy used (when model routing is active) - `X-Olla-Routing-Decision`: Routing decision made (routed/fallback/rejected) - `X-Olla-Routing-Reason`: Human-readable reason for routing decision +- `X-Olla-Sticky-Session`: Sticky session status (hit/miss/repin/disabled) +- `X-Olla-Sticky-Key-Source`: Key source used (session_header/prefix_hash/auth_header/ip/none) +- `X-Olla-Session-ID`: Echoed session ID when client supplies one ## Testing @@ -201,6 +207,7 @@ Always run `make ready` before committing changes. - **Translator Layer**: Enables API format translation (e.g., OpenAI ↔ Anthropic) with passthrough optimisation for backends with native support - **Passthrough Mode**: When a backend natively supports the Anthropic Messages API (vLLM, llama.cpp, LM Studio, Ollama), requests bypass translation entirely - **Translator Metrics**: Thread-safe per-translator statistics tracking passthrough/translation rates, fallback reasons, latency, and streaming breakdown (`internal/adapter/stats/translator_collector.go`) +- **Sticky Sessions**: Optional decorator on the endpoint selector that pins multi-turn LLM conversations to the backend that handled the first turn, maximising KV-cache reuse. FNV-64a hashed keys, TTL + LRU bounded, purged on routable→non-routable health transitions (`internal/adapter/balancer/sticky.go`) - **Proxy Engines**: Choose Sherpa (simple) or Olla (high-performance) - **Load Balancing**: Priority-based recommended for production - **Version Management**: Build-time version injection via `internal/version` @@ -212,6 +219,20 @@ Always run `make ready` before committing changes. - Always run `make ready` before committing - Use `make help` to see all available commands +## Dependencies (Endorsed) + +```go +"github.com/docker/go-units" // Human-readable sizes +"github.com/json-iterator/go" // High-performance JSON encoding/decoding +"github.com/puzpuzpuz/xsync/v4" // Concurrent maps/counters +"github.com/tidwall/gjson" // Fast JSON parsing +"github.com/jellydator/ttlcache" // Time-to-live cache +"golang.org/x/sync" // errgroup +"golang.org/x/time" // rate limiting +``` + +Do not add additional dependencies unless explicitly asked. + ## SUB-AGENT DELEGATION CRITICAL: Always delegate tasks to the appropriate subagent. Do NOT perform work directly in the main context. diff --git a/config/config.yaml b/config/config.yaml index c6a3294..7a4f767 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -41,6 +41,20 @@ proxy: response_timeout: 900s read_timeout: 600s + # KV-cache affinity routing (opt-in) + # Routes repeat turns in a conversation to the same backend, maximising KV-cache + # hit rates for long multi-turn sessions at the cost of reduced load distribution. + sticky_sessions: + enabled: false # opt-in: route same-prefix conversations to same backend + idle_ttl_seconds: 600 # 10-min sliding window (refreshed on each matched request) + max_sessions: 10000 # LRU evicts oldest sessions when full + key_sources: # tried in order; first match wins + - "session_header" # X-Olla-Session-ID header (explicit client opt-in) + - "prefix_hash" # FNV-1a hash of first N bytes of messages JSON (best cache locality) + - "auth_header" # Authorization header hash (per-user affinity) + # - "ip" # client IP (opt-in; unreliable behind NAT/Docker) + prefix_hash_bytes: 512 # how many leading bytes of the messages field to hash + # DEPRECATED as of v0.0.16 - These fields are no longer used # max_retries: 3 # Replaced by retry.max_attempts # retry_backoff: 500ms # Now uses intelligent exponential backoff diff --git a/docs/content/concepts/load-balancing.md b/docs/content/concepts/load-balancing.md index 2b68bce..11592ec 100644 --- a/docs/content/concepts/load-balancing.md +++ b/docs/content/concepts/load-balancing.md @@ -15,6 +15,8 @@ Olla provides multiple load balancing strategies to distribute requests across backend endpoints efficiently. Each strategy has specific use cases and characteristics to optimise for different deployment scenarios. +For multi-turn LLM workloads, combine any strategy with [Sticky Sessions](sticky-sessions.md) to preserve KV-cache affinity across turns. + ## Overview Load balancing determines which backend endpoint receives each incoming request. The strategy you choose affects: diff --git a/docs/content/concepts/sticky-sessions.md b/docs/content/concepts/sticky-sessions.md new file mode 100644 index 0000000..00abd3d --- /dev/null +++ b/docs/content/concepts/sticky-sessions.md @@ -0,0 +1,219 @@ +# Sticky Sessions + +> :memo: **Default Configuration** +> ```yaml +> proxy: +> sticky_sessions: +> enabled: false +> ``` +> Sticky sessions are **opt-in**. Set `enabled: true` under `proxy.sticky_sessions` to activate KV-cache affinity routing. +> +> **Environment Variable**: `OLLA_PROXY_STICKY_SESSIONS_ENABLED` + +Sticky sessions route repeat turns in a multi-turn conversation to the same backend endpoint, maximising KV-cache reuse across turns. The feature wraps the configured load balancer as a decorator. The underlying strategy (priority, round-robin, least-connections) is unchanged for new sessions and fallback cases. + +## Why sticky sessions + +Modern LLM backends maintain a KV-cache for the token sequence they have already processed. When the next turn of a conversation lands on the **same** backend, the backend can skip re-ingesting the full context and jump straight to generating new tokens. For long conversations this produces a substantial reduction in both time-to-first-token and compute cost; the benefit scales with context length. + +Without affinity, a load balancer may distribute turn N and turn N+1 to different backends. The receiving backend for turn N+1 has a cold cache and must process the entire prompt from scratch. For workloads with short prompts or single-turn completions this overhead is negligible; for chat-style applications with growing context it compounds with every turn. + +## How it works + +On each request, Olla computes a **session key** from one of the configured key sources (see [Key sources](#key-sources) below). The key is looked up in an in-memory LRU/TTL store: + +- **Hit**: the pinned backend is still routable; the request is sent there and the TTL is refreshed. +- **Miss**: no entry exists; the request is forwarded to the underlying balancer, and the selected backend is stored. +- **Repin**: an entry exists but the pinned backend is no longer routable; the underlying balancer selects a new backend and the entry is overwritten. +- **Disabled**: no key source produced a usable key (e.g. no `X-Olla-Session-ID` header and no other sources configured); the request is passed through to the underlying balancer without recording anything. + +```mermaid +sequenceDiagram + participant C as Client + participant O as Olla + participant Store as Session Store + participant B as Backend + + C->>O: POST /olla/proxy/... (turn N+1) + O->>O: Derive session key + O->>Store: Lookup key + alt Cache hit (backend routable) + Store-->>O: Pinned backend URL + O->>B: Forward to pinned backend + B-->>O: Response + O-->>C: Response + X-Olla-Sticky-Session: hit + else Cache miss / repin + Store-->>O: No entry (or dead backend) + O->>O: Delegate to underlying balancer + O->>Store: Store key → selected backend + O->>B: Forward to selected backend + B-->>O: Response + O-->>C: Response + X-Olla-Sticky-Session: miss + end +``` + +## Key sources + +The `key_sources` list is evaluated in order; the first source that produces a non-empty value wins. All keys are scoped to the model name so the same client talking to different models maintains independent session state. + +| Source | How the key is derived | When to prefer it | Caveats | +|---|---|---|---| +| `session_header` | FNV-64a hash of the `X-Olla-Session-ID` request header | Explicit client opt-in; most reliable | Client must send the header consistently | +| `prefix_hash` | FNV-64a hash of the first `prefix_hash_bytes` bytes of the `messages` JSON field | No client changes needed; best cache locality | Two conversations with identical opening messages share a session | +| `auth_header` | FNV-64a hash of the `Authorization` header value | Per-user affinity without client changes | Breaks if the token rotates mid-conversation; unreliable with shared tokens | +| `ip` | Client IP address (extracted via `net.SplitHostPort`) | Simple deployments with no NAT | Unreliable behind NAT, load balancers, or Docker networking | + +All header and token values are hashed before storage; plaintext secrets are never written to the session store. + +The default configuration enables `session_header`, `prefix_hash`, and `auth_header` (in that order) and comments out `ip` because it is unreliable behind typical container networking. Adjust the list to suit your deployment. + +## Session lifecycle and eviction + +Sessions do not live forever. Three mechanisms remove them: + +**Sliding TTL**: every cache hit refreshes the expiry timer. A session that goes idle for longer than `idle_ttl_seconds` is expired automatically. Active conversations are never interrupted mid-session. + +**LRU eviction**: when the store reaches `max_sessions`, the least-recently-used entry is evicted to make room. Under normal load this should never occur; it acts as a safety cap to bound memory usage. + +**Health-based purge**: when the health checker transitions a backend to an unhealthy state, Olla immediately calls `PurgeDeadEndpoints` with the current routable set. Any session entry pointing to the now-dead backend is deleted without waiting for TTL. The next request for that session falls through to the underlying balancer and receives a `repin`. + +!!! note "Busy endpoints are not purged" + A backend in the **Busy** state is still considered routable (`IsRoutable() == true`). Sticky sessions are preserved through Busy transitions; the backend is overloaded but still serving. Only transitions to Unhealthy, Offline, or Unknown trigger a purge. + +```mermaid +stateDiagram-v2 + [*] --> Active: First request (miss) + Active --> Active: Subsequent requests (hit, TTL refreshed) + Active --> Expired: Idle longer than idle_ttl_seconds + Active --> Purged: Backend becomes unhealthy + Active --> Evicted: LRU cap (max_sessions) reached + Expired --> [*] + Purged --> [*] + Evicted --> [*] +``` + +## Response headers + +Olla writes three response headers so clients and operators can observe affinity decisions: + +| Header | Values | Meaning | +|---|---|---| +| `X-Olla-Sticky-Session` | `hit` / `miss` / `repin` / `disabled` | Outcome of the affinity lookup for this request | +| `X-Olla-Sticky-Key-Source` | `session_header` / `prefix_hash` / `auth_header` / `ip` / `none` | Which key source was used (absent when outcome is `disabled`) | +| `X-Olla-Session-ID` | _(echoed from request)_ | Present in the response only when the client sent `X-Olla-Session-ID`; lets stateless clients confirm the header was received | + +Example: first request (miss), client provides explicit session ID: + +```bash +curl -i -X POST http://localhost:40114/olla/proxy/api/chat \ + -H "X-Olla-Session-ID: conv-abc123" \ + -H "Content-Type: application/json" \ + -d '{"model":"llama3.2","messages":[{"role":"user","content":"Hello"}]}' +``` + +```http +HTTP/1.1 200 OK +X-Olla-Endpoint: gpu-server-1 +X-Olla-Sticky-Session: miss +X-Olla-Sticky-Key-Source: session_header +X-Olla-Session-ID: conv-abc123 +``` + +Subsequent request (hit): + +```http +HTTP/1.1 200 OK +X-Olla-Endpoint: gpu-server-1 +X-Olla-Sticky-Session: hit +X-Olla-Sticky-Key-Source: session_header +X-Olla-Session-ID: conv-abc123 +``` + +## Configuration + +All fields live under `proxy.sticky_sessions`: + +```yaml +proxy: + sticky_sessions: + enabled: false # opt-in: set true to activate affinity routing + + idle_ttl_seconds: 600 # sliding TTL in seconds; 0 = sessions never expire by TTL + # (not recommended, sessions accumulate until LRU eviction) + + max_sessions: 10000 # LRU capacity; oldest entries are evicted when full + + key_sources: # ordered cascade, first match wins + - "session_header" # X-Olla-Session-ID header (explicit client opt-in) + - "prefix_hash" # hash of first N bytes of messages JSON + - "auth_header" # hash of Authorization header (per-user affinity) + # - "ip" # client IP, opt-in; unreliable behind NAT/Docker + + prefix_hash_bytes: 512 # bytes of the messages field to hash for prefix_hash source; + # larger values reduce false collisions at a small CPU cost +``` + +The only env var exposed for this feature is `OLLA_PROXY_STICKY_SESSIONS_ENABLED` (boolean). The remaining fields are configuration-file only. + +## Observability + +### Stats endpoint + +```bash +curl http://localhost:40114/internal/stats/sticky +``` + +When sticky sessions are **enabled**: + +```json +{ + "enabled": true, + "active_sessions": 142, + "insertions": 1500, + "hits": 9231, + "misses": 1500, + "evictions": 0, + "max_sessions": 10000, + "idle_ttl_seconds": 600 +} +``` + +When sticky sessions are **disabled** (stable shape for scripting): + +```json +{ + "enabled": false +} +``` + +Scripts should branch on the `enabled` field; the endpoint always returns `200 OK` regardless of whether the feature is active. + +## When NOT to use sticky sessions + +Sticky sessions trade load distribution for cache locality. They are not always the right choice: + +- **Stateless or single-turn workloads**: embeddings, one-shot completions, and batch jobs gain nothing from affinity; use the plain load balancer. +- **Model-routing-dominated traffic**: if requests are already hard-routed to specific endpoints by model routing, the sticky wrapper adds overhead with no benefit. +- **Very small deployments**: two endpoints with priority load balancing already behave predictably; adding stickiness is unnecessary complexity. +- **Homogeneous short-prompt workloads**: when prompts are short and vary widely, KV-cache hit rates on the backend are already low; affinity provides little gain and reduces load distribution. +- **Deployments with aggressive autoscaling**: if backends are added and removed frequently, sessions will repin often and the affinity benefit is diluted. + +## Developer notes + +**Decorator pattern**: `StickySessionWrapper` in `internal/adapter/balancer/sticky.go` wraps any `domain.EndpointSelector` implementation. No factory or registry changes are needed to add a new inner balancer; the wrapper is applied in `ProxyServiceWrapper.applyStickySessions()` inside `internal/app/services/proxy.go`. + +**Hashing**: FNV-64a (`hash/fnv`) is used for all key derivation. It is non-cryptographic and intended only as a compact routing hint; collisions are acceptable (two different sessions occasionally land on the same backend). Do not use these keys as security tokens. + +**Import cycle avoidance**: `StickyOutcome` is defined in `internal/core/domain/routing.go` rather than in `internal/adapter/balancer/sticky.go`. This allows `internal/adapter/proxy/core` (the proxy engine shared layer) to read the outcome and write response headers without importing the balancer package, which would create a cycle. + +**Purge wiring order**: `applyStickySessions()` assigns `s.stickyWrapper` and then immediately calls `s.discoverySvc.SetPurgeDeadEndpointsFn(s.PurgeDeadEndpoints)`. This ensures the write of `stickyWrapper` happens-before the health-checker goroutine can observe it via the purge callback. The registration happens inside `ProxyServiceWrapper.Start()`, not at construction time, to respect service startup ordering. + +**Session store**: backed by `github.com/jellydator/ttlcache/v3` with `WithCapacity` (LRU) and `WithTTL` (sliding expiry). The `ttlcache.Get` call inside `Select` refreshes the TTL automatically on every hit. + +Relevant source files: `internal/adapter/balancer/sticky.go`, `internal/core/domain/routing.go`, `internal/app/services/proxy.go`, `internal/app/services/discovery.go`, `internal/app/handlers/handler_proxy.go`, `internal/app/handlers/handler_stats_sticky.go`. + +## See also + +- [Load Balancing](load-balancing.md): underlying strategies that sticky sessions wrap +- [Health Checking](health-checking.md): health states and the routable concept +- [Configuration Reference](../configuration/reference.md#sticky-sessions): complete field reference diff --git a/docs/content/configuration/reference.md b/docs/content/configuration/reference.md index 6a5f1aa..1739fa0 100644 --- a/docs/content/configuration/reference.md +++ b/docs/content/configuration/reference.md @@ -211,6 +211,36 @@ proxy: - "*debug*" # Exclude debug profiles ``` +### Sticky Sessions {#sticky-sessions} + +KV-cache affinity routing for multi-turn LLM conversations. See [Sticky Sessions](../concepts/sticky-sessions.md) for a full explanation. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `sticky_sessions.enabled` | bool | `false` | Enable affinity routing (opt-in) | +| `sticky_sessions.idle_ttl_seconds` | int | `600` | Sliding TTL in seconds; 0 = no TTL expiry | +| `sticky_sessions.max_sessions` | uint64 | `10000` | LRU capacity; oldest entry evicted when full | +| `sticky_sessions.key_sources` | []string | `["session_header","prefix_hash","auth_header"]` | Ordered key source cascade; first match wins | +| `sticky_sessions.prefix_hash_bytes` | int | `512` | Bytes of the messages field to hash for `prefix_hash` | + +**Environment Variable**: `OLLA_PROXY_STICKY_SESSIONS_ENABLED` (only `enabled` is exposed as an env var) + +Example: + +```yaml +proxy: + sticky_sessions: + enabled: true # opt-in + idle_ttl_seconds: 600 # 10-min sliding window + max_sessions: 10000 # LRU cap + key_sources: + - "session_header" # X-Olla-Session-ID header + - "prefix_hash" # hash of messages prefix + - "auth_header" # hash of Authorization header + # - "ip" # client IP (unreliable behind NAT) + prefix_hash_bytes: 512 +``` + ## Discovery Configuration Endpoint discovery and health checking. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 06cb5f9..f8d23f6 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -129,6 +129,7 @@ nav: - Concepts: - Overview: concepts/overview.md - Load Balancing: concepts/load-balancing.md + - Sticky Sessions: concepts/sticky-sessions.md - Model Routing: concepts/model-routing.md - Model Aliases: concepts/model-aliases.md - Model Unification: concepts/model-unification.md diff --git a/go.mod b/go.mod index 6bf12a9..35565a8 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/containerd/console v1.0.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gookit/color v1.6.0 // indirect + github.com/jellydator/ttlcache/v3 v3.4.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/mattn/go-runewidth v0.0.20 // indirect diff --git a/go.sum b/go.sum index 7a23fc1..979e747 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQ github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= github.com/gookit/color v1.6.0 h1:JjJXBTk1ETNyqyilJhkTXJYYigHG24TM9Xa2M1xAhRA= github.com/gookit/color v1.6.0/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs= +github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= +github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/adapter/balancer/sticky.go b/internal/adapter/balancer/sticky.go new file mode 100644 index 0000000..a36f067 --- /dev/null +++ b/internal/adapter/balancer/sticky.go @@ -0,0 +1,285 @@ +package balancer + +import ( + "context" + "hash/fnv" + "log/slog" + "net" + "net/http" + "strings" + "time" + + "github.com/jellydator/ttlcache/v3" + "github.com/thushan/olla/internal/config" + "github.com/thushan/olla/internal/core/constants" + "github.com/thushan/olla/internal/core/domain" + "github.com/tidwall/gjson" +) + +// StickyOutcome is the domain type; alias it here for callers that import +// this package directly (e.g. handler layer). The canonical definition lives in +// core/domain so that adapter/proxy/core can read it without a cycle. +type StickyOutcome = domain.StickyOutcome + +// StickySessionWrapper is a decorator around any EndpointSelector that adds +// KV-cache affinity routing. It remembers which backend last handled a +// conversation (identified by a computed key) and steers subsequent turns +// to the same backend while it remains routable. +// +// TOCTOU on the session store is intentionally acceptable — both racing +// goroutines will select valid backends and the last writer wins, which +// converges quickly in practice. +type StickySessionWrapper struct { + inner domain.EndpointSelector + store *ttlcache.Cache[string, string] + cfg config.StickySessionConfig +} + +// NewStickySessionWrapper wraps inner with sticky session affinity using cfg. +// Call Start() after construction and Stop() on shutdown. +func NewStickySessionWrapper(inner domain.EndpointSelector, cfg config.StickySessionConfig) *StickySessionWrapper { + idleTTL := time.Duration(cfg.IdleTTLSeconds) * time.Second + + if cfg.IdleTTLSeconds <= 0 { + // ttlcache treats a zero TTL as no expiration — sessions accumulate until + // capacity pressure forces eviction. Warn so operators notice the config. + slog.Warn("sticky sessions TTL is zero — sessions will never expire by TTL") + } + + store := ttlcache.New[string, string]( + ttlcache.WithTTL[string, string](idleTTL), + ttlcache.WithCapacity[string, string](cfg.MaxSessions), + ) + + return &StickySessionWrapper{ + inner: inner, + store: store, + cfg: cfg, + } +} + +// Start launches the ttlcache background goroutine that handles TTL expiry and +// capacity-based eviction. Must be called before the wrapper is used for routing. +func (s *StickySessionWrapper) Start() { + go s.store.Start() +} + +// Stop shuts down the ttlcache background goroutine. Safe to call multiple times. +func (s *StickySessionWrapper) Stop() { + s.store.Stop() +} + +// Name returns a descriptive name that composes the inner balancer name. +func (s *StickySessionWrapper) Name() string { + return "sticky(" + s.inner.Name() + ")" +} + +// IncrementConnections delegates to the inner selector. +func (s *StickySessionWrapper) IncrementConnections(endpoint *domain.Endpoint) { + s.inner.IncrementConnections(endpoint) +} + +// DecrementConnections delegates to the inner selector. +func (s *StickySessionWrapper) DecrementConnections(endpoint *domain.Endpoint) { + s.inner.DecrementConnections(endpoint) +} + +// Select routes the request to the pinned backend when the affinity key is present +// and the backend is still routable. On a miss or dead-backend, it delegates to +// the inner selector and pins the newly chosen backend for next time. +func (s *StickySessionWrapper) Select(ctx context.Context, endpoints []*domain.Endpoint) (*domain.Endpoint, error) { + outcome, _ := ctx.Value(constants.ContextStickyOutcomeKey).(*StickyOutcome) + + key, _ := ctx.Value(constants.ContextStickyKeyKey).(string) + source, _ := ctx.Value(constants.ContextStickyKeySourceKey).(string) + + if key == "" { + // No affinity key — pass through transparently. + if outcome != nil { + outcome.Result = "disabled" + outcome.Source = "none" + } + return s.inner.Select(ctx, endpoints) + } + + // ttlcache.Get refreshes the sliding TTL automatically. + item := s.store.Get(key) + if item != nil { + pinnedURL := item.Value() + for _, ep := range endpoints { + if ep.Status.IsRoutable() && ep.URLString == pinnedURL { + // Sticky hit — backend is still alive and serving this model. + if outcome != nil { + outcome.Result = "hit" + outcome.Source = source + } + return ep, nil + } + } + // Pinned backend is gone or unhealthy — fall through to repin. + } + + chosen, err := s.inner.Select(ctx, endpoints) + if err != nil { + return nil, err + } + + // Record the affinity mapping for future turns. + s.store.Set(key, chosen.URLString, ttlcache.DefaultTTL) + + result := "miss" + if item != nil { + // We had a pin but the backend was no longer routable. + result = "repin" + } + + if outcome != nil { + outcome.Result = result + outcome.Source = source + } + + return chosen, nil +} + +// PurgeDeadEndpoints removes session entries that point to backends no longer +// present in the provided routable set. Callers can invoke this periodically +// (e.g. on health-check updates) to reclaim store capacity proactively. +func (s *StickySessionWrapper) PurgeDeadEndpoints(routable []*domain.Endpoint) { + alive := make(map[string]struct{}, len(routable)) + for _, ep := range routable { + alive[ep.URLString] = struct{}{} + } + + s.store.Range(func(item *ttlcache.Item[string, string]) bool { + if _, ok := alive[item.Value()]; !ok { + s.store.Delete(item.Key()) + } + return true + }) +} + +// ComputeStickyKey derives an affinity key for this request using the configured +// key_sources cascade. The key is model-scoped so the same client talking to +// different models does not cross-contaminate their session state. +// +// Returns ("", "") when no source produces a usable key. +// Exported so handlers can compute the key before invoking the balancer. +func ComputeStickyKey(r *http.Request, modelName string, cfg config.StickySessionConfig, body []byte) (key, source string) { + for _, src := range cfg.KeySources { + var k, s string + switch src { + case "session_header": + k, s = stickyKeyFromSessionHeader(r, modelName) + case "prefix_hash": + k, s = stickyKeyFromPrefixHash(body, modelName, cfg.PrefixHashBytes) + case "auth_header": + k, s = stickyKeyFromAuthHeader(r, modelName) + case "ip": + k, s = stickyKeyFromIP(r, modelName) + } + if k != "" { + return k, s + } + } + + return "", "" +} + +// stickyKeyFromSessionHeader hashes the session ID header with FNV-64a so that +// unbounded client-supplied strings do not inflate cache key memory. +func stickyKeyFromSessionHeader(r *http.Request, modelName string) (string, string) { + v := r.Header.Get(constants.HeaderXOllaSessionID) + if v == "" { + return "", "" + } + h := fnv.New64a() + h.Write([]byte(v)) + return uint64ToHex(h.Sum64()) + ":" + modelName, "session_header" +} + +// stickyKeyFromPrefixHash hashes the first prefixBytes bytes of the messages +// JSON array so requests with identical conversation prefixes are routed together. +func stickyKeyFromPrefixHash(body []byte, modelName string, prefixBytes int) (string, string) { + if len(body) == 0 { + return "", "" + } + raw := gjson.GetBytes(body, "messages").Raw + if raw == "" { + return "", "" + } + limit := prefixBytes + if limit <= 0 || limit > len(raw) { + limit = len(raw) + } + h := fnv.New64a() + h.Write([]byte(raw[:limit])) + return strings.ReplaceAll(modelName, ":", "_") + ":" + uint64ToHex(h.Sum64()), "prefix_hash" +} + +// stickyKeyFromAuthHeader hashes the Authorization header value so that tokens +// are never stored in plaintext inside the session store. +func stickyKeyFromAuthHeader(r *http.Request, modelName string) (string, string) { + v := r.Header.Get("Authorization") + if v == "" { + return "", "" + } + h := fnv.New64a() + h.Write([]byte(v)) + return "auth:" + uint64ToHex(h.Sum64()) + ":" + modelName, "auth_header" +} + +// stickyKeyFromIP extracts the remote host using net.SplitHostPort, which +// handles bracketed IPv6 addresses correctly (strings.LastIndex cannot). +func stickyKeyFromIP(r *http.Request, modelName string) (string, string) { + addr := r.RemoteAddr + if addr == "" { + return "", "" + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + // Bare address with no port — use as-is. + host = addr + } + if host == "" { + return "", "" + } + return "ip:" + host + ":" + modelName, "ip" +} + +// StickyStats holds a point-in-time snapshot of sticky session activity. +type StickyStats struct { + Enabled bool `json:"enabled"` + ActiveSessions int `json:"active_sessions"` + Insertions uint64 `json:"insertions"` + Hits uint64 `json:"hits"` + Misses uint64 `json:"misses"` + Evictions uint64 `json:"evictions"` + MaxSessions uint64 `json:"max_sessions"` + IdleTTLSeconds int `json:"idle_ttl_seconds"` +} + +// Stats returns a point-in-time snapshot of the session store metrics. +func (s *StickySessionWrapper) Stats() StickyStats { + m := s.store.Metrics() + return StickyStats{ + Enabled: true, + ActiveSessions: s.store.Len(), + Insertions: m.Insertions, + Hits: m.Hits, + Misses: m.Misses, + Evictions: m.Evictions, + MaxSessions: s.cfg.MaxSessions, + IdleTTLSeconds: s.cfg.IdleTTLSeconds, + } +} + +// uint64ToHex converts a uint64 to a hex string without importing fmt (avoids allocation). +func uint64ToHex(v uint64) string { + const digits = "0123456789abcdef" + buf := make([]byte, 16) + for i := 15; i >= 0; i-- { + buf[i] = digits[v&0xf] + v >>= 4 + } + return string(buf) +} diff --git a/internal/adapter/balancer/sticky_test.go b/internal/adapter/balancer/sticky_test.go new file mode 100644 index 0000000..89d6d5d --- /dev/null +++ b/internal/adapter/balancer/sticky_test.go @@ -0,0 +1,572 @@ +package balancer + +import ( + "context" + "fmt" + "net/http" + "net/url" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thushan/olla/internal/config" + "github.com/thushan/olla/internal/core/constants" + "github.com/thushan/olla/internal/core/domain" +) + +// --- test helpers --- + +func defaultStickyConfig() config.StickySessionConfig { + return config.StickySessionConfig{ + Enabled: true, + IdleTTLSeconds: 60, + MaxSessions: 100, + KeySources: []string{"session_header", "prefix_hash", "auth_header", "ip"}, + PrefixHashBytes: 512, + } +} + +func makeEndpoint(name, rawURL string) *domain.Endpoint { + u, _ := url.Parse(rawURL) + return &domain.Endpoint{ + Name: name, + URL: u, + URLString: rawURL, + Status: domain.StatusHealthy, + } +} + +func makeWrapper(t *testing.T, cfg config.StickySessionConfig) *StickySessionWrapper { + t.Helper() + inner := NewRoundRobinSelector(nil) + // patch IncrementConnections/DecrementConnections to accept nil statsCollector + w := NewStickySessionWrapper(inner, cfg) + w.Start() + t.Cleanup(w.Stop) + return w +} + +// injectKey builds a context carrying the sticky key and an outcome pointer. +func injectKey(parent context.Context, key, source string) (context.Context, *StickyOutcome) { + outcome := &StickyOutcome{} + ctx := context.WithValue(parent, constants.ContextStickyKeyKey, key) + ctx = context.WithValue(ctx, constants.ContextStickyKeySourceKey, source) + ctx = context.WithValue(ctx, constants.ContextStickyOutcomeKey, outcome) + return ctx, outcome +} + +// --- RoundRobinSelector with nil stats shim --- +// The existing RoundRobinSelector panics on nil statsCollector only inside +// IncrementConnections/DecrementConnections (which call RecordConnection). +// Select itself works fine, so we can use it directly in unit tests where we +// never call those methods. + +// --- tests --- + +func TestStickySessionWrapper_Miss(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + ep2 := makeEndpoint("ep2", "http://backend2:8080") + endpoints := []*domain.Endpoint{ep1, ep2} + + ctx, outcome := injectKey(context.Background(), "sess-abc:llama3", "session_header") + + chosen, err := w.Select(ctx, endpoints) + require.NoError(t, err) + assert.NotNil(t, chosen) + assert.Equal(t, "miss", outcome.Result) + assert.Equal(t, "session_header", outcome.Source) +} + +func TestStickySessionWrapper_Hit(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + ep2 := makeEndpoint("ep2", "http://backend2:8080") + endpoints := []*domain.Endpoint{ep1, ep2} + + const stickyKey = "sess-hit:llama3" + + // First call — should be a miss and pin a backend. + ctx1, _ := injectKey(context.Background(), stickyKey, "session_header") + first, err := w.Select(ctx1, endpoints) + require.NoError(t, err) + + // Second call with same key — should return the same backend (hit). + ctx2, outcome2 := injectKey(context.Background(), stickyKey, "session_header") + second, err := w.Select(ctx2, endpoints) + require.NoError(t, err) + + assert.Equal(t, first.URLString, second.URLString, "second request should be pinned to the same backend") + assert.Equal(t, "hit", outcome2.Result) +} + +func TestStickySessionWrapper_Repin(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + ep2 := makeEndpoint("ep2", "http://backend2:8080") + + const stickyKey = "sess-repin:llama3" + + // Pin to ep1. + ctx1, _ := injectKey(context.Background(), stickyKey, "session_header") + first, err := w.Select(ctx1, []*domain.Endpoint{ep1, ep2}) + require.NoError(t, err) + + // Remove pinned backend from routable set — simulate it going offline. + remaining := []*domain.Endpoint{ep1, ep2} + for i, ep := range remaining { + if ep.URLString == first.URLString { + remaining[i] = remaining[len(remaining)-1] + remaining = remaining[:len(remaining)-1] + break + } + } + + ctx2, outcome2 := injectKey(context.Background(), stickyKey, "session_header") + second, err := w.Select(ctx2, remaining) + require.NoError(t, err) + + assert.NotEqual(t, first.URLString, second.URLString, "repin should select a different backend") + assert.Equal(t, "repin", outcome2.Result) +} + +func TestStickySessionWrapper_NoKey(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + + // Empty key — wrapper should pass through transparently. + ctx := context.Background() + outcome := &StickyOutcome{} + ctx = context.WithValue(ctx, constants.ContextStickyOutcomeKey, outcome) + + chosen, err := w.Select(ctx, []*domain.Endpoint{ep1}) + require.NoError(t, err) + assert.Equal(t, ep1.URLString, chosen.URLString) + assert.Equal(t, "disabled", outcome.Result) + assert.Equal(t, "none", outcome.Source) +} + +func TestStickySessionWrapper_TTLExpiry(t *testing.T) { + if testing.Short() { + t.Skip("skipping TTL test in short mode") + } + t.Parallel() + + cfg := defaultStickyConfig() + cfg.IdleTTLSeconds = 1 // 1 second for a fast test + + w := makeWrapper(t, cfg) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + ep2 := makeEndpoint("ep2", "http://backend2:8080") + endpoints := []*domain.Endpoint{ep1, ep2} + + const stickyKey = "sess-ttl:llama3" + + ctx1, _ := injectKey(context.Background(), stickyKey, "session_header") + first, err := w.Select(ctx1, endpoints) + require.NoError(t, err) + + // Confirm pinned. + ctx2, outcome2 := injectKey(context.Background(), stickyKey, "session_header") + second, _ := w.Select(ctx2, endpoints) + assert.Equal(t, first.URLString, second.URLString) + assert.Equal(t, "hit", outcome2.Result) + + // Poll until the ttlcache TTL expires (1 s) rather than sleeping a fixed + // duration. Cap at 2 s to stay well above the TTL without being brittle. + deadline := time.Now().Add(2 * time.Second) + var outcome3 *StickyOutcome + for time.Now().Before(deadline) { + ctx3, o3 := injectKey(context.Background(), stickyKey, "session_header") + _, err = w.Select(ctx3, endpoints) + require.NoError(t, err) + outcome3 = o3 + if outcome3.Result == "miss" { + break + } + time.Sleep(50 * time.Millisecond) + } + // After TTL the entry is gone, so it's a fresh miss not a repin. + require.NotNil(t, outcome3) + assert.Equal(t, "miss", outcome3.Result) +} + +func TestStickySessionWrapper_ModelScoping(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + ep2 := makeEndpoint("ep2", "http://backend2:8080") + endpoints := []*domain.Endpoint{ep1, ep2} + + // Same session but different models — keys must be distinct. + keyModel1 := "sess-scope:modelA" + keyModel2 := "sess-scope:modelB" + + ctx1, _ := injectKey(context.Background(), keyModel1, "session_header") + chosenA, err := w.Select(ctx1, endpoints) + require.NoError(t, err) + + // Force the second model to a specific backend so we can assert they differ. + // Simply select a few times — the round-robin inner will distribute. + var chosenB *domain.Endpoint + for range 10 { + ctx2, _ := injectKey(context.Background(), keyModel2, "session_header") + chosenB, _ = w.Select(ctx2, endpoints) + if chosenB.URLString != chosenA.URLString { + break + } + } + // The important assertion: each model key is tracked independently. + ctx3, out3 := injectKey(context.Background(), keyModel1, "session_header") + third, err := w.Select(ctx3, endpoints) + require.NoError(t, err) + assert.Equal(t, chosenA.URLString, third.URLString, "model-scoped key should return same backend") + assert.Equal(t, "hit", out3.Result) +} + +func TestStickySessionWrapper_KeySources_SessionHeader(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + endpoints := []*domain.Endpoint{ep1} + + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(constants.HeaderXOllaSessionID, "my-session-id") + + key, source := ComputeStickyKey(req, "llama3", defaultStickyConfig(), nil) + // Fix 1: session_header now hashes the raw value with FNV-64a, producing a + // 16-hex-char prefix. The raw string "my-session-id" must not appear in the key. + assert.Equal(t, "bd95cc3ab55faccc:llama3", key) + assert.NotContains(t, key, "my-session-id") + assert.Equal(t, "session_header", source) + + // Verify it routes + ctx, out := injectKey(context.Background(), key, source) + _, err := w.Select(ctx, endpoints) + require.NoError(t, err) + assert.Equal(t, "miss", out.Result) +} + +func TestStickySessionWrapper_KeySources_PrefixHash(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"llama3","messages":[{"role":"user","content":"hello"}]}`) + req, _ := http.NewRequest(http.MethodPost, "/", nil) + + key, source := ComputeStickyKey(req, "llama3", defaultStickyConfig(), body) + assert.Equal(t, "prefix_hash", source) + assert.NotEmpty(t, key) +} + +func TestStickySessionWrapper_KeySources_AuthHash(t *testing.T) { + t.Parallel() + + cfg := config.StickySessionConfig{ + KeySources: []string{"auth_header"}, + PrefixHashBytes: 512, + } + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.Header.Set("Authorization", "Bearer token-xyz") + + key, source := ComputeStickyKey(req, "llama3", cfg, nil) + assert.Equal(t, "auth_header", source) + assert.NotEmpty(t, key) +} + +func TestStickySessionWrapper_KeySources_IP(t *testing.T) { + t.Parallel() + + cfg := config.StickySessionConfig{ + KeySources: []string{"ip"}, + PrefixHashBytes: 512, + } + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.RemoteAddr = "192.168.1.42:12345" + + key, source := ComputeStickyKey(req, "llama3", cfg, nil) + assert.Equal(t, "ip", source) + assert.Contains(t, key, "192.168.1.42") +} + +func TestStickySessionWrapper_KeySources_NoMatch(t *testing.T) { + t.Parallel() + + cfg := config.StickySessionConfig{ + KeySources: []string{"session_header"}, // header not present + PrefixHashBytes: 512, + } + req, _ := http.NewRequest(http.MethodPost, "/", nil) + + key, source := ComputeStickyKey(req, "llama3", cfg, nil) + assert.Empty(t, key) + assert.Empty(t, source) +} + +func TestComputeStickyKey_ModelScope(t *testing.T) { + t.Parallel() + + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(constants.HeaderXOllaSessionID, "same-session") + + keyA, _ := ComputeStickyKey(req, "llama3", defaultStickyConfig(), nil) + keyB, _ := ComputeStickyKey(req, "mistral", defaultStickyConfig(), nil) + + assert.NotEqual(t, keyA, keyB, "same session ID with different models must produce different keys") +} + +func TestComputeStickyKey_PrefixHashBytes_Truncation(t *testing.T) { + t.Parallel() + + body := []byte(`{"messages":[{"role":"user","content":"` + string(make([]byte, 1000)) + `"}]}`) + req, _ := http.NewRequest(http.MethodPost, "/", nil) + + cfg := config.StickySessionConfig{ + KeySources: []string{"prefix_hash"}, + PrefixHashBytes: 16, // very small limit + } + key1, _ := ComputeStickyKey(req, "llama3", cfg, body) + + cfg2 := config.StickySessionConfig{ + KeySources: []string{"prefix_hash"}, + PrefixHashBytes: 32, + } + key2, _ := ComputeStickyKey(req, "llama3", cfg2, body) + + // Different prefix lengths should produce different hashes. + assert.NotEqual(t, key1, key2) +} + +func TestStickySessionWrapper_Race(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + ep2 := makeEndpoint("ep2", "http://backend2:8080") + endpoints := []*domain.Endpoint{ep1, ep2} + + const goroutines = 50 + const iters = 20 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := range goroutines { + + go func() { + defer wg.Done() + for range iters { + key := fmt.Sprintf("sess-race-%d:llama3", g%5) // share some keys to exercise contention + ctx, _ := injectKey(context.Background(), key, "session_header") + _, err := w.Select(ctx, endpoints) + if err != nil { + t.Errorf("Select returned error: %v", err) + } + } + }() + } + + wg.Wait() +} + +func TestStickySessionWrapper_PurgeDeadEndpoints(t *testing.T) { + t.Parallel() + + w := makeWrapper(t, defaultStickyConfig()) + ep1 := makeEndpoint("ep1", "http://backend1:8080") + ep2 := makeEndpoint("ep2", "http://backend2:8080") + + // Pin two different sessions to both backends. + ctx1, _ := injectKey(context.Background(), "sess-purge1:m", "session_header") + w.Select(ctx1, []*domain.Endpoint{ep1}) //nolint:errcheck + + ctx2, _ := injectKey(context.Background(), "sess-purge2:m", "session_header") + w.Select(ctx2, []*domain.Endpoint{ep2}) //nolint:errcheck + + // Purge: only ep1 is alive. + w.PurgeDeadEndpoints([]*domain.Endpoint{ep1}) + + // sess-purge2 (pinned to ep2) should be gone → next select is a miss. + ctx3, out3 := injectKey(context.Background(), "sess-purge2:m", "session_header") + _, err := w.Select(ctx3, []*domain.Endpoint{ep1, ep2}) + require.NoError(t, err) + assert.Equal(t, "miss", out3.Result, "session pinned to purged backend should be a fresh miss") + + // sess-purge1 (pinned to ep1) should still be a hit. + ctx4, out4 := injectKey(context.Background(), "sess-purge1:m", "session_header") + _, err = w.Select(ctx4, []*domain.Endpoint{ep1, ep2}) + require.NoError(t, err) + assert.Equal(t, "hit", out4.Result, "session pinned to surviving backend should still hit") +} + +// --- Fix 1: session_header hashing --- + +// TestComputeStickyKey_SessionHeader_IsHashed verifies that the raw session ID +// value never appears in the computed key — only its FNV-64a hex digest does. +func TestComputeStickyKey_SessionHeader_IsHashed(t *testing.T) { + t.Parallel() + + const modelName = "llama3" + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(constants.HeaderXOllaSessionID, "my-session-id") + + cfg := config.StickySessionConfig{ + KeySources: []string{"session_header"}, + PrefixHashBytes: 512, + } + + key, source := ComputeStickyKey(req, modelName, cfg, nil) + + assert.Equal(t, "session_header", source) + assert.NotContains(t, key, "my-session-id", "raw session ID must not appear in the key") + // 16 hex chars + ":" + modelName + assert.Equal(t, 16+1+len(modelName), len(key), "key must be exactly 16 hex chars + colon + model name") +} + +// TestComputeStickyKey_SessionHeader_LargeValue_IsHashed confirms that arbitrarily +// long session IDs are bounded to a fixed-length key after hashing. +func TestComputeStickyKey_SessionHeader_LargeValue_IsHashed(t *testing.T) { + t.Parallel() + + const modelName = "llama3" + largeValue := string(make([]byte, 10000)) // 10 000 zero bytes → valid UTF-8 + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(constants.HeaderXOllaSessionID, largeValue) + + cfg := config.StickySessionConfig{ + KeySources: []string{"session_header"}, + PrefixHashBytes: 512, + } + + key, source := ComputeStickyKey(req, modelName, cfg, nil) + + assert.Equal(t, "session_header", source) + // Regardless of input length, output is always 16 hex + ":" + model. + assert.Equal(t, 16+1+len(modelName), len(key), "unbounded input must produce a bounded key") +} + +// TestComputeStickyKey_SessionHeader_SameValueSameKey verifies that hashing is +// deterministic — two requests with identical session IDs produce identical keys. +func TestComputeStickyKey_SessionHeader_SameValueSameKey(t *testing.T) { + t.Parallel() + + const modelName = "llama3" + cfg := config.StickySessionConfig{ + KeySources: []string{"session_header"}, + PrefixHashBytes: 512, + } + + req1, _ := http.NewRequest(http.MethodPost, "/", nil) + req1.Header.Set(constants.HeaderXOllaSessionID, "deterministic-session") + + req2, _ := http.NewRequest(http.MethodPost, "/", nil) + req2.Header.Set(constants.HeaderXOllaSessionID, "deterministic-session") + + key1, _ := ComputeStickyKey(req1, modelName, cfg, nil) + key2, _ := ComputeStickyKey(req2, modelName, cfg, nil) + + assert.Equal(t, key1, key2, "same session ID must always produce the same key") +} + +// --- Fix 3: ip key source uses net.SplitHostPort --- + +// TestComputeStickyKey_IP_IPv6Loopback verifies that IPv6 addresses are handled +// correctly — brackets and port are stripped, leaving only the clean host. +func TestComputeStickyKey_IP_IPv6Loopback(t *testing.T) { + t.Parallel() + + cfg := config.StickySessionConfig{ + KeySources: []string{"ip"}, + PrefixHashBytes: 512, + } + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.RemoteAddr = "[::1]:54321" + + key, source := ComputeStickyKey(req, "llama3", cfg, nil) + + assert.Equal(t, "ip", source) + assert.Contains(t, key, "::1", "clean IPv6 host must appear in key") + assert.NotContains(t, key, "54321", "port must be stripped from key") + assert.NotContains(t, key, "[", "opening bracket must be stripped by net.SplitHostPort") + assert.NotContains(t, key, "]", "closing bracket must be stripped by net.SplitHostPort") +} + +// TestComputeStickyKey_IP_IPv4 verifies that IPv4 address:port is correctly split +// and only the host is included in the key. +func TestComputeStickyKey_IP_IPv4(t *testing.T) { + t.Parallel() + + cfg := config.StickySessionConfig{ + KeySources: []string{"ip"}, + PrefixHashBytes: 512, + } + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.RemoteAddr = "192.168.1.42:12345" + + key, source := ComputeStickyKey(req, "llama3", cfg, nil) + + assert.Equal(t, "ip", source) + assert.Contains(t, key, "192.168.1.42", "IPv4 host must appear in key") + assert.NotContains(t, key, "12345", "port must be stripped from key") +} + +// TestComputeStickyKey_IP_BareAddress verifies the fallback path where RemoteAddr +// has no port (e.g. a custom listener that omits the port). The address must still +// be usable as a key rather than being discarded. +func TestComputeStickyKey_IP_BareAddress(t *testing.T) { + t.Parallel() + + cfg := config.StickySessionConfig{ + KeySources: []string{"ip"}, + PrefixHashBytes: 512, + } + req, _ := http.NewRequest(http.MethodPost, "/", nil) + req.RemoteAddr = "10.0.0.1" // no port + + key, source := ComputeStickyKey(req, "llama3", cfg, nil) + + assert.Equal(t, "ip", source) + assert.NotEmpty(t, key, "bare address without port must still produce a key") + assert.Contains(t, key, "10.0.0.1", "bare host must appear in key") +} + +// --- Fix 2: zero TTL warning --- + +// TestNewStickySessionWrapper_ZeroTTL_NoPanic verifies that constructing a wrapper +// with IdleTTLSeconds == 0 does not panic. The warning is emitted to slog but +// capturing structured log output in tests requires non-trivial plumbing; this +// test focuses on the observable guarantee (no panic, wrapper is usable). +func TestNewStickySessionWrapper_ZeroTTL_NoPanic(t *testing.T) { + t.Parallel() + + cfg := config.StickySessionConfig{ + Enabled: true, + IdleTTLSeconds: 0, // triggers the warning branch + MaxSessions: 10, + KeySources: []string{"session_header"}, + PrefixHashBytes: 512, + } + + // Must not panic. + inner := NewRoundRobinSelector(nil) + w := NewStickySessionWrapper(inner, cfg) + w.Start() + t.Cleanup(w.Stop) + + ep := makeEndpoint("ep1", "http://backend1:8080") + ctx, out := injectKey(context.Background(), "zero-ttl-key:llama3", "session_header") + _, err := w.Select(ctx, []*domain.Endpoint{ep}) + require.NoError(t, err) + assert.Equal(t, "miss", out.Result, "wrapper with zero TTL must still route requests") +} diff --git a/internal/adapter/health/checker.go b/internal/adapter/health/checker.go index cec5af1..a382a88 100644 --- a/internal/adapter/health/checker.go +++ b/internal/adapter/health/checker.go @@ -24,13 +24,14 @@ const ( ) type HTTPHealthChecker struct { - repository domain.EndpointRepository - logger logger.StyledLogger - recoveryCallback RecoveryCallback - healthClient *HealthClient - ticker *time.Ticker - stopCh chan struct{} - isRunning atomic.Bool + repository domain.EndpointRepository + logger logger.StyledLogger + recoveryCallback RecoveryCallback + unhealthyCallback UnhealthyCallback + healthClient *HealthClient + ticker *time.Ticker + stopCh chan struct{} + isRunning atomic.Bool } func NewHTTPHealthChecker(repository domain.EndpointRepository, logger logger.StyledLogger, client HTTPClient) *HTTPHealthChecker { @@ -53,6 +54,13 @@ func (c *HTTPHealthChecker) SetRecoveryCallback(callback RecoveryCallback) { } } +// SetUnhealthyCallback sets the callback to be invoked when an endpoint transitions to an unhealthy state. +func (c *HTTPHealthChecker) SetUnhealthyCallback(callback UnhealthyCallback) { + if callback != nil { + c.unhealthyCallback = callback + } +} + func NewHTTPHealthCheckerWithDefaults(repository domain.EndpointRepository, logger logger.StyledLogger) *HTTPHealthChecker { // We want to enable connection pooling and reuse with some sane defaults client := &http.Client{ @@ -274,6 +282,21 @@ func (c *HTTPHealthChecker) checkEndpoint(ctx context.Context, endpoint *domain. } } + // Trigger unhealthy callback only when an endpoint becomes non-routable from a + // routable state. Busy and Warming are still routable, so a Healthy→Busy transition + // must not evict sticky sessions — that would defeat KV-cache affinity entirely. + // Unknown→Unhealthy is intentionally excluded: nothing could have been pinned to an + // endpoint that was never routable, so there is nothing to purge. + if statusChanged && !newStatus.IsRoutable() && oldStatus.IsRoutable() { + if c.unhealthyCallback != nil { + go func(ep domain.Endpoint) { + callbackCtx, cancel := context.WithTimeout(context.Background(), DefaultRecoveryCallbackTimeout) + defer cancel() + c.unhealthyCallback.OnEndpointUnhealthy(callbackCtx, &ep) + }(endpointCopy) + } + } + c.logHealthCheckResult(endpoint, oldStatus, newStatus, statusChanged, result, nextInterval, err) } diff --git a/internal/adapter/health/checker_test.go b/internal/adapter/health/checker_test.go index 54b1cd3..053f97b 100644 --- a/internal/adapter/health/checker_test.go +++ b/internal/adapter/health/checker_test.go @@ -588,6 +588,113 @@ func TestHTTPHealthChecker_ContextCancellation(t *testing.T) { } } +// nonRetryableError is a plain error (not net.Error) so classifyError returns +// ErrorTypeHTTPError, which makes shouldRetry return false. This avoids the +// exponential-backoff retry delays inside HealthClient.Check during unit tests. +type nonRetryableError struct{} + +func (e *nonRetryableError) Error() string { return "non-retryable test error" } + +// nonRetryingHTTPClient returns a non-retryable error, collapsing the retry +// loop inside HealthClient to a single attempt for fast unit tests. +type nonRetryingHTTPClient struct{} + +func (c *nonRetryingHTTPClient) Do(_ *http.Request) (*http.Response, error) { + return nil, &nonRetryableError{} +} + +// TestUnhealthyCallbackPredicate verifies that the unhealthy callback fires only when +// an endpoint transitions from a routable state to a non-routable one. The previous +// predicate (newStatus != Healthy && oldStatus != Unknown) incorrectly fired on +// Healthy→Busy and Healthy→Warming, evicting sticky sessions unnecessarily. +func TestUnhealthyCallbackPredicate(t *testing.T) { + t.Parallel() + + loggerCfg := &logger.Config{Level: "error", Theme: "default"} + log, cleanup, _ := logger.New(loggerCfg) + defer cleanup() + styledLogger := logger.NewPlainStyledLogger(log) + + makeEndpoint := func(urlStr string, status domain.EndpointStatus) *domain.Endpoint { + u, _ := url.Parse(urlStr) + hcu, _ := url.Parse(urlStr + "/health") + return &domain.Endpoint{ + Name: urlStr, + URL: u, + HealthCheckURL: hcu, + URLString: u.String(), + HealthCheckURLString: hcu.String(), + Status: status, + CheckTimeout: time.Second, + } + } + + // errClient returns a non-retryable error → StatusUnhealthy (non-routable). + // Using a plain error (not net.Error) avoids the retry-backoff delay in HealthClient. + errClient := &nonRetryingHTTPClient{} + // okClient returns HTTP 200 → StatusHealthy (routable) + okClient := &mockHTTPClient{statusCode: 200} + + tests := []struct { + name string + oldStatus domain.EndpointStatus + client HTTPClient + wantFired bool + }{ + // Routable → non-routable: callback must fire. + {name: "Healthy→Unhealthy fires", oldStatus: domain.StatusHealthy, client: errClient, wantFired: true}, + {name: "Busy→Unhealthy fires", oldStatus: domain.StatusBusy, client: errClient, wantFired: true}, + {name: "Warming→Unhealthy fires", oldStatus: domain.StatusWarming, client: errClient, wantFired: true}, + + // Already non-routable → non-routable: nothing was pinned, so no purge. + {name: "Unknown→Unhealthy no fire", oldStatus: domain.StatusUnknown, client: errClient, wantFired: false}, + {name: "Offline→Unhealthy no fire", oldStatus: domain.StatusOffline, client: errClient, wantFired: false}, + + // Routable → routable: keep sticky sessions intact. + {name: "Healthy→Healthy no fire (no change)", oldStatus: domain.StatusHealthy, client: okClient, wantFired: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + fired := make(chan struct{}, 1) + repo := newMockRepository() + ep := makeEndpoint("http://127.0.0.1:19999", tc.oldStatus) + repo.mu.Lock() + repo.endpoints[ep.URLString] = ep + repo.mu.Unlock() + + checker := NewHTTPHealthChecker(repo, styledLogger, tc.client) + checker.SetUnhealthyCallback(UnhealthyCallbackFunc(func(_ context.Context, _ *domain.Endpoint) { + select { + case fired <- struct{}{}: + default: + } + })) + + ctx := context.Background() + checker.checkEndpoint(ctx, ep) + + // The callback dispatches asynchronously in a goroutine; give it a + // short window to fire before concluding it won't. + var gotFired bool + select { + case <-fired: + gotFired = true + case <-time.After(200 * time.Millisecond): + } + + if gotFired && !tc.wantFired { + t.Errorf("unhealthy callback fired but should not have (old=%s)", tc.oldStatus) + } + if !gotFired && tc.wantFired { + t.Errorf("unhealthy callback not fired but should have (old=%s)", tc.oldStatus) + } + }) + } +} + type panicHTTPClient struct{} func (p *panicHTTPClient) Do(req *http.Request) (*http.Response, error) { diff --git a/internal/adapter/health/recovery_callback.go b/internal/adapter/health/recovery_callback.go index 900f6b2..3ba4f9c 100644 --- a/internal/adapter/health/recovery_callback.go +++ b/internal/adapter/health/recovery_callback.go @@ -24,3 +24,23 @@ type NoOpRecoveryCallback struct{} func (n NoOpRecoveryCallback) OnEndpointRecovered(ctx context.Context, endpoint *domain.Endpoint) error { return nil } + +// UnhealthyCallback is called when an endpoint transitions from healthy (or unknown) to an +// unhealthy state. Callers can use this to proactively clean up state tied to the dead backend +// (e.g. purging sticky session entries) rather than waiting for TTL expiry. +type UnhealthyCallback interface { + OnEndpointUnhealthy(ctx context.Context, endpoint *domain.Endpoint) +} + +// UnhealthyCallbackFunc is a function adapter for UnhealthyCallback +type UnhealthyCallbackFunc func(ctx context.Context, endpoint *domain.Endpoint) + +func (f UnhealthyCallbackFunc) OnEndpointUnhealthy(ctx context.Context, endpoint *domain.Endpoint) { + f(ctx, endpoint) +} + +// NoOpUnhealthyCallback is a no-op implementation of UnhealthyCallback +type NoOpUnhealthyCallback struct{} + +func (n NoOpUnhealthyCallback) OnEndpointUnhealthy(ctx context.Context, endpoint *domain.Endpoint) { +} diff --git a/internal/adapter/proxy/core/common.go b/internal/adapter/proxy/core/common.go index 150468d..280e389 100644 --- a/internal/adapter/proxy/core/common.go +++ b/internal/adapter/proxy/core/common.go @@ -163,6 +163,29 @@ func extractClientIP(r *http.Request) string { return host } +// SetStickySessionHeaders writes sticky session outcome headers before WriteHeader +// is called. It reads the StickyOutcome pointer that was injected into the context +// by the handler layer after the balancer's Select fills it. Must be called before +// w.WriteHeader() — calling it afterwards is a no-op on a committed response. +func SetStickySessionHeaders(w http.ResponseWriter, r *http.Request) { + outcome, _ := r.Context().Value(constants.ContextStickyOutcomeKey).(*domain.StickyOutcome) + if outcome == nil { + return + } + h := w.Header() + if outcome.Result != "" { + h.Set(constants.HeaderXOllaStickySession, outcome.Result) + } + if outcome.Source != "" && outcome.Source != "none" { + h.Set(constants.HeaderXOllaStickyKeySource, outcome.Source) + } + if outcome.Source == "session_header" { + if sid := r.Header.Get(constants.HeaderXOllaSessionID); sid != "" { + h.Set(constants.HeaderXOllaSessionID, sid) + } + } +} + // SetResponseHeaders sets common response headers func SetResponseHeaders(w http.ResponseWriter, stats *ports.RequestStats, endpoint *domain.Endpoint) { h := w.Header() diff --git a/internal/adapter/proxy/core/common_test.go b/internal/adapter/proxy/core/common_test.go index 5476eeb..94495d5 100644 --- a/internal/adapter/proxy/core/common_test.go +++ b/internal/adapter/proxy/core/common_test.go @@ -1,6 +1,7 @@ package core import ( + "context" "crypto/tls" "net/http" "net/http/httptest" @@ -691,6 +692,90 @@ func BenchmarkCopyHeaders_WithExistingHeaders(b *testing.B) { } } +func TestSetStickySessionHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + outcome *domain.StickyOutcome + sessionIDHeader string // value of X-Olla-Session-ID on the request + expectSession string // expected X-Olla-Sticky-Session header value + expectSource string // expected X-Olla-Sticky-Key-Source header value + expectSessionID string // expected X-Olla-Session-ID echo in response + }{ + { + name: "nil_outcome_no_op", + outcome: nil, + expectSession: "", + expectSource: "", + }, + { + name: "hit_with_session_header_source", + outcome: &domain.StickyOutcome{Result: "hit", Source: "session_header"}, + sessionIDHeader: "my-session-123", + expectSession: "hit", + expectSource: "session_header", + expectSessionID: "my-session-123", + }, + { + name: "miss_with_prefix_hash_source", + outcome: &domain.StickyOutcome{Result: "miss", Source: "prefix_hash"}, + expectSession: "miss", + expectSource: "prefix_hash", + }, + { + name: "repin_with_ip_source", + outcome: &domain.StickyOutcome{Result: "repin", Source: "ip"}, + expectSession: "repin", + expectSource: "ip", + }, + { + name: "disabled_source_none_skips_key_source_header", + outcome: &domain.StickyOutcome{Result: "disabled", Source: "none"}, + expectSession: "disabled", + expectSource: "", // "none" must not be written + }, + { + name: "session_header_source_without_session_id_skips_echo", + outcome: &domain.StickyOutcome{Result: "hit", Source: "session_header"}, + // no X-Olla-Session-ID on request + sessionIDHeader: "", + expectSession: "hit", + expectSource: "session_header", + expectSessionID: "", // nothing to echo + }, + { + name: "empty_result_skips_sticky_session_header", + outcome: &domain.StickyOutcome{Result: "", Source: "ip"}, + expectSession: "", + expectSource: "ip", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodPost, "/test", nil) + if tt.sessionIDHeader != "" { + req.Header.Set(constants.HeaderXOllaSessionID, tt.sessionIDHeader) + } + + if tt.outcome != nil { + ctx := context.WithValue(req.Context(), constants.ContextStickyOutcomeKey, tt.outcome) + req = req.WithContext(ctx) + } + + w := httptest.NewRecorder() + SetStickySessionHeaders(w, req) + + assert.Equal(t, tt.expectSession, w.Header().Get(constants.HeaderXOllaStickySession)) + assert.Equal(t, tt.expectSource, w.Header().Get(constants.HeaderXOllaStickyKeySource)) + assert.Equal(t, tt.expectSessionID, w.Header().Get(constants.HeaderXOllaSessionID)) + }) + } +} + // BenchmarkSetResponseHeaders benchmarks the SetResponseHeaders function func BenchmarkSetResponseHeaders(b *testing.B) { stats := &ports.RequestStats{ diff --git a/internal/adapter/proxy/olla/service.go b/internal/adapter/proxy/olla/service.go index 8b1792b..b58ccb1 100644 --- a/internal/adapter/proxy/olla/service.go +++ b/internal/adapter/proxy/olla/service.go @@ -560,6 +560,7 @@ func (s *Service) handleSuccessfulResponse(ctx context.Context, w http.ResponseW rlog.Debug("round-trip success", "status", resp.StatusCode) core.SetResponseHeaders(w, stats, endpoint) + core.SetStickySessionHeaders(w, r) // Copy response headers for key, values := range resp.Header { diff --git a/internal/adapter/proxy/olla/service_retry.go b/internal/adapter/proxy/olla/service_retry.go index 2650d31..eba0258 100644 --- a/internal/adapter/proxy/olla/service_retry.go +++ b/internal/adapter/proxy/olla/service_retry.go @@ -122,6 +122,7 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit rlog.Debug("round-trip success", "status", resp.StatusCode) core.SetResponseHeaders(w, stats, endpoint) + core.SetStickySessionHeaders(w, r) // Copy response headers for key, values := range resp.Header { diff --git a/internal/adapter/proxy/sherpa/service_retry.go b/internal/adapter/proxy/sherpa/service_retry.go index 86ccbbe..901b320 100644 --- a/internal/adapter/proxy/sherpa/service_retry.go +++ b/internal/adapter/proxy/sherpa/service_retry.go @@ -115,6 +115,7 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit rlog.Debug("round-trip success", "status", resp.StatusCode) core.SetResponseHeaders(w, stats, endpoint) + core.SetStickySessionHeaders(w, r) // Copy response headers for key, values := range resp.Header { diff --git a/internal/app/app.go b/internal/app/app.go index b5076ed..64a4784 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -106,6 +106,9 @@ func registerServices(manager *services.ServiceManager, cfg *config.Config, logg } if disc, err := registry.GetDiscovery(); err == nil { proxy.SetDiscoveryService(disc) + // SetPurgeDeadEndpointsFn is called inside ProxyServiceWrapper.Start() after + // stickyWrapper is assigned, so the registration happens-before the health-checker + // goroutine reads it. No registration needed here. } if sec, err := registry.GetSecurity(); err == nil { proxy.SetSecurityService(sec) diff --git a/internal/app/handlers/application.go b/internal/app/handlers/application.go index eb18e02..3e8b31d 100644 --- a/internal/app/handlers/application.go +++ b/internal/app/handlers/application.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/thushan/olla/internal/adapter/balancer" "github.com/thushan/olla/internal/adapter/converter" "github.com/thushan/olla/internal/adapter/inspector" "github.com/thushan/olla/internal/adapter/registry" @@ -84,10 +85,13 @@ type Application struct { profileFactory profile.ProfileFactory profileLookup translator.ProfileLookup translatorRegistry *translator.Registry - aliasResolver *registry.AliasResolver - server *http.Server - errCh chan error - StartTime time.Time + // stickyStatsFn is non-nil when sticky sessions are enabled. Stored as a + // closure so the handler layer does not need to import the balancer package. + stickyStatsFn func() *balancer.StickyStats + aliasResolver *registry.AliasResolver + server *http.Server + errCh chan error + StartTime time.Time } // NewApplication creates a new Application instance with all required dependencies @@ -219,6 +223,12 @@ func (a *Application) GetProfileLookup() translator.ProfileLookup { return a.profileLookup } +// SetStickyStatsFn wires in the sticky session stats provider after construction. +// Called by HTTPService when sticky sessions are enabled. +func (a *Application) SetStickyStatsFn(fn func() *balancer.StickyStats) { + a.stickyStatsFn = fn +} + func (a *Application) RegisterRoutes() { a.registerRoutes() } diff --git a/internal/app/handlers/handler_proxy.go b/internal/app/handlers/handler_proxy.go index 233a4f7..297cd91 100644 --- a/internal/app/handlers/handler_proxy.go +++ b/internal/app/handlers/handler_proxy.go @@ -1,11 +1,14 @@ package handlers import ( + "bytes" "context" "fmt" + "io" "net/http" "time" + "github.com/thushan/olla/internal/adapter/balancer" "github.com/thushan/olla/internal/app/middleware" "github.com/thushan/olla/internal/core/constants" "github.com/thushan/olla/internal/core/domain" @@ -39,6 +42,13 @@ func (a *Application) proxyHandler(w http.ResponseWriter, r *http.Request) { a.analyzeRequest(ctx, r, pr) + // Sticky session key must be computed after analyzeRequest so the model + // name is available; inject into context before endpoint selection. + // The outcome pointer is stored in context; the proxy engine reads it before WriteHeader. + if a.Config.Proxy.StickySessions.Enabled { + ctx, r, _ = a.injectStickyKey(ctx, r, pr.model) + } + endpoints, err := a.getCompatibleEndpoints(ctx, pr) if err != nil { a.handleEndpointError(w, pr, err) @@ -122,6 +132,81 @@ func (a *Application) analyzeRequest(ctx context.Context, r *http.Request, pr *p pr.stats.PathResolutionMs = time.Since(pathResolutionStart).Milliseconds() } +// injectStickyKey computes the affinity key for this request and injects it into the context. +// It reads up to prefix_hash_bytes of the request body (for prefix hashing) then restores the +// body so downstream handlers see it intact. The StickyOutcome pointer lets the wrapper report +// its hit/miss/repin decision back to the handler without an extra context lookup. +func (a *Application) injectStickyKey(ctx context.Context, r *http.Request, modelName string) (context.Context, *http.Request, *balancer.StickyOutcome) { + cfg := a.Config.Proxy.StickySessions + + // Read a small prefix of the body for prefix_hash; restore it afterwards. + // We cap at cfg.PrefixHashBytes+1 to handle the case where the body is exactly + // that length without allocating an oversized buffer. + var bodySnap []byte + if r.Body != nil && r.ContentLength != 0 { + limit := cfg.PrefixHashBytes + if limit <= 0 { + limit = 512 + } + snap, readErr := io.ReadAll(io.LimitReader(r.Body, int64(limit)+1)) + // Restore any bytes already consumed, even on partial-read error, so the + // downstream proxy always sees a complete body. + if len(snap) > 0 { + r.Body = io.NopCloser(io.MultiReader(bytes.NewReader(snap), r.Body)) + } + if readErr == nil { + bodySnap = snap + } + } + + return a.injectStickyKeyWithBody(ctx, r, modelName, bodySnap) +} + +// injectStickyKeyWithBody is the core sticky key injection path when the body has +// already been buffered by the caller (e.g. the translation handler reads the full +// body to extract the model name before we reach this point). Passing the bytes in +// avoids a second read/restore cycle on the same reader. +func (a *Application) injectStickyKeyWithBody(ctx context.Context, r *http.Request, modelName string, body []byte) (context.Context, *http.Request, *balancer.StickyOutcome) { + cfg := a.Config.Proxy.StickySessions + stickyKey, stickySource := balancer.ComputeStickyKey(r, modelName, cfg, body) + + outcome := &balancer.StickyOutcome{} + ctx = context.WithValue(ctx, constants.ContextStickyKeyKey, stickyKey) + ctx = context.WithValue(ctx, constants.ContextStickyKeySourceKey, stickySource) + ctx = context.WithValue(ctx, constants.ContextStickyOutcomeKey, outcome) + r = r.WithContext(ctx) + + return ctx, r, outcome +} + +// setStickyResponseHeadersFromRequest reads the StickyOutcome from the request context +// and writes the sticky session headers. Used by sub-handlers that have *http.Request +// but not the outcome pointer directly. Must be called before w.WriteHeader(). +func (a *Application) setStickyResponseHeadersFromRequest(w http.ResponseWriter, r *http.Request) { + outcome, _ := r.Context().Value(constants.ContextStickyOutcomeKey).(*balancer.StickyOutcome) + a.setStickyResponseHeaders(w, r, outcome) +} + +// setStickyResponseHeaders writes sticky session outcome headers so clients can observe +// affinity routing decisions. When the client provided an explicit session ID header, +// we echo it back so stateless clients can track their own session. +func (a *Application) setStickyResponseHeaders(w http.ResponseWriter, r *http.Request, outcome *balancer.StickyOutcome) { + if outcome == nil { + return + } + if outcome.Result != "" { + w.Header().Set(constants.HeaderXOllaStickySession, outcome.Result) + } + if outcome.Source != "" && outcome.Source != "none" { + w.Header().Set(constants.HeaderXOllaStickyKeySource, outcome.Source) + } + if outcome.Source == "session_header" { + if sid := r.Header.Get(constants.HeaderXOllaSessionID); sid != "" { + w.Header().Set(constants.HeaderXOllaSessionID, sid) + } + } +} + func (a *Application) getCompatibleEndpoints(ctx context.Context, pr *proxyRequest) ([]*domain.Endpoint, error) { endpoints, err := a.discoveryService.GetHealthyEndpoints(ctx) if err != nil { diff --git a/internal/app/handlers/handler_stats_sticky.go b/internal/app/handlers/handler_stats_sticky.go new file mode 100644 index 0000000..88b0d44 --- /dev/null +++ b/internal/app/handlers/handler_stats_sticky.go @@ -0,0 +1,40 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/thushan/olla/internal/core/constants" +) + +func (a *Application) stickyStatsHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + + if a.stickyStatsFn == nil { + // Sticky sessions are disabled — return a stable JSON shape so callers + // can branch on the "enabled" field rather than on status codes. + if err := json.NewEncoder(w).Encode(struct { + Enabled bool `json:"enabled"` + }{Enabled: false}); err != nil { + a.logger.Error("Failed to encode sticky stats response", "error", err) + } + return + } + + stats := a.stickyStatsFn() + if stats == nil { + // stickyStatsFn is wired but the wrapper was not created because sticky + // sessions are disabled in config — return the same stable shape as the + // nil-function path so callers always branch on "enabled", not status codes. + if err := json.NewEncoder(w).Encode(struct { + Enabled bool `json:"enabled"` + }{Enabled: false}); err != nil { + a.logger.Error("Failed to encode sticky stats response", "error", err) + } + return + } + if err := json.NewEncoder(w).Encode(stats); err != nil { + a.logger.Error("Failed to encode sticky stats response", "error", err) + } +} diff --git a/internal/app/handlers/handler_translation.go b/internal/app/handlers/handler_translation.go index 0faf935..19ccd27 100644 --- a/internal/app/handlers/handler_translation.go +++ b/internal/app/handlers/handler_translation.go @@ -257,6 +257,13 @@ func (a *Application) translationHandler(trans translator.RequestTranslator) htt // Run through proxy pipeline (inspector, security, routing) a.analyzeRequest(ctx, r, pr) + // Inject sticky session key. bodyBytes is already buffered from the model-name + // extraction above, so pass it directly to avoid a second read/restore cycle. + // The outcome pointer is stored in context; sub-handlers read it before WriteHeader. + if a.Config.Proxy.StickySessions.Enabled { + ctx, r, _ = a.injectStickyKeyWithBody(ctx, r, pr.model, bodyBytes) + } + // Get compatible endpoints for this request endpoints, err := a.getCompatibleEndpoints(ctx, pr) if err != nil { @@ -281,6 +288,7 @@ func (a *Application) translationHandler(trans translator.RequestTranslator) htt // Attempt passthrough if the translator and backends support it. // Returns true when passthrough was used and the request is complete. + // Sticky headers are written by the proxy engine before WriteHeader in this path. if a.tryPassthrough(ctx, w, r, bodyBytes, endpoints, pr, trans) { return } @@ -299,6 +307,7 @@ func (a *Application) translationHandler(trans translator.RequestTranslator) htt return } + // Sticky headers are written inside executeTranslationRequest before WriteHeader. a.executeTranslationRequest(ctx, w, r, endpoints, pr, trans, transformedReq) a.recordTranslatorMetrics(trans, pr, mode, fallbackReason) } @@ -330,7 +339,7 @@ func (a *Application) executeTranslatedNonStreamingRequest( // handle backend errors if recorder.status >= 400 { - return a.handleNonStreamingBackendError(w, recorder, openaiResp, pr, trans) + return a.handleNonStreamingBackendError(w, r, recorder, openaiResp, pr, trans) } // transform and write successful response @@ -354,6 +363,7 @@ func (a *Application) prepareProxyContext(ctx context.Context, r *http.Request, // handleNonStreamingBackendError processes backend errors and writes translated error response func (a *Application) handleNonStreamingBackendError( w http.ResponseWriter, + r *http.Request, recorder *responseRecorder, openaiResp map[string]interface{}, pr *proxyRequest, @@ -369,6 +379,7 @@ func (a *Application) handleNonStreamingBackendError( w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) a.copyOllaHeaders(recorder, w) a.setModelHeaderIfMissing(w, pr.model) + a.setStickyResponseHeadersFromRequest(w, r) // Use translator's error formatter if available if errorWriter, ok := trans.(translator.ErrorWriter); ok { @@ -446,6 +457,8 @@ func (a *Application) writeTranslatedSuccessResponse( w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) a.copyOllaHeaders(recorder, w) + // Write sticky headers before committing the response. + a.setStickyResponseHeadersFromRequest(w, r) // Serialize and write response respBody, err := json.Marshal(targetResp) @@ -504,13 +517,15 @@ func (a *Application) executeTranslatedStreamingRequest( // handle backend errors before starting sse stream if streamRecorder.status >= 400 { - a.handleStreamingBackendError(w, pipeReader, streamRecorder, proxyErrChan, pr, trans) + a.handleStreamingBackendError(w, r, pipeReader, streamRecorder, proxyErrChan, pr, trans) return nil } // copy olla headers before stream starts a.copyOllaHeaders(streamRecorder, w) a.setModelHeaderIfMissing(w, pr.model) + // Write sticky headers before the first write to w commits the response. + a.setStickyResponseHeadersFromRequest(w, r) // transform stream (blocks until done) and wait for proxy return a.transformStreamAndWaitForProxy(ctx, pipeReader, w, r, proxyErrChan, trans) @@ -592,6 +607,7 @@ func (a *Application) handleStreamingPanic( // handleStreamingBackendError processes backend errors during streaming func (a *Application) handleStreamingBackendError( w http.ResponseWriter, + r *http.Request, pipeReader *io.PipeReader, streamRecorder *streamingResponseRecorder, proxyErrChan chan error, @@ -612,6 +628,7 @@ func (a *Application) handleStreamingBackendError( w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) a.copyOllaHeaders(streamRecorder, w) a.setModelHeaderIfMissing(w, pr.model) + a.setStickyResponseHeadersFromRequest(w, r) // Use translator's error formatter if available if errorWriter, ok := trans.(translator.ErrorWriter); ok { diff --git a/internal/app/handlers/server_routes.go b/internal/app/handlers/server_routes.go index 60636c1..cd8d8d4 100644 --- a/internal/app/handlers/server_routes.go +++ b/internal/app/handlers/server_routes.go @@ -35,6 +35,7 @@ func (a *Application) registerRoutes() { a.routeRegistry.RegisterWithMethod("/internal/status/models", a.modelsStatusHandler, "Models status", "GET") a.routeRegistry.RegisterWithMethod("/internal/stats/models", a.modelStatsHandler, "Model statistics", "GET") a.routeRegistry.RegisterWithMethod("/internal/stats/translators", a.translatorStatsHandler, "Translator statistics", "GET") + a.routeRegistry.RegisterWithMethod("/internal/stats/sticky", a.stickyStatsHandler, "Sticky session statistics", "GET") a.routeRegistry.RegisterWithMethod("/internal/process", a.processStatsHandler, "Process status", "GET") a.routeRegistry.RegisterWithMethod("/version", a.versionHandler, "Olla version information", "GET") diff --git a/internal/app/services/discovery.go b/internal/app/services/discovery.go index 0ba36d4..d040d8e 100644 --- a/internal/app/services/discovery.go +++ b/internal/app/services/discovery.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "sync/atomic" "github.com/thushan/olla/internal/adapter/discovery" "github.com/thushan/olla/internal/adapter/health" @@ -30,6 +31,12 @@ type DiscoveryService struct { modelDiscovery *discovery.ModelDiscoveryService registry domain.ModelRegistry endpointRepo domain.EndpointRepository + // purgeDeadFn, when set, is called with the current routable endpoint list + // whenever a backend transitions to unhealthy. Wired by the proxy layer so sticky + // session entries for dead backends are evicted promptly rather than waiting for TTL. + // Stored atomically because SetPurgeDeadEndpointsFn is called from the main goroutine + // while the health-check goroutine may concurrently invoke the callback. + purgeDeadFn atomic.Pointer[func([]*domain.Endpoint)] } // NewDiscoveryService creates a new discovery service @@ -94,6 +101,21 @@ func (s *DiscoveryService) Start(ctx context.Context) error { s.healthChecker = health.NewHTTPHealthCheckerWithDefaults(s.endpointRepo, s.logger) + // Purge sticky session entries for any backend that goes offline. The purgeFn is + // nil when sticky sessions are disabled, making this callback a cheap no-op then. + s.healthChecker.SetUnhealthyCallback(health.UnhealthyCallbackFunc(func(ctx context.Context, _ *domain.Endpoint) { + fn := s.purgeDeadFn.Load() + if fn == nil { + return + } + routable, err := s.endpointRepo.GetRoutable(ctx) + if err != nil { + s.logger.Warn("Failed to fetch routable endpoints for sticky session purge", "error", err) + return + } + (*fn)(routable) + })) + if err := s.healthChecker.StartChecking(ctx); err != nil { return fmt.Errorf("failed to start health checker: %w", err) } @@ -257,6 +279,13 @@ func (s *DiscoveryService) SetStatsService(statsService *StatsService) { s.statsService = statsService } +// SetPurgeDeadEndpointsFn registers a function to call with the current routable +// endpoint list whenever a backend transitions to unhealthy. Used by the proxy layer +// to evict sticky session entries for dead backends without waiting for TTL expiry. +func (s *DiscoveryService) SetPurgeDeadEndpointsFn(fn func([]*domain.Endpoint)) { + s.purgeDeadFn.Store(&fn) +} + // UpdateEndpointStatus updates the status of an endpoint in the repository func (s *DiscoveryService) UpdateEndpointStatus(ctx context.Context, endpoint *domain.Endpoint) error { if s.endpointRepo == nil { diff --git a/internal/app/services/http.go b/internal/app/services/http.go index 1a736b5..6aadb59 100644 --- a/internal/app/services/http.go +++ b/internal/app/services/http.go @@ -125,6 +125,11 @@ func (s *HTTPService) Start(ctx context.Context) error { } s.application = app + // Wire sticky session stats if enabled — proxySvc holds the wrapper. + if s.proxySvc != nil { + s.application.SetStickyStatsFn(s.proxySvc.StickyStats) + } + s.application.RegisterRoutes() // Wire routes with security middleware diff --git a/internal/app/services/proxy.go b/internal/app/services/proxy.go index c8b273e..d8208ae 100644 --- a/internal/app/services/proxy.go +++ b/internal/app/services/proxy.go @@ -20,9 +20,12 @@ import ( // model. It manages the creation of load balancers and proxy engines, ensuring they // receive validated endpoints from the discovery service. type ProxyServiceWrapper struct { - config *config.ProxyConfig - proxyService ports.ProxyService - loadBalancer domain.EndpointSelector + config *config.ProxyConfig + proxyService ports.ProxyService + loadBalancer domain.EndpointSelector + // stickyWrapper is non-nil when sticky sessions are enabled; held separately + // so Stop() can shut down its background goroutine. + stickyWrapper *balancer.StickySessionWrapper endpointRepo domain.EndpointRepository discoveryService ports.DiscoveryService statsCollector ports.StatsCollector @@ -84,6 +87,8 @@ func (s *ProxyServiceWrapper) Start(ctx context.Context) error { } s.logger.Info("Load balancer created", "type", s.config.LoadBalancer) + s.applyStickySessions() + // Create proxy configuration proxyConfig := s.createProxyConfiguration() @@ -145,8 +150,9 @@ func (s *ProxyServiceWrapper) Start(ctx context.Context) error { func (s *ProxyServiceWrapper) Stop(ctx context.Context) error { s.logger.Info(" Stopping proxy service") - // Most proxy implementations don't need explicit cleanup - // but we provide the hook for future extensions + if s.stickyWrapper != nil { + s.stickyWrapper.Stop() + } defer func() { s.logger.ResetLine() @@ -190,6 +196,24 @@ func (s *ProxyServiceWrapper) GetLoadBalancer() (domain.EndpointSelector, error) return s.loadBalancer, nil } +// StickyStats returns a point-in-time snapshot of sticky session metrics, +// or nil when sticky sessions are disabled or not yet initialised. +func (s *ProxyServiceWrapper) StickyStats() *balancer.StickyStats { + if s.stickyWrapper == nil { + return nil + } + stats := s.stickyWrapper.Stats() + return &stats +} + +// PurgeDeadEndpoints removes sticky session entries that point to backends absent +// from the routable set. It is a no-op when sticky sessions are disabled. +func (s *ProxyServiceWrapper) PurgeDeadEndpoints(routable []*domain.Endpoint) { + if s.stickyWrapper != nil { + s.stickyWrapper.PurgeDeadEndpoints(routable) + } +} + // endpointRepositoryAdapter provides interface adaptation between the domain repository // and the discovery service interface expected by the proxy layer. type endpointRepositoryAdapter struct { @@ -228,3 +252,28 @@ func (s *ProxyServiceWrapper) SetDiscoveryService(discoveryService *DiscoverySer func (s *ProxyServiceWrapper) SetSecurityService(securityService *SecurityService) { s.securityService = securityService } + +// applyStickySessions wraps the current load balancer with KV-cache affinity routing +// when sticky sessions are enabled. It also wires the purge function into the discovery +// service immediately after the wrapper is assigned, ensuring the write of stickyWrapper +// happens-before the health-checker goroutine can observe it via PurgeDeadEndpoints. +func (s *ProxyServiceWrapper) applyStickySessions() { + if !s.config.StickySessions.Enabled { + return + } + sw := balancer.NewStickySessionWrapper(s.loadBalancer, s.config.StickySessions) + sw.Start() + s.stickyWrapper = sw + s.loadBalancer = sw + s.logger.Info("Sticky session affinity enabled", + "idle_ttl_seconds", s.config.StickySessions.IdleTTLSeconds, + "max_sessions", s.config.StickySessions.MaxSessions, + "key_sources", s.config.StickySessions.KeySources) + + // stickyWrapper is now initialised — register the purge hook so that sticky + // session entries for dead backends are evicted on health-check failure rather + // than waiting for the session TTL to expire. + if s.discoverySvc != nil { + s.discoverySvc.SetPurgeDeadEndpointsFn(s.PurgeDeadEndpoints) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 4f44254..3d86755 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -74,6 +74,13 @@ func DefaultConfig() *Config { ReadTimeout: 120 * time.Second, MaxRetries: 3, RetryBackoff: 500 * time.Millisecond, + StickySessions: StickySessionConfig{ + Enabled: false, + IdleTTLSeconds: 600, + MaxSessions: 10000, + KeySources: []string{"session_header", "prefix_hash", "auth_header"}, + PrefixHashBytes: 512, + }, }, Discovery: DiscoveryConfig{ Type: DefaultDiscoveryType, @@ -353,6 +360,11 @@ func applyEnvOverrides(config *Config) { if val := os.Getenv("OLLA_PROXY_PROFILE"); val != "" { config.Proxy.Profile = val } + if val := os.Getenv("OLLA_PROXY_STICKY_SESSIONS_ENABLED"); val != "" { + if enabled, err := strconv.ParseBool(val); err == nil { + config.Proxy.StickySessions.Enabled = enabled + } + } if val := os.Getenv("OLLA_LOGGING_LEVEL"); val != "" { config.Logging.Level = val } diff --git a/internal/config/types.go b/internal/config/types.go index 902f2bf..6940793 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -73,12 +73,28 @@ type ServerRateLimits struct { TrustProxyHeaders bool `yaml:"trust_proxy_headers"` } +// StickySessionConfig controls KV-cache affinity routing. +// When enabled, the balancer wrapper remembers which backend handled a conversation +// and routes subsequent turns to the same backend, maximising KV-cache reuse. +type StickySessionConfig struct { + // KeySources lists the ordered cascade of strategies used to identify a conversation. + // The first match wins: session_header, prefix_hash, auth_header, ip. + KeySources []string `yaml:"key_sources"` + MaxSessions uint64 `yaml:"max_sessions"` + IdleTTLSeconds int `yaml:"idle_ttl_seconds"` + // PrefixHashBytes is how many leading bytes of the messages field to hash. + // Larger values increase hash uniqueness at the cost of slightly more CPU per request. + PrefixHashBytes int `yaml:"prefix_hash_bytes"` + Enabled bool `yaml:"enabled"` +} + // ProxyConfig holds proxy-specific configuration type ProxyConfig struct { ProfileFilter *domain.FilterConfig `yaml:"profile_filter,omitempty"` Engine string `yaml:"engine"` LoadBalancer string `yaml:"load_balancer"` Profile string `yaml:"profile"` + StickySessions StickySessionConfig `yaml:"sticky_sessions"` ConnectionTimeout time.Duration `yaml:"connection_timeout"` ResponseTimeout time.Duration `yaml:"response_timeout"` ReadTimeout time.Duration `yaml:"read_timeout"` diff --git a/internal/core/constants/content.go b/internal/core/constants/content.go index 9031e0a..61aa953 100644 --- a/internal/core/constants/content.go +++ b/internal/core/constants/content.go @@ -107,4 +107,9 @@ const ( HeaderXOllaRoutingDecision = "X-Olla-Routing-Decision" HeaderXOllaRoutingReason = "X-Olla-Routing-Reason" HeaderXOllaMode = "X-Olla-Mode" + + // Sticky session headers + HeaderXOllaSessionID = "X-Olla-Session-ID" // client-supplied or echoed session identifier + HeaderXOllaStickySession = "X-Olla-Sticky-Session" // hit | miss | repin | disabled + HeaderXOllaStickyKeySource = "X-Olla-Sticky-Key-Source" // session_header | prefix_hash | auth_header | ip | none ) diff --git a/internal/core/constants/context.go b/internal/core/constants/context.go index 9fd4ca0..ad85523 100644 --- a/internal/core/constants/context.go +++ b/internal/core/constants/context.go @@ -1,5 +1,8 @@ package constants +// contextKey is a private type for context keys to prevent collisions with other packages. +type contextKey string + const ( ContextRoutePrefixKey = "route_prefix" // we inject this into the context to allow stripping prefixes for proxy calls ContextRequestIdKey = "request_id" // generataed each proxy_handler request for the request ID @@ -8,6 +11,12 @@ const ( ContextKeyStream = "stream" // indicates whether the response should be streamed or buffered ContextProviderTypeKey = "provider_type" // the provider type for the request, used for routing and load balancing + // Sticky session context keys — set by the handler before balancer selection + // and read back after to surface affinity decisions in response headers. + ContextStickyKeyKey = contextKey("sticky-key") // computed affinity key for this request + ContextStickyKeySourceKey = contextKey("sticky-key-source") // which source produced the key + ContextStickyOutcomeKey = contextKey("sticky-outcome") // *StickyOutcome written by the wrapper + // ContextModelAliasMapKey stores a map[string]string of endpoint URL → actual model name // when a model alias is resolved, allowing the proxy to rewrite the model name in the // request body to match what the selected backend expects diff --git a/internal/core/domain/routing.go b/internal/core/domain/routing.go index 03e8127..4918cf5 100644 --- a/internal/core/domain/routing.go +++ b/internal/core/domain/routing.go @@ -80,3 +80,15 @@ func (rp *RequestProfile) AddSupportedProfile(profileType string) { func (rp *RequestProfile) SetInspectionMeta(key string, value interface{}) { rp.InspectionMeta.Store(key, value) } + +// StickyOutcome carries the result of a sticky session selection back to the +// handler layer via context. The handler allocates it, passes it in context, +// and the proxy engine reads it to write response headers before WriteHeader. +// Defined here (not in adapter/balancer) so that adapter/proxy/core can read it +// without creating an import cycle. +type StickyOutcome struct { + // Result is "hit", "miss", "repin", or "disabled". + Result string + // Source is which key source produced the affinity key. + Source string +} diff --git a/makefile b/makefile index 4e2c87d..9dd6324 100644 --- a/makefile +++ b/makefile @@ -18,7 +18,7 @@ LDFLAGS := -ldflags "\ -X '$(PKG).Tool=$(TOOL)' \ -X '$(PKG).User=$(USER)'" -.PHONY: run clean build test test-verbose test-short test-race test-cover bench version install-deps check-deps vet test-script-integration +.PHONY: run clean build test test-verbose test-short test-race test-cover bench version install-deps check-deps vet test-script-integration test-script-sticky # Build the application with version info build: @@ -330,6 +330,11 @@ test-script-integration: @echo "Running integration test scripts..." @cd test/scripts && python integration/test-integration.py $(ARGS) +## test-script-sticky: Run sticky session test scripts (requires running Olla instance) +test-script-sticky: + @echo "Running sticky session test scripts..." + @cd test/scripts && python sticky/test-sticky-sessions.py $(ARGS) + # Show help help: @echo "Available targets:" @@ -373,4 +378,5 @@ help: @echo " check-deps - Check installed tool versions against requirements" @echo " ci - Run full CI pipeline locally" @echo " test-script-integration - Run integration test scripts (requires running Olla)" + @echo " test-script-sticky - Run sticky session test scripts (requires running Olla)" @echo " help - Show this help" \ No newline at end of file diff --git a/readme.md b/readme.md index 04ab691..6fc9f3c 100644 --- a/readme.md +++ b/readme.md @@ -43,6 +43,7 @@ For Large GPU deployments, Enterprise & Data-Centre use, see [TensorFoundry Foun ## Key Features - **🔄 Smart Load Balancing**: [Priority-based routing](https://thushan.github.io/olla/concepts/load-balancing/) with automatic failover and connection retry +- **📌 Sticky Sessions**: [KV-cache-aware affinity routing](https://thushan.github.io/olla/concepts/sticky-sessions/) that pins multi-turn conversations to the same backend - **🔍 Smart Model Unification**: [Per-provider unification + OpenAI-compatible cross-provider routing](https://thushan.github.io/olla/concepts/model-unification/) - **⚡ Dual Proxy Engines**: [Sherpa (simple) and Olla (high-performance)](https://thushan.github.io/olla/concepts/proxy-engines/) - **🎯 Advanced Filtering**: [Profile and model filtering](https://thushan.github.io/olla/configuration/filters/) with glob patterns for precise control diff --git a/test/scripts/integration/test-integration.py b/test/scripts/integration/test-integration.py index 029a104..9b56dba 100644 --- a/test/scripts/integration/test-integration.py +++ b/test/scripts/integration/test-integration.py @@ -70,13 +70,15 @@ def __init__(self, name: str, passed: bool, detail: str = "", phase: str = ""): class IntegrationTester: def __init__(self, base_url: str, timeout: int, verbose: bool, - skip_streaming: bool, skip_anthropic: bool, skip_providers: bool): + skip_streaming: bool, skip_anthropic: bool, skip_providers: bool, + skip_sticky: bool = False): self.base_url = base_url self.timeout = timeout self.verbose = verbose self.skip_streaming = skip_streaming self.skip_anthropic = skip_anthropic self.skip_providers = skip_providers + self.skip_sticky = skip_sticky self.results: List[TestResult] = [] self.endpoints: List[Dict] = [] self.models: List[Dict] = [] @@ -1229,7 +1231,278 @@ def phase_error_handling(self): self._print_result("Missing model field handled", ok, detail) self.record(f"{phase}/missing-model", ok, detail, phase) - # -- Phase 10: Summary ---------------------------------------------------- + # -- Phase 10: Sticky Session Headers ------------------------------------- + + def phase_sticky_sessions(self): + self._phase_header(10, "Sticky Session Headers") + phase = "sticky" + + if self.skip_sticky: + self.pcolor(GREY, " [SKIP] Sticky session tests skipped via --skip-sticky") + return + + if not self.selected_model: + self.pcolor(GREY, " [SKIP] No model available for sticky session tests") + return + + # Probe to determine whether sticky sessions are enabled on this instance. + probe = self._post("/olla/proxy/v1/chat/completions", { + "model": self.selected_model, + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + "stream": False, + }) + if probe is None or probe.status_code != 200: + self.pcolor(GREY, " [SKIP] Probe request failed — cannot determine sticky session state") + return + + sticky_state = probe.headers.get("X-Olla-Sticky-Session", "") + if sticky_state == "disabled" or sticky_state == "": + self.pcolor(YELLOW, " [SKIP] Sticky sessions are disabled on this Olla instance") + self.pcolor(GREY, " Enable with proxy.sticky_sessions.enabled: true in config.yaml") + return + + self.pcolor(GREEN, f" Sticky sessions active (probe returned: {sticky_state})") + + # -- Test 1: explicit session ID creates a pin ----------------------- + session_id = f"integration-test-{int(time.time())}" + headers_with_session = {"X-Olla-Session-ID": session_id} + + body = { + "model": self.selected_model, + "messages": [{"role": "user", "content": "Say one word"}], + "max_tokens": 1, + "stream": False, + } + + r1 = self._post("/olla/proxy/v1/chat/completions", body, headers=headers_with_session) + ok = False + detail = "" + endpoint_pinned = "" + + if r1 is not None and r1.status_code == 200: + sticky1 = r1.headers.get("X-Olla-Sticky-Session", "") + endpoint_pinned = r1.headers.get("X-Olla-Endpoint", "") + ok = sticky1 == "miss" + if ok: + detail = f"first request sticky=miss, endpoint={endpoint_pinned}" + else: + detail = f"expected sticky=miss, got '{sticky1}'" + else: + detail = self._error_detail(r1) + + self._print_result("Session ID first request returns miss", ok, detail) + self.record(f"{phase}/session-first-miss", ok, detail, phase) + + # -- Test 2: second request with same session ID returns hit --------- + r2 = self._post("/olla/proxy/v1/chat/completions", body, headers=headers_with_session) + ok = False + detail = "" + + if r2 is not None and r2.status_code == 200: + sticky2 = r2.headers.get("X-Olla-Sticky-Session", "") + endpoint2 = r2.headers.get("X-Olla-Endpoint", "") + ok = sticky2 == "hit" + if ok: + same_ep = endpoint2 == endpoint_pinned if endpoint_pinned else True + detail = f"sticky=hit, endpoint={endpoint2}" + if endpoint_pinned and not same_ep: + ok = False + detail = f"sticky=hit but endpoint changed: {endpoint_pinned} -> {endpoint2}" + else: + detail = f"expected sticky=hit, got '{sticky2}'" + else: + detail = self._error_detail(r2) + + self._print_result("Session ID second request returns hit", ok, detail) + self.record(f"{phase}/session-second-hit", ok, detail, phase) + + # -- Test 3: session ID echoed back in response ---------------------- + ok = False + detail = "" + if r1 is not None and r1.status_code == 200: + echoed = r1.headers.get("X-Olla-Session-ID", "") + ok = echoed == session_id + detail = (f"echoed={echoed}" if echoed else "X-Olla-Session-ID absent in response") + else: + detail = self._error_detail(r1) + + self._print_result("Session ID echoed back in response header", ok, detail) + self.record(f"{phase}/session-id-echoed", ok, detail, phase) + + # -- Test 4: key source is a valid value ----------------------------- + valid_sources = {"session_header", "prefix_hash", "auth_header", "ip", "none"} + ok = False + detail = "" + if r1 is not None and r1.status_code == 200: + source = r1.headers.get("X-Olla-Sticky-Key-Source", "") + ok = source in valid_sources + detail = f"X-Olla-Sticky-Key-Source={source}" if source else "header absent" + if source and not ok: + detail = f"unknown key source: '{source}'" + else: + detail = self._error_detail(r1) + + self._print_result("Key source is a known valid value", ok, detail) + self.record(f"{phase}/key-source-valid", ok, detail, phase) + + # -- Test 5: prefix hash routing — same prompt hits same backend ----- + # Use a deterministic, long system prompt so prefix_hash fires. + long_prompt = ( + "You are a helpful assistant for integration testing of Olla sticky session routing. " + "You specialise in infrastructure, load balancers, and KV cache behaviour. " + "Always respond concisely. This prompt is intentionally long to exercise " + "the prefix hashing logic which operates on the first 512 bytes of the " + "messages JSON payload. Filler text: " + ("a" * 100) + ) + hash_body = { + "model": self.selected_model, + "messages": [ + {"role": "system", "content": long_prompt}, + {"role": "user", "content": "Hello"}, + ], + "max_tokens": 1, + "stream": False, + } + + rh1 = self._post("/olla/proxy/v1/chat/completions", hash_body) + rh2 = self._post("/olla/proxy/v1/chat/completions", hash_body) + ok = False + detail = "" + + if rh1 is not None and rh2 is not None and rh1.status_code == 200 and rh2.status_code == 200: + source_h1 = rh1.headers.get("X-Olla-Sticky-Key-Source", "") + source_h2 = rh2.headers.get("X-Olla-Sticky-Key-Source", "") + sticky_h2 = rh2.headers.get("X-Olla-Sticky-Session", "") + ep_h1 = rh1.headers.get("X-Olla-Endpoint", "") + ep_h2 = rh2.headers.get("X-Olla-Endpoint", "") + + # Both must use prefix_hash as key source and the second should be a hit. + source_ok = source_h1 == "prefix_hash" and source_h2 == "prefix_hash" + hit_ok = sticky_h2 == "hit" + + ok = source_ok and hit_ok + detail = ( + f"source1={source_h1}, source2={source_h2}, " + f"sticky2={sticky_h2}, ep1={ep_h1}, ep2={ep_h2}" + ) + if not source_ok: + detail = f"expected prefix_hash key source, got source1={source_h1} source2={source_h2}" + elif not hit_ok: + detail = f"expected sticky=hit on second request, got '{sticky_h2}'" + else: + detail = self._error_detail(rh1 or rh2) + + self._print_result("Same prefix hash routes to same backend", ok, detail) + self.record(f"{phase}/prefix-hash-hits", ok, detail, phase) + + # -- Test 6: two independent sessions pin independently -------------- + session_x = f"integration-x-{int(time.time())}" + session_y = f"integration-y-{int(time.time())}" + + # Establish both sessions + self._post("/olla/proxy/v1/chat/completions", body, + headers={"X-Olla-Session-ID": session_x}) + self._post("/olla/proxy/v1/chat/completions", body, + headers={"X-Olla-Session-ID": session_y}) + + # Follow-up for each session + rx = self._post("/olla/proxy/v1/chat/completions", body, + headers={"X-Olla-Session-ID": session_x}) + ry = self._post("/olla/proxy/v1/chat/completions", body, + headers={"X-Olla-Session-ID": session_y}) + + ok = False + detail = "" + + if (rx is not None and ry is not None + and rx.status_code == 200 and ry.status_code == 200): + sticky_x = rx.headers.get("X-Olla-Sticky-Session", "") + sticky_y = ry.headers.get("X-Olla-Sticky-Session", "") + source_x = rx.headers.get("X-Olla-Sticky-Key-Source", "") + source_y = ry.headers.get("X-Olla-Sticky-Key-Source", "") + + ok = (sticky_x == "hit" and sticky_y == "hit" + and source_x == "session_header" and source_y == "session_header") + detail = ( + f"session_x: sticky={sticky_x} source={source_x}, " + f"session_y: sticky={sticky_y} source={source_y}" + ) + else: + detail = self._error_detail(rx or ry) + + self._print_result("Two independent sessions pin independently", ok, detail) + self.record(f"{phase}/sessions-independent", ok, detail, phase) + + # -- Test 7: model scoping — same session ID, different models pin separately + model_ids = [m.get("id", m.get("name", "")) for m in self.models + if m.get("id", m.get("name", "")) and "embed" not in m.get("id", "").lower()] + if len(model_ids) >= 2: + model_a = model_ids[0] + model_b = model_ids[1] + scope_sid = f"integration-scope-{int(time.time())}" + scope_headers = {"X-Olla-Session-ID": scope_sid} + + # Establish a pin for model_a + self._post("/olla/proxy/v1/chat/completions", + {"model": model_a, "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}, + headers=scope_headers) + # model_b with the same session ID must get its own independent pin (miss) + rb_first = self._post("/olla/proxy/v1/chat/completions", + {"model": model_b, "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}, + headers=scope_headers) + # model_a follow-up must still hit + ra_second = self._post("/olla/proxy/v1/chat/completions", + {"model": model_a, "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}, + headers=scope_headers) + + ok = False + detail = "" + if (rb_first is not None and ra_second is not None + and rb_first.status_code == 200 and ra_second.status_code == 200): + rb_sticky = rb_first.headers.get("X-Olla-Sticky-Session", "") + ra_sticky = ra_second.headers.get("X-Olla-Sticky-Session", "") + # model_b first request should be a miss (new key because model differs) + # model_a second request should be a hit (existing pin) + ok = rb_sticky == "miss" and ra_sticky == "hit" + detail = f"model_b first={rb_sticky} (want miss), model_a second={ra_sticky} (want hit)" + else: + detail = self._error_detail(rb_first or ra_second) + + self._print_result("Model-scoped keys: same session ID, different models pin separately", ok, detail) + self.record(f"{phase}/model-scoping", ok, detail, phase) + else: + self.pcolor(GREY, f" [SKIP] Model scoping test skipped — only {len(model_ids)} model(s) available") + + # -- Test 8: /internal/stats/sticky endpoint -------------------------- + ok = False + detail = "" + try: + sr = requests.get(f"{self.base_url}/internal/stats/sticky", timeout=self.timeout) + if sr.status_code == 200: + data = sr.json() + has_enabled = "enabled" in data + if data.get("enabled"): + # When enabled, expect the ttlcache metric fields to be present + has_fields = all(k in data for k in ("hits", "misses", "evictions", "active_sessions")) + ok = has_enabled and has_fields + detail = (f"enabled=true hits={data.get('hits')} misses={data.get('misses')} " + f"active={data.get('active_sessions')}") + if not has_fields: + detail = f"missing metric fields in response: {list(data.keys())}" + else: + # Disabled — endpoint still returns 200 with enabled:false + ok = has_enabled + detail = "enabled=false (sticky sessions disabled on this instance)" + else: + detail = f"HTTP {sr.status_code}" + except Exception as e: + detail = str(e) + + self._print_result("Sticky stats endpoint returns 200 with expected fields", ok, detail) + self.record(f"{phase}/stats-endpoint", ok, detail, phase) + + # -- Phase 11: Summary ---------------------------------------------------- def print_summary(self) -> bool: print() @@ -1252,6 +1525,7 @@ def print_summary(self) -> bool: "providers": "Provider Routes", "headers": "Response Headers", "errors": "Error Handling", + "sticky": "Sticky Sessions", } for phase_key, label in phase_labels.items(): @@ -1302,6 +1576,8 @@ def main(): help="Skip Anthropic translator tests") parser.add_argument("--skip-providers", action="store_true", help="Skip provider-specific route tests") + parser.add_argument("--skip-sticky", action="store_true", + help="Skip sticky session tests") parser.add_argument("--verbose", action="store_true", help="Show response bodies") @@ -1314,6 +1590,7 @@ def main(): skip_streaming=args.skip_streaming, skip_anthropic=args.skip_anthropic, skip_providers=args.skip_providers, + skip_sticky=args.skip_sticky, ) tester.print_header() @@ -1325,7 +1602,7 @@ def main(): if not tester.discover(): sys.exit(1) - # Phase 2-9: Test phases + # Phase 2-10: Test phases tester.phase_internal_endpoints() tester.phase_unified_models() tester.phase_proxy_endpoints() @@ -1334,8 +1611,9 @@ def main(): tester.phase_provider_routes() tester.phase_response_headers() tester.phase_error_handling() + tester.phase_sticky_sessions() - # Phase 10: Summary + # Phase 11: Summary all_pass = tester.print_summary() sys.exit(0 if all_pass else 1) diff --git a/test/scripts/sticky/README.md b/test/scripts/sticky/README.md new file mode 100644 index 0000000..e1dda29 --- /dev/null +++ b/test/scripts/sticky/README.md @@ -0,0 +1,136 @@ +# Sticky Session Test Scripts + +Validates Olla's KV cache affinity routing (sticky sessions). Tests session pinning via explicit +session IDs, prefix hash routing for identical system prompts, session independence, key source +reporting, and disabled-state behaviour. + +## What's Being Tested + +| Behaviour | Header | Expected values | +|-----------|--------|----------------| +| Session state | `X-Olla-Sticky-Session` | `hit`, `miss`, `repin`, `disabled` | +| Key derivation | `X-Olla-Sticky-Key-Source` | `session_header`, `prefix_hash`, `auth_header`, `ip`, `none` | +| Session echo | `X-Olla-Session-ID` | Echoed back when provided in request | + +### Test Cases + +1. **Header presence** — `X-Olla-Sticky-Session` is always set to a known value. +2. **Key source validity** — `X-Olla-Sticky-Key-Source` is always one of the known valid values (or absent when disabled). +3. **Session ID pin** — First request with `X-Olla-Session-ID` returns `miss`; second with same ID returns `hit` on the same endpoint. +4. **Session ID echo** — The `X-Olla-Session-ID` request header is echoed in the response. +5. **Prefix hash hit** — Identical system prompts (same first 512 bytes) produce a `hit` on the second request. +6. **Session independence** — Two different session IDs each pin to their own endpoint without interfering. +7. **Multi-backend (2+ backends)** — Sessions pinned across multiple backends do not cross-contaminate. + +Affinity tests (3–7) are automatically skipped with an advisory message when `sticky_sessions.enabled: false`. + +## Prerequisites + +- Olla running and reachable (default `http://localhost:40114`) +- Python 3.8+ with `requests` installed (`pip install -r requirements.txt` from `test/scripts/`) +- At least one healthy backend with a loaded model +- `sticky_sessions.enabled: true` in `config.yaml` for affinity tests + +## Enabling Sticky Sessions + +Add the following to `config.yaml` under `proxy:`: + +```yaml +proxy: + sticky_sessions: + enabled: true + idle_ttl_seconds: 600 + max_sessions: 10000 + key_sources: + - "session_header" + - "prefix_hash" + - "auth_header" + prefix_hash_bytes: 512 +``` + +## Usage + +```bash +# Run via Makefile (recommended) +make test-script-sticky + +# Custom URL +make test-script-sticky ARGS="--url http://localhost:11435" + +# Custom URL and model +make test-script-sticky ARGS="--url http://localhost:11435 --model llama3.2" + +# Run directly +cd test/scripts +python sticky/test-sticky-sessions.py +python sticky/test-sticky-sessions.py --url http://localhost:11435 --model phi4:latest +python sticky/test-sticky-sessions.py --verbose +python sticky/test-sticky-sessions.py --skip-stats +``` + +## Expected Output + +``` +======================================================================== + Olla Sticky Session Test + Validates KV cache affinity routing behaviour +======================================================================== + +Checking Olla availability... +[OK] Olla is reachable + +Discovering endpoints... +[OK] Found 2 endpoint(s) +... + +Detecting sticky session status... + [OK] Sticky sessions active (probe returned: miss) + +Header Validation +======================================================================== + X-Olla-Sticky-Session present: [PASS] value='miss' + X-Olla-Sticky-Key-Source valid: [PASS] source='prefix_hash' + +Affinity Tests +======================================================================== + Session ID: miss → hit: [PASS] endpoint=local-ollama + Session ID echoed in response: [PASS] echoed='test-sticky-1234567890' + Prefix hash: same prompt → hit: [PASS] source=prefix_hash endpoint=local-ollama + Two sessions pin independently: [PASS] both sessions independently pinned + +======================================================================== + Results Summary +======================================================================== + + Test Result + ------------------------------------------------ ------ + header/sticky-session-present PASS + header/key-source-valid PASS + affinity/session-id-miss-then-hit PASS + affinity/session-id-echoed PASS + affinity/prefix-hash-hit PASS + affinity/sessions-independent PASS + + Total: 6 | Passed: 6 | Failed: 0 + + All tests passed. +``` + +## Response Headers Reference + +| Header | When Set | Values | +|--------|----------|--------| +| `X-Olla-Sticky-Session` | All proxy requests | `hit` — pinned endpoint served the request | +| | | `miss` — no existing pin, new pin created | +| | | `repin` — previous endpoint unavailable, repinned | +| | | `disabled` — feature is off | +| `X-Olla-Sticky-Key-Source` | When enabled | `session_header`, `prefix_hash`, `auth_header`, `ip`, `none` | +| `X-Olla-Session-ID` | When provided in request | Echoed back unchanged | + +## Exit Codes + +| Code | Meaning | +|------|---------| +| `0` | All tests passed | +| `1` | One or more tests failed, or Olla unreachable | +| `130` | Interrupted by Ctrl+C | diff --git a/test/scripts/sticky/test-sticky-sessions.py b/test/scripts/sticky/test-sticky-sessions.py new file mode 100644 index 0000000..8fd94bb --- /dev/null +++ b/test/scripts/sticky/test-sticky-sessions.py @@ -0,0 +1,715 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Olla Sticky Session Test Script + +Validates that sticky session (KV cache affinity) routing works correctly. +Sends repeated requests with identical content or explicit session IDs and +verifies that the same backend is selected on subsequent turns. + +Auto-discovers available backends and models, then runs a test matrix covering +session ID pinning, prefix-hash affinity, session independence, and header +validation. +""" + +import sys +import json +import time +import argparse +import requests +import os +from typing import Dict, List, Optional, Any + +# Fix Windows console encoding for Unicode +if sys.platform == 'win32': + import io + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + os.environ['PYTHONIOENCODING'] = 'utf-8' + +# ANSI colour codes +RED = '\033[0;31m' +GREEN = '\033[0;32m' +YELLOW = '\033[1;33m' +BLUE = '\033[0;34m' +PURPLE = '\033[0;35m' +CYAN = '\033[0;36m' +WHITE = '\033[1;37m' +GREY = '\033[0;37m' +RESET = '\033[0m' +BOLD = '\033[1m' + +# Configuration +TARGET_URL = "http://localhost:40114" +DEFAULT_TIMEOUT = 30 + +# Valid values for sticky session response headers +VALID_STICKY_RESULTS = {"hit", "miss", "repin", "disabled"} +VALID_KEY_SOURCES = {"session_header", "prefix_hash", "auth_header", "ip", "none"} + +# Long system prompt used for prefix-hash tests — must be > 512 bytes so the +# hash captures real content rather than colliding on short strings. +LONG_SYSTEM_PROMPT = ( + "You are a highly knowledgeable Go programming expert. " + "You specialise in concurrent systems, high-performance proxies, and load " + "balancing algorithms. You write clean, idiomatic Go following effective Go " + "conventions. You always explain the why behind design decisions, not just " + "the what. You are familiar with the Olla project and its hexagonal " + "architecture. When reviewing code you consider correctness, performance, " + "and maintainability equally." +) + +ALTERNATIVE_SYSTEM_PROMPT = ( + "You are a Python data science expert specialising in pandas, NumPy, and " + "scikit-learn. You focus on efficient data pipelines, reproducible research, " + "and clear visualisations. Your answers are concise and include runnable " + "code examples." +) + + +class BackendInfo: + """Discovered backend with its health, type, and selected model.""" + __slots__ = ("name", "backend_type", "status", "models", "selected_model") + + def __init__(self, name: str, backend_type: str, status: str): + self.name = name + self.backend_type = backend_type + self.status = status + self.models: List[str] = [] + self.selected_model: Optional[str] = None + + +class TestResult: + """Outcome of a single test case.""" + __slots__ = ("name", "passed", "detail") + + def __init__(self, name: str, passed: bool, detail: str = ""): + self.name = name + self.passed = passed + self.detail = detail + + +class StickySessionTester: + def __init__(self, base_url: str, timeout: int, verbose: bool, model: Optional[str]): + self.base_url = base_url + self.timeout = timeout + self.verbose = verbose + self.forced_model = model + self.backends: List[BackendInfo] = [] + self.results: List[TestResult] = [] + self.sticky_enabled: Optional[bool] = None # detected from first request + + # ── Helpers ───────────────────────────────────────────────────────── + + def pcolor(self, color: str, msg: str, end: str = '\n'): + print(f"{color}{msg}{RESET}", end=end) + sys.stdout.flush() + + def print_header(self): + self.pcolor(PURPLE, "=" * 72) + self.pcolor(PURPLE, f" {CYAN}Olla Sticky Session Test{RESET}") + self.pcolor(PURPLE, f" {GREY}Validates KV cache affinity routing behaviour{RESET}") + self.pcolor(PURPLE, "=" * 72) + print() + + def record(self, name: str, passed: bool, detail: str = "") -> bool: + self.results.append(TestResult(name, passed, detail)) + return passed + + def _chat_headers(self, session_id: Optional[str] = None) -> Dict[str, str]: + h = {"Content-Type": "application/json"} + if session_id: + h["X-Olla-Session-ID"] = session_id + return h + + def _chat_body(self, model: str, system: Optional[str] = None, + user_turn: str = "Reply with one word: yes") -> Dict[str, Any]: + messages: List[Dict[str, str]] = [] + if system: + messages.append({"role": "system", "content": system}) + messages.append({"role": "user", "content": user_turn}) + return { + "model": model, + "messages": messages, + "max_tokens": 5, + } + + def _post(self, body: Dict[str, Any], + headers: Optional[Dict[str, str]] = None) -> Optional[requests.Response]: + """POST to the OpenAI-compatible proxy endpoint. Returns None on network error.""" + try: + return requests.post( + f"{self.base_url}/olla/proxy/v1/chat/completions", + headers=headers or {"Content-Type": "application/json"}, + json=body, + timeout=self.timeout, + ) + except Exception as e: + self.pcolor(RED, f" [ERROR] Request failed: {e}") + return None + + # ── Discovery ─────────────────────────────────────────────────────── + + def check_health(self) -> bool: + self.pcolor(YELLOW, "Checking Olla availability...") + try: + r = requests.get(f"{self.base_url}/internal/health", timeout=5) + if r.status_code == 200: + self.pcolor(GREEN, "[OK] Olla is reachable") + return True + except Exception: + pass + self.pcolor(RED, f"[FAIL] Cannot reach Olla at {self.base_url}") + return False + + def discover(self) -> bool: + """Discover endpoints and pick a model for testing.""" + self.pcolor(YELLOW, "Discovering endpoints...") + try: + r = requests.get(f"{self.base_url}/internal/status/endpoints", timeout=self.timeout) + r.raise_for_status() + endpoints = r.json().get("endpoints", []) + except Exception as e: + self.pcolor(RED, f"[FAIL] Endpoint discovery failed: {e}") + return False + + ep_map: Dict[str, BackendInfo] = {} + for ep in endpoints: + name = ep.get("name", "unknown") + btype = ep.get("type", "unknown") + status = ep.get("status", "unknown") + bi = BackendInfo(name, btype, status) + ep_map[name] = bi + self.backends.append(bi) + + if not self.backends: + self.pcolor(RED, "[FAIL] No endpoints discovered") + return False + + self.pcolor(GREEN, f"[OK] Found {len(self.backends)} endpoint(s)") + + # Fetch models + self.pcolor(YELLOW, "Discovering models...") + try: + r = requests.get(f"{self.base_url}/internal/status/models", timeout=self.timeout) + r.raise_for_status() + data = r.json() + recent = data.get("recent_models", []) + except Exception as e: + self.pcolor(RED, f"[FAIL] Model discovery failed: {e}") + return False + + # Build URL → endpoint name reverse map + url_to_name: Dict[str, str] = {} + for ep in endpoints: + url = ep.get("url", "") + name = ep.get("name", "") + if url and name: + url_to_name[url.rstrip("/")] = name + + type_to_backends: Dict[str, List[str]] = {} + for ep in endpoints: + btype = ep.get("type", "") + name = ep.get("name", "") + if btype and name: + type_to_backends.setdefault(btype, []).append(name) + + for m in recent: + model_name = m.get("name", "") + model_endpoints = m.get("endpoints", []) + model_type = m.get("type", "") + matched = False + + for ep_url in model_endpoints: + ep_name = url_to_name.get(ep_url.rstrip("/")) + if ep_name and ep_name in ep_map: + ep_map[ep_name].models.append(model_name) + matched = True + + if not matched and model_type and model_type in type_to_backends: + for ep_name in type_to_backends[model_type]: + if ep_name in ep_map: + ep_map[ep_name].models.append(model_name) + + # Select one non-embedding model per backend + for bi in self.backends: + if self.forced_model: + bi.selected_model = self.forced_model + else: + candidates = [m for m in bi.models if "embed" not in m.lower()] + if candidates: + bi.selected_model = candidates[0] + + self._print_discovery_summary() + return True + + def _print_discovery_summary(self): + print() + self.pcolor(WHITE, "Configuration Summary") + self.pcolor(PURPLE, "-" * 72) + header = f" {'Backend':<20} {'Type':<18} {'Status':<10} {'Model'}" + self.pcolor(GREY, header) + self.pcolor(GREY, f" {'-'*20} {'-'*18} {'-'*10} {'-'*24}") + + for bi in self.backends: + status_c = GREEN if bi.status == "healthy" else RED + status_str = f"{status_c}{bi.status:<10}{RESET}" + model_str = bi.selected_model or "(none)" + print(f" {bi.name:<20} {bi.backend_type:<18} {status_str} {model_str}") + + def _testable(self) -> List[BackendInfo]: + return [b for b in self.backends if b.status == "healthy" and b.selected_model] + + # ── Feature detection ──────────────────────────────────────────────── + + def detect_sticky_enabled(self, model: str) -> bool: + """ + Make one probe request and inspect X-Olla-Sticky-Session. + Returns True if sticky sessions are active, False if disabled. + """ + self.pcolor(YELLOW, "\nDetecting sticky session status...") + r = self._post(self._chat_body(model)) + if r is None or r.status_code != 200: + self.pcolor(RED, " [FAIL] Probe request failed — cannot detect feature status") + return False + + val = r.headers.get("X-Olla-Sticky-Session", "") + if val == "disabled": + self.pcolor(YELLOW, ( + " [WARN] Sticky sessions are disabled in Olla config.\n" + " Affinity tests will be skipped.\n" + " To enable, add to config.yaml under proxy:\n" + " sticky_sessions:\n" + " enabled: true" + )) + self.sticky_enabled = False + return False + + if val in VALID_STICKY_RESULTS: + self.pcolor(GREEN, f" [OK] Sticky sessions active (probe returned: {val})") + self.sticky_enabled = True + return True + + # Header absent — old binary or feature not built + self.pcolor(YELLOW, f" [WARN] X-Olla-Sticky-Session header absent (value: '{val}')") + self.sticky_enabled = False + return False + + # ── Test Sections ──────────────────────────────────────────────────── + + def run_header_validation(self, model: str): + """ + Basic header presence and value validation — runs regardless of whether + sticky sessions are enabled, because the 'disabled' value is also valid. + """ + print() + self.pcolor(WHITE, "Header Validation") + self.pcolor(PURPLE, "=" * 72) + + self._test_sticky_header_present(model) + self._test_key_source_valid(model) + + def run_affinity_tests(self, model: str): + """Core affinity tests — only meaningful when sticky sessions are enabled.""" + print() + self.pcolor(WHITE, "Affinity Tests") + self.pcolor(PURPLE, "=" * 72) + + if not self.sticky_enabled: + self.pcolor(GREY, " Skipped — sticky sessions are disabled") + return + + self._test_session_id_miss_then_hit(model) + self._test_session_id_echoed(model) + self._test_prefix_hash_hit(model) + self._test_sessions_independent(model) + + def run_multi_backend_tests(self): + """ + Tests that only make sense when multiple healthy backends are available, + e.g. verifying that different session IDs can land on different backends. + Skipped silently when only one backend is present. + """ + testable = self._testable() + if len(testable) < 2: + return + + print() + self.pcolor(WHITE, "Multi-Backend Tests") + self.pcolor(PURPLE, "=" * 72) + + if not self.sticky_enabled: + self.pcolor(GREY, " Skipped — sticky sessions are disabled") + return + + self._test_different_sessions_may_use_different_backends(testable) + + # ── Individual Tests ───────────────────────────────────────────────── + + def _test_sticky_header_present(self, model: str): + label = "header/sticky-session-present" + self.pcolor(YELLOW, " X-Olla-Sticky-Session present: ", end="") + + r = self._post(self._chat_body(model)) + if r is None: + print(f"{RED}[FAIL]{RESET}") + self.record(label, False, "request failed") + return + + val = r.headers.get("X-Olla-Sticky-Session", "") + ok = val in VALID_STICKY_RESULTS + status_str = f"{GREEN}[PASS]{RESET}" if ok else f"{RED}[FAIL]{RESET}" + detail = f" value='{val}'" if ok else f" got '{val}', expected one of {VALID_STICKY_RESULTS}" + print(f"{status_str}{GREY}{detail}{RESET}") + self.record(label, ok, detail.strip()) + + def _test_key_source_valid(self, model: str): + label = "header/key-source-valid" + self.pcolor(YELLOW, " X-Olla-Sticky-Key-Source valid: ", end="") + + r = self._post(self._chat_body(model)) + if r is None: + print(f"{RED}[FAIL]{RESET}") + self.record(label, False, "request failed") + return + + sticky_val = r.headers.get("X-Olla-Sticky-Session", "") + source_val = r.headers.get("X-Olla-Sticky-Key-Source", "") + + # When disabled, source header is absent — that is correct behaviour. + if sticky_val == "disabled": + print(f"{GREEN}[PASS]{RESET}{GREY} (disabled — source header absent as expected){RESET}") + self.record(label, True, "disabled") + return + + ok = source_val in VALID_KEY_SOURCES + status_str = f"{GREEN}[PASS]{RESET}" if ok else f"{RED}[FAIL]{RESET}" + detail = f" source='{source_val}'" if ok else f" got '{source_val}', expected one of {VALID_KEY_SOURCES}" + print(f"{status_str}{GREY}{detail}{RESET}") + self.record(label, ok, detail.strip()) + + def _test_session_id_miss_then_hit(self, model: str): + label = "affinity/session-id-miss-then-hit" + self.pcolor(YELLOW, " Session ID: miss → hit: ", end="") + + session_id = f"test-sticky-{int(time.time())}" + headers = self._chat_headers(session_id) + body = self._chat_body(model, system=LONG_SYSTEM_PROMPT) + + r1 = self._post(body, headers) + if r1 is None or r1.status_code != 200: + print(f"{RED}[FAIL]{RESET}{GREY} turn-1 request failed{RESET}") + self.record(label, False, "turn-1 failed") + return + + r2 = self._post(body, headers) + if r2 is None or r2.status_code != 200: + print(f"{RED}[FAIL]{RESET}{GREY} turn-2 request failed{RESET}") + self.record(label, False, "turn-2 failed") + return + + t1_sticky = r1.headers.get("X-Olla-Sticky-Session", "") + t2_sticky = r2.headers.get("X-Olla-Sticky-Session", "") + t1_ep = r1.headers.get("X-Olla-Endpoint", "") + t2_ep = r2.headers.get("X-Olla-Endpoint", "") + + notes = [] + ok = True + + if t1_sticky != "miss": + ok = False + notes.append(f"turn-1 expected 'miss', got '{t1_sticky}'") + if t2_sticky != "hit": + ok = False + notes.append(f"turn-2 expected 'hit', got '{t2_sticky}'") + if t1_ep and t2_ep and t1_ep != t2_ep: + ok = False + notes.append(f"endpoint changed: {t1_ep} → {t2_ep}") + + status_str = f"{GREEN}[PASS]{RESET}" if ok else f"{RED}[FAIL]{RESET}" + detail = f" {' | '.join(notes)}" if notes else f" endpoint={t1_ep or 'n/a'}" + print(f"{status_str}{GREY}{detail}{RESET}") + self.record(label, ok, "; ".join(notes)) + + def _test_session_id_echoed(self, model: str): + label = "affinity/session-id-echoed" + self.pcolor(YELLOW, " Session ID echoed in response: ", end="") + + session_id = f"echo-test-{int(time.time())}" + r = self._post( + self._chat_body(model), + self._chat_headers(session_id), + ) + if r is None or r.status_code != 200: + print(f"{RED}[FAIL]{RESET}") + self.record(label, False, "request failed") + return + + echoed = r.headers.get("X-Olla-Session-ID", "") + ok = echoed == session_id + status_str = f"{GREEN}[PASS]{RESET}" if ok else f"{RED}[FAIL]{RESET}" + detail = f" echoed='{echoed}'" if ok else f" expected '{session_id}', got '{echoed}'" + print(f"{status_str}{GREY}{detail}{RESET}") + self.record(label, ok, detail.strip()) + + def _test_prefix_hash_hit(self, model: str): + label = "affinity/prefix-hash-hit" + self.pcolor(YELLOW, " Prefix hash: same prompt → hit: ", end="") + + # No session header — key must come from prefix hash. + body = self._chat_body(model, system=LONG_SYSTEM_PROMPT) + + r1 = self._post(body) + if r1 is None or r1.status_code != 200: + print(f"{RED}[FAIL]{RESET}{GREY} turn-1 failed{RESET}") + self.record(label, False, "turn-1 failed") + return + + r2 = self._post(body) + if r2 is None or r2.status_code != 200: + print(f"{RED}[FAIL]{RESET}{GREY} turn-2 failed{RESET}") + self.record(label, False, "turn-2 failed") + return + + t1_source = r1.headers.get("X-Olla-Sticky-Key-Source", "") + t2_sticky = r2.headers.get("X-Olla-Sticky-Session", "") + t1_ep = r1.headers.get("X-Olla-Endpoint", "") + t2_ep = r2.headers.get("X-Olla-Endpoint", "") + + notes = [] + ok = True + + if t1_source not in ("prefix_hash", "auth_header"): + # auth_header is also acceptable when Authorization header is present + notes.append(f"turn-1 source='{t1_source}' (expected prefix_hash or auth_header)") + if t2_sticky != "hit": + ok = False + notes.append(f"turn-2 expected 'hit', got '{t2_sticky}'") + if t1_ep and t2_ep and t1_ep != t2_ep: + ok = False + notes.append(f"endpoint changed: {t1_ep} → {t2_ep}") + + status_str = f"{GREEN}[PASS]{RESET}" if ok else f"{RED}[FAIL]{RESET}" + detail = f" {' | '.join(notes)}" if notes else f" source={t1_source} endpoint={t1_ep or 'n/a'}" + print(f"{status_str}{GREY}{detail}{RESET}") + self.record(label, ok, "; ".join(notes)) + + def _test_sessions_independent(self, model: str): + label = "affinity/sessions-independent" + self.pcolor(YELLOW, " Two sessions pin independently: ", end="") + + ts = int(time.time()) + sid_a = f"session-alpha-{ts}" + sid_b = f"session-beta-{ts}" + body = self._chat_body(model, system=LONG_SYSTEM_PROMPT) + + # Pin each session with a first (miss) request, then confirm second is a hit. + ra1 = self._post(body, self._chat_headers(sid_a)) + rb1 = self._post(body, self._chat_headers(sid_b)) + ra2 = self._post(body, self._chat_headers(sid_a)) + rb2 = self._post(body, self._chat_headers(sid_b)) + + responses = [ra1, rb1, ra2, rb2] + if any(r is None or r.status_code != 200 for r in responses): + print(f"{RED}[FAIL]{RESET}{GREY} one or more requests failed{RESET}") + self.record(label, False, "request failed") + return + + a2_sticky = ra2.headers.get("X-Olla-Sticky-Session", "") # type: ignore[union-attr] + b2_sticky = rb2.headers.get("X-Olla-Sticky-Session", "") # type: ignore[union-attr] + + notes = [] + ok = True + + if a2_sticky != "hit": + ok = False + notes.append(f"session-A turn-2 expected 'hit', got '{a2_sticky}'") + if b2_sticky != "hit": + ok = False + notes.append(f"session-B turn-2 expected 'hit', got '{b2_sticky}'") + + status_str = f"{GREEN}[PASS]{RESET}" if ok else f"{RED}[FAIL]{RESET}" + detail = f" {' | '.join(notes)}" if notes else " both sessions independently pinned" + print(f"{status_str}{GREY}{detail}{RESET}") + self.record(label, ok, "; ".join(notes)) + + def _test_different_sessions_may_use_different_backends(self, testable: List[BackendInfo]): + """ + With multiple backends, two sessions with different system prompts *may* + land on different backends. We cannot assert they must — load balancer + may fairly assign both to the same backend. We only verify both pin + (second request is a hit) without interfering with each other. + """ + label = "multi-backend/sessions-do-not-interfere" + self.pcolor(YELLOW, " Multi-backend: no cross-session interference: ", end="") + + model = testable[0].selected_model + ts = int(time.time()) + body_a = self._chat_body(model, system=LONG_SYSTEM_PROMPT) + body_b = self._chat_body(model, system=ALTERNATIVE_SYSTEM_PROMPT) + + sid_a = f"multi-alpha-{ts}" + sid_b = f"multi-beta-{ts}" + + ra1 = self._post(body_a, self._chat_headers(sid_a)) + rb1 = self._post(body_b, self._chat_headers(sid_b)) + ra2 = self._post(body_a, self._chat_headers(sid_a)) + rb2 = self._post(body_b, self._chat_headers(sid_b)) + + responses = [ra1, rb1, ra2, rb2] + if any(r is None or r.status_code != 200 for r in responses): + print(f"{RED}[FAIL]{RESET}{GREY} one or more requests failed{RESET}") + self.record(label, False, "request failed") + return + + a2_sticky = ra2.headers.get("X-Olla-Sticky-Session", "") # type: ignore[union-attr] + b2_sticky = rb2.headers.get("X-Olla-Sticky-Session", "") # type: ignore[union-attr] + a_ep = ra1.headers.get("X-Olla-Endpoint", "") # type: ignore[union-attr] + b_ep = rb1.headers.get("X-Olla-Endpoint", "") # type: ignore[union-attr] + + notes = [] + ok = True + + if a2_sticky != "hit": + ok = False + notes.append(f"session-A not pinned (got '{a2_sticky}')") + if b2_sticky != "hit": + ok = False + notes.append(f"session-B not pinned (got '{b2_sticky}')") + + ep_note = f" A→{a_ep or '?'} B→{b_ep or '?'}" if a_ep != b_ep else f" both→{a_ep or '?'}" + status_str = f"{GREEN}[PASS]{RESET}" if ok else f"{RED}[FAIL]{RESET}" + detail = f" {' | '.join(notes)}{ep_note}" if notes else ep_note + print(f"{status_str}{GREY}{detail}{RESET}") + self.record(label, ok, "; ".join(notes)) + + # ── Stats ──────────────────────────────────────────────────────────── + + def report_sticky_stats(self): + """Fetch and display the /internal/stats/sticky endpoint if available.""" + print() + self.pcolor(WHITE, "Sticky Session Stats") + self.pcolor(PURPLE, "=" * 72) + + try: + r = requests.get(f"{self.base_url}/internal/stats/sticky", timeout=self.timeout) + if r.status_code == 404: + self.pcolor(GREY, " Stats endpoint not available (HTTP 404)") + return + r.raise_for_status() + data = r.json() + except Exception as e: + self.pcolor(GREY, f" Could not fetch sticky stats: {e}") + return + + hits = data.get("hits", "N/A") + misses = data.get("misses", "N/A") + evictions = data.get("evictions", "N/A") + sessions = data.get("active_sessions", "N/A") + + self.pcolor(GREY, f" Active sessions: {CYAN}{sessions}{RESET}") + self.pcolor(GREY, f" Cache hits: {GREEN}{hits}{RESET}") + self.pcolor(GREY, f" Cache misses: {YELLOW}{misses}{RESET}") + self.pcolor(GREY, f" Evictions: {GREY}{evictions}{RESET}") + + # ── Summary ────────────────────────────────────────────────────────── + + def print_summary(self) -> bool: + print() + self.pcolor(PURPLE, "=" * 72) + self.pcolor(WHITE, f" {BOLD}Results Summary{RESET}") + self.pcolor(PURPLE, "=" * 72) + print() + + if self.results: + header = f" {'Test':<48} {'Result'}" + self.pcolor(GREY, header) + self.pcolor(GREY, f" {'-'*48} {'-'*6}") + + for r in self.results: + c = GREEN if r.passed else RED + mark = "PASS" if r.passed else "FAIL" + detail = f" {GREY}({r.detail}){RESET}" if r.detail and self.verbose else "" + label_short = r.name[:46] + ".." if len(r.name) > 48 else r.name + print(f" {label_short:<48} {c}{mark}{RESET}{detail}") + + passed = sum(1 for r in self.results if r.passed) + failed = sum(1 for r in self.results if not r.passed) + total = len(self.results) + + print() + self.pcolor(GREY, f" Total: {total} | ", end="") + self.pcolor(GREEN, f"Passed: {passed}", end="") + self.pcolor(GREY, " | ", end="") + self.pcolor(RED if failed else GREEN, f"Failed: {failed}") + print() + + all_pass = failed == 0 + if all_pass: + self.pcolor(GREEN, " All tests passed.") + else: + self.pcolor(RED, f" {failed} test(s) failed.") + + return all_pass + + +def main(): + parser = argparse.ArgumentParser( + description="Test Olla sticky session (KV cache affinity) routing" + ) + parser.add_argument("--url", default=TARGET_URL, + help=f"Olla base URL (default: {TARGET_URL})") + parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, + help=f"Request timeout in seconds (default: {DEFAULT_TIMEOUT})") + parser.add_argument("--model", + help="Force a specific model name instead of auto-discovery") + parser.add_argument("--skip-stats", action="store_true", + help="Skip the sticky stats endpoint check") + parser.add_argument("--verbose", action="store_true", + help="Show test detail in summary table") + args = parser.parse_args() + + tester = StickySessionTester(args.url, args.timeout, args.verbose, args.model) + tester.print_header() + + # Phase 1: Health and discovery + if not tester.check_health(): + sys.exit(1) + print() + if not tester.discover(): + sys.exit(1) + + testable = tester._testable() + if not testable: + tester.pcolor(RED, "\n[FAIL] No healthy backends with models available for testing") + sys.exit(1) + + model = testable[0].selected_model + + # Phase 2: Feature detection + tester.detect_sticky_enabled(model) + + # Phase 3: Header validation (runs even when disabled) + tester.run_header_validation(model) + + # Phase 4: Affinity tests (skipped when disabled) + tester.run_affinity_tests(model) + + # Phase 5: Multi-backend tests + tester.run_multi_backend_tests() + + # Phase 6: Stats endpoint + if not args.skip_stats: + tester.report_sticky_stats() + + # Phase 7: Summary + all_pass = tester.print_summary() + sys.exit(0 if all_pass else 1) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print(f"\n{YELLOW}Test interrupted by user (Ctrl+C){RESET}") + sys.exit(130)