diff --git a/CLAUDE.md b/CLAUDE.md index 3c20b2a3..b325ad41 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -217,7 +217,7 @@ Always run `make ready` before committing changes. - **Version Management**: Build-time version injection via `internal/version` ### Development Guidelines -- Go 1.24+ +- Go 1.24 (do not bump to 1.25; see Dependencies for held-back packages) - Australian English for comments and documentation - Comment on **why** rather than **what** - Always run `make ready` before committing @@ -237,6 +237,19 @@ Always run `make ready` before committing changes. Do not add additional dependencies unless explicitly asked. +### Go 1.24 Compatibility Pins + +Olla targets Go 1.24. From the versions listed below onward, the upstream `go` directive moves to 1.25, so these packages are held back: + +- `golang.org/x/sys` at v0.41.0 (v0.42.0+ requires Go 1.25) +- `golang.org/x/term` at v0.40.0 +- `golang.org/x/text` at v0.34.0 +- `golang.org/x/sync` at v0.19.0 +- `golang.org/x/time` at v0.14.0 +- `atomicgo.dev/keyboard` at v0.2.9 + +`go get -u ./...` will silently bump the toolchain to 1.25 by pulling these. After running it, check `go.mod` and pin the affected packages back to the versions above, or use `go get -u=patch ./...` to limit upgrades to patch releases only. + ## SUB-AGENT DELEGATION CRITICAL: Always delegate tasks to the appropriate subagent. Do NOT perform work directly in the main context. diff --git a/config/profiles/litellm.yaml b/config/profiles/litellm.yaml index 6e0a4e07..3501bdca 100644 --- a/config/profiles/litellm.yaml +++ b/config/profiles/litellm.yaml @@ -45,6 +45,11 @@ characteristics: max_concurrent_requests: 100 # LiteLLM handles high concurrency well default_priority: 95 # High priority as a unified gateway streaming_support: true + auth: + required: false # optional but common in production deployments + types: + - bearer # master key via Authorization: Bearer + - api_key # some versions accept x-goog-api-key or custom header # Detection hints for auto-discovery detection: diff --git a/config/profiles/llamacpp.yaml b/config/profiles/llamacpp.yaml index 6dbd5fb7..cffe1848 100644 --- a/config/profiles/llamacpp.yaml +++ b/config/profiles/llamacpp.yaml @@ -68,6 +68,10 @@ characteristics: default_priority: 95 # High priority for direct GGUF inference streaming_support: true single_model_server: true # important: One model per instance + auth: + required: false + types: + - bearer # enabled via --api-key flag # Detection hints for auto-discovery detection: diff --git a/config/profiles/ollama.yaml b/config/profiles/ollama.yaml index 4d082b90..2e706ee0 100644 --- a/config/profiles/ollama.yaml +++ b/config/profiles/ollama.yaml @@ -45,6 +45,10 @@ characteristics: max_concurrent_requests: 10 default_priority: 100 streaming_support: true + auth: + required: false + types: + - bearer # used by Ollama Cloud and protected remote instances # Detection hints for auto-discovery detection: diff --git a/config/profiles/openai-compatible.yaml b/config/profiles/openai-compatible.yaml index d40be0ca..695ba568 100644 --- a/config/profiles/openai-compatible.yaml +++ b/config/profiles/openai-compatible.yaml @@ -37,6 +37,11 @@ characteristics: max_concurrent_requests: 20 default_priority: 50 streaming_support: true + auth: + required: false + types: + - bearer # standard Authorization: Bearer for most OpenAI-compatible APIs + - api_key # some backends use a custom header (set header: in auth config) # Detection hints for auto-discovery detection: diff --git a/config/profiles/vllm.yaml b/config/profiles/vllm.yaml index 59d0308b..517d0572 100644 --- a/config/profiles/vllm.yaml +++ b/config/profiles/vllm.yaml @@ -65,6 +65,10 @@ characteristics: max_concurrent_requests: 100 default_priority: 80 streaming_support: true + auth: + required: false + types: + - bearer # enabled via --api-key flag # Detection hints for auto-discovery detection: diff --git a/docs/content/configuration/endpoint-auth-remote.md b/docs/content/configuration/endpoint-auth-remote.md new file mode 100644 index 00000000..c08f514f --- /dev/null +++ b/docs/content/configuration/endpoint-auth-remote.md @@ -0,0 +1,192 @@ +--- +title: Remote Backend Auth (Experimental) - Cloud API Recipes +description: Experimental recipes for using Olla with remote cloud APIs like Ollama Cloud, OpenRouter, and Groq. Understand the limitations and caveats before use. +keywords: olla remote, cloud api, ollama cloud, openrouter, groq, experimental +--- + +# Remote Backend Auth (Experimental) + +!!! warning "Not officially supported" + Olla is designed for **local, self-hosted inference backends**. Remote cloud APIs are not + a first-class use case. The recipes below work today for users who want to experiment, but + we make no guarantees about continued compatibility, and issues specific to cloud providers + will not be prioritised. + + If you want to use hosted APIs, consider LiteLLM as an intermediary. It handles the + provider-specific quirks, and Olla then talks to LiteLLM as a local OpenAI-compatible endpoint. + +## Why Cloud APIs Are Not First-Class + +Cloud inference APIs have operational characteristics that Olla does not currently handle: + +- **Rate limit headers** (`x-ratelimit-*`, `retry-after`): Olla does not parse or propagate + provider-specific rate limit signalling beyond honouring 429 for health state. +- **Path-prefix base URLs**: Some APIs require a base path in the URL + (e.g. `https://api.groq.com/openai/v1`). See below for how this interacts with health and + model discovery. +- **Cold-start latency**: Serverless-backed providers can have high first-token latency that + exceeds Olla's default health check timeouts. +- **Model namespacing**: Many cloud APIs use `provider/model-name` format. Olla's model + discovery and unification are tuned for local naming conventions. +- **No local health check**: Cloud APIs do not expose a `/health` endpoint. Health checks + against `/v1/models` incur real API calls and may consume quota. + +## URL Construction for Path-Prefixed Bases + +Olla joins discovery paths onto the base URL path using `path.Join`. For a base like +`https://api.groq.com/openai/v1`, the default health or model path `/v1/models` gets +joined as `/openai/v1/v1/models` -- a doubled prefix that silently breaks health checks and +model discovery. + +Set explicit absolute `health_check_url` and `model_url` values to bypass the join entirely. +`ResolveURLPath` returns absolute URLs as-is, so `https://api.groq.com/openai/v1/models` +goes to the wire unchanged. This only affects discovery; proxy-time URL building is +controlled separately by `preserve_path`. + +## What We Don't Promise + +- Health check accuracy for cloud endpoints +- Correct model listing or unification across local and remote endpoints +- Retry or backoff behaviour that respects provider-specific rate limiting +- Compatibility with provider authentication changes + +## Recipes + +These configurations work at the time of writing. Treat them as starting points, not +production-tested deployments. + +### Ollama Cloud + +Ollama Cloud (`https://ollama.com`) accepts bearer authentication. Set your API key from +[ollama.com/settings/keys](https://ollama.com/settings/keys). + +```yaml +discovery: + static: + endpoints: + - url: "https://ollama.com" + name: "ollama-cloud" + type: "ollama" + priority: 10 # lower than local instances + check_interval: 60s # avoid hammering cloud health checks + check_timeout: 10s + auth: + type: bearer + token: "${OLLAMA_CLOUD_API_KEY}" +``` + +**Known limitations:** + +- The Ollama Cloud API surface may differ from local Ollama. Model names include the namespace + (e.g. `hf.co/bartowski/Llama-3.2-3B-Instruct-GGUF`). +- Health check hits `/`, which works on the Ollama Cloud base URL. + +### OpenRouter + +OpenRouter exposes an OpenAI-compatible API at `https://openrouter.ai/api/v1`. The `/api/v1` +prefix path means you need `preserve_path: true` to prevent Olla from stripping it. + +```yaml +discovery: + static: + endpoints: + - url: "https://openrouter.ai/api/v1" + name: "openrouter" + type: "openai-compatible" + priority: 10 + preserve_path: true # required: prevents stripping the /api/v1 prefix + health_check_url: "https://openrouter.ai/api/v1/models" + model_url: "https://openrouter.ai/api/v1/models" + check_interval: 120s + check_timeout: 15s + auth: + type: bearer + token: "${OPENROUTER_API_KEY}" +``` + +**Known limitations:** + +- Health checks probe `/api/v1/models` which incurs an API call. Set `check_interval` high + to avoid burning quota. +- OpenRouter requires an `HTTP-Referer` header for attribution on some tiers. Use `headers:` + to set it: + + ```yaml + headers: + HTTP-Referer: "https://your-app.example.com" + X-Title: "Your App Name" + ``` + +- Model names include the provider prefix (e.g. `openai/gpt-4o`, `anthropic/claude-3-5-sonnet`). + These will not unify with local model names. + +### Groq + +Groq provides a fast OpenAI-compatible inference API. + +```yaml +discovery: + static: + endpoints: + - url: "https://api.groq.com/openai/v1" + name: "groq" + type: "openai-compatible" + priority: 10 + preserve_path: true + health_check_url: "https://api.groq.com/openai/v1/models" + model_url: "https://api.groq.com/openai/v1/models" + check_interval: 120s + check_timeout: 10s + auth: + type: bearer + token: "${GROQ_API_KEY}" +``` + +**Known limitations:** + +- Same health check cost caveat as OpenRouter. +- Groq's rate limits are aggressive on the free tier. A misconfigured health interval can + exhaust rate limits before any inference requests are made. + +## Mixing Local and Remote + +You can combine local and remote endpoints. Set priorities so local endpoints are strongly +preferred and remote endpoints act as overflow: + +```yaml +discovery: + static: + endpoints: + # Local, always preferred + - url: "http://localhost:8000" + name: "local-vllm" + type: "vllm" + priority: 100 + + # Remote fallback + - url: "https://api.groq.com/openai/v1" + name: "groq-fallback" + type: "openai-compatible" + priority: 5 + preserve_path: true + health_check_url: "https://api.groq.com/openai/v1/models" + model_url: "https://api.groq.com/openai/v1/models" + check_interval: 120s + auth: + type: bearer + token: "${GROQ_API_KEY}" +``` + +With `load_balancer: priority`, requests only reach the remote endpoint when all local +endpoints are unhealthy. + +## Community Contributions + +If you build cloud-specific profile YAML files or improve health check behaviour for cloud +APIs, PRs are welcome. See [Contributing](../development/contributing.md). + +## See Also + +- [Endpoint Authentication](endpoint-auth.md): auth configuration reference +- [Configuration Overview](overview.md): general configuration +- [LiteLLM Integration](../integrations/backend/litellm.md): recommended cloud API gateway diff --git a/docs/content/configuration/endpoint-auth.md b/docs/content/configuration/endpoint-auth.md new file mode 100644 index 00000000..18c5889f --- /dev/null +++ b/docs/content/configuration/endpoint-auth.md @@ -0,0 +1,286 @@ +--- +title: Endpoint Authentication - Configure Auth for Backend Endpoints +description: Configure bearer, API key, and basic authentication for Olla backend endpoints. Includes env interpolation, file-based secrets, Docker/Kubernetes examples, and recipes for vLLM, llama.cpp, and LiteLLM. +keywords: olla auth, endpoint authentication, bearer token, api key, basic auth, vllm auth, llamacpp auth, litellm auth +--- + +# Endpoint Authentication + +Olla can attach outbound authentication headers to requests forwarded to a backend endpoint. This is +for authenticating **Olla to the backend**. It has no bearing on how clients authenticate to Olla. + +## When to Use It + +Most local inference servers (Ollama, llama.cpp without `--api-key`) run without authentication. +You need `auth:` when: + +- A backend is started with an API key flag (e.g. `vllm --api-key`, `llama-server --api-key`) +- A backend sits behind a reverse proxy that requires credentials +- A LiteLLM gateway has a master key configured + +## Supported Types + +### `bearer` + +Sends `Authorization: Bearer `. + +```yaml +discovery: + static: + endpoints: + - url: "http://gpu-server:8000" + name: "vllm-gpu" + type: "vllm" + auth: + type: bearer + token: "sk-my-secret-token" +``` + +### `api_key` + +Sends a custom header (default `X-Api-Key`). Use `header:` to override. The raw credential +value is written to the header with no scheme prefix -- use `bearer` if the backend expects +`Authorization: Bearer `. + +```yaml + - url: "http://analytics-llm:9000" + name: "analytics-gw" + type: "openai-compatible" + auth: + type: api_key + key: "${ANALYTICS_API_KEY}" + header: "X-Api-Key" # optional, this is the default +``` + +### `basic` + +Sends `Authorization: Basic `. + +```yaml + - url: "http://internal-llm:8080" + name: "llamacpp-basic" + type: "llamacpp" + auth: + type: basic + username: "admin" + password: "s3cr3t" +``` + +## Environment Variable Interpolation + +Hardcoding credentials in config files is an antipattern. Use `${VAR}` placeholders instead: + +```yaml +auth: + type: bearer + token: "${VLLM_API_KEY}" +``` + +Olla expands these at startup using `ExpandStrict`. **If the variable is unset and has no default, +the process exits with a clear error**. This prevents silent misconfiguration. + +### Default Values + +Use `${VAR:-default}` for optional credentials or fallback values: + +```yaml +auth: + type: api_key + key: "${CUSTOM_API_KEY:-changeme}" +``` + +!!! warning "Defaults in production" + `:-default` is useful for development. In production, prefer requiring the variable explicitly + so a missing secret surfaces as a startup failure rather than silently using a fallback. + +## File-Based Secrets (`_file` suffix) + +Each credential field has a `_file` sibling that reads the value from a file path. This is the +standard pattern for Docker Secrets and Kubernetes mounted secrets, where a volume provides a +file containing a single secret value. + +```yaml +auth: + type: bearer + token_file: "/run/secrets/vllm_api_key" +``` + +The file contents are trimmed of leading/trailing whitespace. Setting both the inline field and +the `_file` field is a fatal startup error. + +### Available `_file` Fields + +| Auth type | Inline field | File field | +|-----------|-------------|------------| +| `bearer` | `token` | `token_file` | +| `api_key` | `key` | `key_file` | +| `basic` | `username` | `username_file` | +| `basic` | `password` | `password_file` | + +### Docker Compose Example + +```yaml +# docker-compose.yml +services: + olla: + image: ghcr.io/thushan/olla:latest + secrets: + - vllm_api_key + volumes: + - ./config.local.yaml:/app/config/config.local.yaml + +secrets: + vllm_api_key: + file: ./secrets/vllm_api_key.txt +``` + +```yaml +# config.local.yaml +discovery: + static: + endpoints: + - url: "http://vllm:8000" + name: "vllm" + type: "vllm" + auth: + type: bearer + token_file: "/run/secrets/vllm_api_key" +``` + +### Kubernetes Secret Example + +```yaml +apiVersion: v1 +kind: Secret +metadata: + name: olla-backend-creds +stringData: + vllm-token: "sk-my-token" +--- +# In your Deployment, mount as a volume or env var: +env: + - name: VLLM_API_KEY + valueFrom: + secretKeyRef: + name: olla-backend-creds + key: vllm-token +``` + +Then reference it from config: + +```yaml +auth: + type: bearer + token: "${VLLM_API_KEY}" +``` + +## The `headers:` Escape Hatch + +For backends that need authentication headers that don't fit bearer/api_key/basic, use the +`headers:` map directly. Headers set here are copied verbatim on every forwarded request. + +```yaml + - url: "http://custom-llm:9000" + name: "custom" + type: "openai-compatible" + headers: + X-Custom-Auth: "token abc123" + X-Tenant-ID: "acme" +``` + +`headers:` and `auth:` can coexist. The `auth:` block sets the `Authorization` (or configured) +header; `headers:` sets everything else. + +## Order of Precedence + +When a forwarded request is assembled, headers are applied in this order: + +1. **Client request headers** are stripped of hop-by-hop headers +2. **`headers:` map** values are set verbatim +3. **`auth:`** sets the credential header (overrides any `headers:` entry for the same name) + +The `auth:` block intentionally wins over `headers:` for the credential header. This prevents +an operator from accidentally overriding a resolved secret with a static `headers:` entry. + +## Fatal Startup Behaviour + +Auth validation runs before the HTTP server starts. The process exits immediately on: + +- Unknown `auth.type` (must be `bearer`, `api_key`, or `basic`) +- Both inline field and `_file` sibling set simultaneously +- Neither inline nor `_file` set for a required field +- `${VAR}` placeholder where `VAR` is unset and no `:-default` is provided +- File in `_file` field that does not exist or cannot be read + +This fail-fast behaviour is intentional: a proxy that silently starts without credentials and +forwards unauthenticated requests to a protected backend is harder to debug than a startup error. + +## Recipes + +### vLLM with `--api-key` + +Start vLLM: + +```bash +vllm serve meta-llama/Llama-3.1-8B-Instruct --api-key sk-my-key +``` + +Olla config: + +```yaml + - url: "http://vllm-host:8000" + name: "vllm-gpu" + type: "vllm" + auth: + type: bearer + token: "${VLLM_API_KEY}" +``` + +### llama.cpp with `--api-key` + +Start llama-server: + +```bash +llama-server -m model.gguf --api-key sk-my-key +``` + +Olla config: + +```yaml + - url: "http://llamacpp-host:8080" + name: "llamacpp" + type: "llamacpp" + auth: + type: bearer + token: "${LLAMACPP_API_KEY}" +``` + +### LiteLLM with Master Key + +Start LiteLLM proxy: + +```bash +litellm --config litellm_config.yaml --master_key sk-master +``` + +Olla config: + +```yaml + - url: "http://litellm:4000" + name: "litellm-gw" + type: "litellm" + auth: + type: bearer + token: "${LITELLM_MASTER_KEY}" +``` + +!!! note "LiteLLM API key format" + LiteLLM accepts the master key as a standard `Authorization: Bearer` header or as `x-goog-api-key` + depending on the version and configuration. Use `api_key` auth with `header: x-goog-api-key` if + bearer does not work for your deployment. + +## See Also + +- [Configuration Reference](reference.md): complete `auth:` field list +- [Security Best Practices](practices/security.md): production hardening +- [Experimental Remote Backends](endpoint-auth-remote.md): cloud API recipes diff --git a/docs/content/configuration/overview.md b/docs/content/configuration/overview.md index 94487ec0..db59c421 100644 --- a/docs/content/configuration/overview.md +++ b/docs/content/configuration/overview.md @@ -143,6 +143,8 @@ server: See [Rate Limiting Reference](reference.md#rate-limiting) for complete details. +Endpoints also support per-endpoint outbound authentication (`auth:`) and custom headers (`headers:`). See [Endpoint Authentication](endpoint-auth.md) for configuration details. + ## Proxy Configuration The `proxy` section controls request routing and proxy behaviour. @@ -386,6 +388,7 @@ Olla validates configuration on startup: ### Next Steps - [Configuration Reference](reference.md) - Complete configuration options +- [Endpoint Authentication](endpoint-auth.md) - Bearer, API key, and basic auth for backends - [Configuration Examples](examples.md) - Common configuration scenarios - [Configuration Best Practices](practices/configuration.md) - Native and Docker configuration strategies - [Best Practices](practices/overview.md) - Production recommendations diff --git a/docs/content/configuration/practices/security.md b/docs/content/configuration/practices/security.md index 3f8b16dd..f1d6b93d 100644 --- a/docs/content/configuration/practices/security.md +++ b/docs/content/configuration/practices/security.md @@ -324,6 +324,30 @@ Monitor these security events: - Failed health checks - Configuration changes +## Upstream Response Header Stripping + +Olla removes the following headers from upstream responses before returning them to clients: + +- `Authorization` +- `Proxy-Authorization` +- `Set-Cookie` +- `X-Api-Key` +- `X-Auth-Token` + +Any header named in an endpoint's `auth.header` field or in the `headers:` map is also stripped on the response side. This means operator-supplied custom auth header names are protected even when they do not appear in the list above. The reason: backends should not be able to set cookies on clients or reflect credentials back through Olla. + +!!! note "Custom header names" + If you configure a non-standard credential header (e.g. `auth.header: X-My-Token`), Olla strips `X-My-Token` from responses as well. No additional configuration is needed. + +## Secrets Resolution + +Credential values in `auth:` and `headers:` blocks support two forms: + +- **`${VAR}`**: resolved from the environment at startup. An unset variable with no `:-default` is a fatal error, so misconfigured auth surfaces before the server starts accepting traffic. +- **`_file` fields** (`token_file`, `key_file`, `username_file`, `password_file`): reads the secret from a file path and trims whitespace. The standard pattern for Docker Secrets and Kubernetes mounted volumes. + +Setting both the inline field and its `_file` sibling is a fatal startup error. + ## Secrets Management ### Configuration Files diff --git a/docs/content/configuration/reference.md b/docs/content/configuration/reference.md index 1739fa0e..c4ae504b 100644 --- a/docs/content/configuration/reference.md +++ b/docs/content/configuration/reference.md @@ -274,6 +274,68 @@ discovery: | `static.endpoints[].check_interval` | duration | No | Health check interval (default: `5s`) | | `static.endpoints[].check_timeout` | duration | No | Health check timeout (default: `2s`) | | `static.endpoints[].model_filter` | object | No | Model filtering for this endpoint | +| `static.endpoints[].auth` | object | No | Outbound authentication credentials (see below) | +| `static.endpoints[].headers` | map[string]string | No | Custom outbound headers applied on every forwarded request | + +#### Endpoint Authentication (`auth:`) + +Attaches credentials to requests forwarded from Olla to a backend. See [Endpoint Authentication](endpoint-auth.md) for the full guide. + +| Field | Type | Description | +|-------|------|-------------| +| `auth.type` | string | `bearer`, `api_key`, or `basic` | +| `auth.token` | string | Bearer token value. Sends `Authorization: Bearer `. Mutually exclusive with `token_file`. | +| `auth.token_file` | string | Path to a file containing the bearer token. | +| `auth.key` | string | API key value. Mutually exclusive with `key_file`. | +| `auth.key_file` | string | Path to a file containing the API key. | +| `auth.header` | string | Header name for `api_key` type (default: `X-Api-Key`). | +| `auth.username` | string | Username for `basic` type. Mutually exclusive with `username_file`. | +| `auth.username_file` | string | Path to a file containing the username. | +| `auth.password` | string | Password for `basic` type. Mutually exclusive with `password_file`. | +| `auth.password_file` | string | Path to a file containing the password. | + +`${VAR}` interpolation works on every value field. `_file` fields read and trim the file contents at startup. Setting both the inline field and its `_file` sibling is a fatal error, as is an unresolved `${VAR}` with no default. + +**Bearer example:** + +```yaml +discovery: + static: + endpoints: + - url: "http://gpu-server:8000" + name: "vllm-gpu" + type: "vllm" + auth: + type: bearer + token: "${VLLM_API_KEY}" +``` + +**API key with custom header:** + +```yaml + - url: "http://custom-gw:9000" + name: "custom-gw" + type: "openai-compatible" + auth: + type: api_key + key: "${CUSTOM_API_KEY}" + header: "X-Api-Key" + headers: + X-Tenant-ID: "team-a" +``` + +#### Custom Outbound Headers (`headers:`) + +`headers:` is a free-form map of header names to values. All entries are copied verbatim onto every request forwarded to that endpoint. `auth:` and `headers:` can coexist; the `auth:` block always wins for its own credential header. `${VAR}` interpolation applies to values. + +```yaml + - url: "http://custom-llm:9000" + name: "custom" + type: "openai-compatible" + headers: + X-Tenant-ID: "acme" + X-Request-Source: "olla" +``` #### URL Configuration diff --git a/docs/content/development/setup.md b/docs/content/development/setup.md index 7c79bcad..8824a346 100644 --- a/docs/content/development/setup.md +++ b/docs/content/development/setup.md @@ -12,7 +12,7 @@ This guide covers setting up a complete development environment for Olla. ### Required -- **Go 1.24+**: [Download Go](https://golang.org/dl/) +- **Go 1.24**: [Download Go](https://golang.org/dl/). Olla pins to 1.24; several `golang.org/x/*` packages have moved to 1.25 and are held back. See the dependencies note in `CLAUDE.md` if you plan to update them. - **Git**: For version control - **Make**: Build automation diff --git a/docs/content/faq.md b/docs/content/faq.md index ef1d6d5f..a8d476d8 100644 --- a/docs/content/faq.md +++ b/docs/content/faq.md @@ -280,6 +280,24 @@ discovery: interval: 15m # Less frequent discovery ``` +## Authentication + +### Why does my endpoint show `config_error`? + +A `config_error` status means Olla received a 401 or 403 from the backend during a health probe. This is an auth misconfiguration, not a network failure. The backend is reachable but rejecting the credentials. Check that the `auth.token`, `auth.key`, or `auth.password` value configured on the endpoint matches what the backend expects. + +### What does `rate_limited` mean? + +The health probe received a 429 (Too Many Requests) response. Olla marks the endpoint as `rate_limited` and honours the `Retry-After` header if present. Probing resumes automatically once the wait period expires. This is most common when health checks are running too frequently against a rate-limited backend; increase `check_interval` if it happens repeatedly. + +### How do I authenticate to a backend protected by `--api-key`? + +Use `auth.type: bearer` on the endpoint. Both vLLM (`vllm serve --api-key`) and llama.cpp (`llama-server --api-key`) treat the value as a bearer token checked against the `Authorization` header. See [Endpoint Authentication](configuration/endpoint-auth.md) for full configuration and Docker/Kubernetes examples. + +### Olla refuses to start with a `${VAR}` error + +The environment variable referenced in your config was not set (or not exported) when Olla started. This is intentional: Olla uses fail-fast expansion so a missing secret surfaces as a startup error rather than silently forwarding unauthenticated requests. Export the variable before starting Olla, or use the `_file` form (`token_file`, `key_file`, etc.) for container and Kubernetes deployments where secrets are mounted as files. + ## Common Issues ### "No healthy endpoints available" diff --git a/docs/content/index.md b/docs/content/index.md index 7c918c13..152a4d70 100644 --- a/docs/content/index.md +++ b/docs/content/index.md @@ -29,6 +29,9 @@ Olla is a high-performance, low-overhead, low-latency proxy and load balancer fo Olla works alongside API gateways like [LiteLLM](https://github.com/BerriAI/litellm) or orchestration platforms like [GPUStack](https://github.com/gpustack/gpustack), focusing on making your **existing** LLM infrastructure reliable through intelligent routing and failover. You can choose between two proxy engines: **Sherpa** for simplicity and maintainability or **Olla** for maximum performance with advanced features like circuit breakers and connection pooling. +!!! info "Local-First" + Olla is built for local, self-hosted inference: Ollama, llama.cpp, vLLM, LM Studio, LiteLLM, SGLang, and similar engines running on hardware you control. Remote authenticated APIs (Ollama Cloud, OpenAI, Anthropic, OpenRouter, Groq, etc.) are not a first-class use case. The auth machinery is generic enough to point at them, but Olla makes no guarantees about health check accuracy, rate limit handling, or model unification for cloud providers. If you want to proxy remote APIs, see [Remote Backend Auth (Experimental)](configuration/endpoint-auth-remote.md). + ## Key Features - **Unified Model Registry**: Unifies models registered across instances (of the same type - Eg. Ollama or LMStudio) diff --git a/docs/content/integrations/backend/llamacpp.md b/docs/content/integrations/backend/llamacpp.md index 52f8457f..c5c4fa47 100644 --- a/docs/content/integrations/backend/llamacpp.md +++ b/docs/content/integrations/backend/llamacpp.md @@ -110,6 +110,9 @@ discovery: # Profile handles health checks and model discovery ``` +!!! tip "Authenticated llama.cpp" + `llama-server` accepts `--api-key` (or `--api-key-file` for file-based secrets) to protect the HTTP endpoint. Configure the matching credential with `auth.type: bearer` on the endpoint. See [Endpoint Authentication](../../configuration/endpoint-auth.md) for details. + ### Production Setup Configure llama.cpp for production with proper timeouts: diff --git a/docs/content/integrations/backend/lmdeploy.md b/docs/content/integrations/backend/lmdeploy.md index afcbfd60..9f77b6f6 100644 --- a/docs/content/integrations/backend/lmdeploy.md +++ b/docs/content/integrations/backend/lmdeploy.md @@ -99,7 +99,12 @@ The default port for `lmdeploy serve api_server` is **23333**. Register individu ### Authentication -LMDeploy supports optional Bearer-token authentication via the `--api-keys` flag. Configure the token in Olla's endpoint headers so it is forwarded on every proxied request: +LMDeploy supports optional Bearer-token authentication via the `--api-keys` flag (plural). + +!!! tip "Use the `auth:` block" + Prefer the structured `auth:` block over a raw `headers:` entry. Olla validates credentials + at startup and resolves `${VAR}` placeholders and `_file` secrets consistently across all + auth types. See [Endpoint Authentication](../../configuration/endpoint-auth.md) for details. ```yaml discovery: @@ -112,8 +117,9 @@ discovery: health_check_url: "/health" check_interval: 10s check_timeout: 5s - headers: - Authorization: "Bearer ${LMDEPLOY_API_KEY}" + auth: + type: bearer + token: "${LMDEPLOY_API_KEY}" ``` The `/health` endpoint is auth-exempt on LMDeploy, so health checks will succeed even when a key is required for inference. diff --git a/docs/content/integrations/backend/lmstudio.md b/docs/content/integrations/backend/lmstudio.md index cebfd924..5136ce04 100644 --- a/docs/content/integrations/backend/lmstudio.md +++ b/docs/content/integrations/backend/lmstudio.md @@ -95,6 +95,9 @@ discovery: check_timeout: 1s ``` +!!! tip "Authenticated LM Studio" + LM Studio itself does not require authentication, but if you place a reverse proxy in front of it that does, the same `auth:` block applies. Configure the credential Olla should present with `auth.type: bearer` or `auth.type: basic` on the endpoint. See [Endpoint Authentication](../../configuration/endpoint-auth.md) for details. + ### Multiple LM Studio Instances Run multiple LM Studio servers on different ports: diff --git a/docs/content/integrations/backend/ollama.md b/docs/content/integrations/backend/ollama.md index 151c2729..c33cd2a6 100644 --- a/docs/content/integrations/backend/ollama.md +++ b/docs/content/integrations/backend/ollama.md @@ -137,8 +137,9 @@ discovery: check_timeout: 5s ``` -!!! note "Authentication Not Supported" - Olla does not currently support authentication headers for endpoints. If your Ollama server requires authentication, you'll need to use a reverse proxy or wait for this feature to be added. +!!! tip "Authenticated Ollama" + To authenticate to a remote or protected Ollama instance, use the `auth:` block on the endpoint. + See [Endpoint Authentication](../../configuration/endpoint-auth.md) for details. ## Anthropic Messages API Support diff --git a/docs/content/integrations/backend/vllm.md b/docs/content/integrations/backend/vllm.md index 9610c872..5c0835a0 100644 --- a/docs/content/integrations/backend/vllm.md +++ b/docs/content/integrations/backend/vllm.md @@ -96,6 +96,9 @@ discovery: check_timeout: 2s ``` +!!! tip "Authenticated vLLM" + vLLM supports the `--api-key` flag to require a bearer token on all requests. Configure the matching credential with `auth.type: bearer` on the endpoint. See [Endpoint Authentication](../../configuration/endpoint-auth.md) for details. + ### Production Setup Configure vLLM for high-throughput production: diff --git a/docs/content/integrations/frontend/openwebui.md b/docs/content/integrations/frontend/openwebui.md index 1bfb0dc6..92d896ec 100644 --- a/docs/content/integrations/frontend/openwebui.md +++ b/docs/content/integrations/frontend/openwebui.md @@ -461,12 +461,9 @@ endpoints: ### Authentication -!!! warning "Authentication Not Supported" - Olla does not currently support authentication headers for endpoints. If your API requires authentication, you'll need to: - - - Use a reverse proxy that adds authentication - - Wait for this feature to be implemented - - Access only public/local endpoints +!!! tip "Authenticated Endpoints" + Olla supports per-endpoint authentication (bearer, api_key, and basic) for backends that + require credentials. See [Endpoint Authentication](../../configuration/endpoint-auth.md). ### Custom Networks diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 306e59c7..7f540260 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -143,6 +143,8 @@ nav: - Provider Metrics: concepts/provider-metrics.md - Configuration: - Overview: configuration/overview.md + - Endpoint Auth: configuration/endpoint-auth.md + - Remote Backends (Experimental): configuration/endpoint-auth-remote.md - Filters: configuration/filters.md - Reference: configuration/reference.md - Examples: configuration/examples.md diff --git a/go.mod b/go.mod index 72412c8f..c04f9e7a 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,11 @@ require ( github.com/expr-lang/expr v1.17.8 github.com/jellydator/ttlcache/v3 v3.4.0 github.com/json-iterator/go v1.1.12 - github.com/mattn/go-isatty v0.0.21 + github.com/mattn/go-isatty v0.0.22 github.com/pterm/pterm v0.12.83 github.com/puzpuzpuz/xsync/v4 v4.5.0 github.com/stretchr/testify v1.11.1 - github.com/tidwall/gjson v1.18.0 + github.com/tidwall/gjson v1.19.0 golang.org/x/sync v0.19.0 golang.org/x/time v0.14.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -25,10 +25,9 @@ require ( github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/containerd/console v1.0.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gookit/color v1.6.0 // indirect - github.com/kr/pretty v0.3.1 // indirect + github.com/gookit/color v1.6.1 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect - github.com/mattn/go-runewidth v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.23 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 81d0a0d3..5abf5257 100644 --- a/go.sum +++ b/go.sum @@ -21,7 +21,6 @@ github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJ github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/containerd/console v1.0.5 h1:R0ymNeydRqH2DmakFNdmjR2k0t7UPuiOV/N/27/qqsc= github.com/containerd/console v1.0.5/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -34,8 +33,8 @@ github.com/gookit/assert v0.1.1 h1:lh3GcawXe/p+cU7ESTZ5Ui3Sm/x8JWpIis4/1aF0mY0= github.com/gookit/assert v0.1.1/go.mod h1:jS5bmIVQZTIwk42uXl4lyj4iaaxx32tqH16CFj0VX2E= github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQHCoQ= github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= -github.com/gookit/color v1.6.0 h1:JjJXBTk1ETNyqyilJhkTXJYYigHG24TM9Xa2M1xAhRA= -github.com/gookit/color v1.6.0/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs= +github.com/gookit/color v1.6.1 h1:KoTnDxJPRgrL0SoX0f8rCFg2zI0t4E3GZZBMo2nN8LU= +github.com/gookit/color v1.6.1/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs= github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -45,26 +44,24 @@ github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuOb github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4= github.com/lithammer/fuzzysearch v1.1.8/go.mod h1:IdqeyBClc3FFqSzYq/MXESsS4S0FsZ5ajtkr5xPLts4= -github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= -github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4= +github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjcQQaQ= -github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= +github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pterm/pterm v0.12.27/go.mod h1:PhQ89w4i95rhgE+xedAoqous6K9X+r6aSOI2eFF7DZI= @@ -79,8 +76,6 @@ github.com/pterm/pterm v0.12.83/go.mod h1:xlgc6bFWyJIMtmLJvGim+L7jhSReilOlOnodeI github.com/puzpuzpuz/xsync/v4 v4.5.0 h1:vOSWu6b57/emh+L/Cw0BeQfvxa/cogFywXHeGUxQxAg= github.com/puzpuzpuz/xsync/v4 v4.5.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -90,12 +85,10 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= -github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU= +github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc= github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= diff --git a/internal/adapter/discovery/auth.go b/internal/adapter/discovery/auth.go new file mode 100644 index 00000000..5a3f5dda --- /dev/null +++ b/internal/adapter/discovery/auth.go @@ -0,0 +1,227 @@ +package discovery + +import ( + "encoding/base64" + "fmt" + "strings" + + "github.com/thushan/olla/internal/config" + "github.com/thushan/olla/internal/core/constants" + "github.com/thushan/olla/pkg/envresolver" +) + +// resolvedAuth holds the credential values after all env/file references have +// been expanded. It is an intermediate type that does not cross package boundaries. +type resolvedAuth struct { + // username and password are only populated for basic auth. + username string + password string + // credential holds the resolved token (bearer) or key (api_key). + credential string + // header overrides the default header name for api_key auth. + header string + authType string +} + +// resolveAuth expands env placeholders and reads _file siblings for all +// credential fields in cfg. It must be called after validateAuth succeeds. +func resolveAuth(name string, cfg *config.AuthConfig) (resolvedAuth, error) { + switch cfg.Type { + case constants.AuthTypeBearer: + return resolveBearerAuth(name, cfg) + case constants.AuthTypeAPIKey: + return resolveAPIKeyAuth(name, cfg) + case constants.AuthTypeBasic: + return resolveBasicAuth(name, cfg) + default: + // unreachable; validateAuth guards this + return resolvedAuth{}, fmt.Errorf("endpoint %q: unknown auth type %q", name, cfg.Type) + } +} + +func resolveBearerAuth(name string, cfg *config.AuthConfig) (resolvedAuth, error) { + token, err := resolveValueOrFile(cfg.Token, cfg.TokenFile) + if err != nil { + return resolvedAuth{}, fmt.Errorf("endpoint %q: bearer token: %w", name, err) + } + if token == "" { + return resolvedAuth{}, fmt.Errorf("endpoint %q: bearer token resolved to empty string", name) + } + return resolvedAuth{authType: constants.AuthTypeBearer, credential: token}, nil +} + +func resolveAPIKeyAuth(name string, cfg *config.AuthConfig) (resolvedAuth, error) { + key, err := resolveValueOrFile(cfg.Key, cfg.KeyFile) + if err != nil { + return resolvedAuth{}, fmt.Errorf("endpoint %q: api_key: %w", name, err) + } + if key == "" { + return resolvedAuth{}, fmt.Errorf("endpoint %q: api_key resolved to empty string", name) + } + + header := cfg.Header + if header == "" { + header = constants.AuthDefaultAPIKeyHeader + } + return resolvedAuth{authType: constants.AuthTypeAPIKey, credential: key, header: header}, nil +} + +func resolveBasicAuth(name string, cfg *config.AuthConfig) (resolvedAuth, error) { + username, err := resolveValueOrFile(cfg.Username, cfg.UsernameFile) + if err != nil { + return resolvedAuth{}, fmt.Errorf("endpoint %q: basic username: %w", name, err) + } + if username == "" { + return resolvedAuth{}, fmt.Errorf("endpoint %q: basic username resolved to empty string", name) + } + + password, err := resolveValueOrFile(cfg.Password, cfg.PasswordFile) + if err != nil { + return resolvedAuth{}, fmt.Errorf("endpoint %q: basic password: %w", name, err) + } + if password == "" { + return resolvedAuth{}, fmt.Errorf("endpoint %q: basic password resolved to empty string", name) + } + + return resolvedAuth{authType: constants.AuthTypeBasic, username: username, password: password}, nil +} + +// resolveValueOrFile handles the inline-value / _file-sibling pattern. +// If fileValue is non-empty the file is read; otherwise the inline value is +// expanded through envresolver. ExpandWithFile already rejects the both-set case. +func resolveValueOrFile(value, fileValue string) (string, error) { + // ExpandWithFile covers: both set → error, file-only → read+trim, neither → "". + // For the inline path it calls Expand (non-strict). We re-enter with ExpandStrict + // when we actually have an inline placeholder to catch missing vars. + if fileValue != "" { + return envresolver.ExpandWithFile("", fileValue) + } + if value == "" { + return "", nil + } + return envresolver.ExpandStrict(value) +} + +// precomputeAuthHeaders builds the final AuthHeaderName and AuthHeaderValue +// from a resolved credential. Using strings.Builder keeps this allocation-free +// when the hot path copies these pre-built strings into request headers. +func precomputeAuthHeaders(r resolvedAuth) (headerName, headerValue string) { + switch r.authType { + case constants.AuthTypeBearer: + var b strings.Builder + b.Grow(len(constants.AuthSchemeBearer) + len(r.credential)) + b.WriteString(constants.AuthSchemeBearer) + b.WriteString(r.credential) + return constants.AuthHeaderAuthorization, b.String() + + case constants.AuthTypeAPIKey: + return r.header, r.credential + + case constants.AuthTypeBasic: + raw := r.username + ":" + r.password + encoded := base64.StdEncoding.EncodeToString([]byte(raw)) + var b strings.Builder + b.Grow(len(constants.AuthSchemeBasic) + len(encoded)) + b.WriteString(constants.AuthSchemeBasic) + b.WriteString(encoded) + return constants.AuthHeaderAuthorization, b.String() + } + + return "", "" +} + +// validateAuth checks the shape of an auth block before any env/file resolution +// happens. All conflicts and missing required fields are caught here so the +// process fails fast with a clear message that names the offending endpoint. +func validateAuth(name string, cfg *config.AuthConfig) error { + if !constants.IsValidAuthType(cfg.Type) { + return fmt.Errorf("endpoint %q: auth.type %q is not valid (must be bearer, api_key, or basic)", name, cfg.Type) + } + + switch cfg.Type { + case constants.AuthTypeBearer: + return validateBearerAuth(name, cfg) + case constants.AuthTypeAPIKey: + return validateAPIKeyAuth(name, cfg) + case constants.AuthTypeBasic: + return validateBasicAuth(name, cfg) + } + + // unreachable; IsValidAuthType guards the switch above + return nil +} + +func validateBearerAuth(name string, cfg *config.AuthConfig) error { + hasToken := cfg.Token != "" + hasTokenFile := cfg.TokenFile != "" + + if hasToken && hasTokenFile { + return fmt.Errorf("endpoint %q: bearer auth has both token and token_file set; use exactly one", name) + } + if !hasToken && !hasTokenFile { + return fmt.Errorf("endpoint %q: bearer auth requires token or token_file", name) + } + + // Fields that must not appear for this auth type + if cfg.Key != "" || cfg.KeyFile != "" { + return fmt.Errorf("endpoint %q: bearer auth does not accept key/key_file fields", name) + } + if cfg.Username != "" || cfg.UsernameFile != "" || cfg.Password != "" || cfg.PasswordFile != "" { + return fmt.Errorf("endpoint %q: bearer auth does not accept username/password fields", name) + } + + return nil +} + +func validateAPIKeyAuth(name string, cfg *config.AuthConfig) error { + hasKey := cfg.Key != "" + hasKeyFile := cfg.KeyFile != "" + + if hasKey && hasKeyFile { + return fmt.Errorf("endpoint %q: api_key auth has both key and key_file set; use exactly one", name) + } + if !hasKey && !hasKeyFile { + return fmt.Errorf("endpoint %q: api_key auth requires key or key_file", name) + } + + // Fields that must not appear for this auth type + if cfg.Token != "" || cfg.TokenFile != "" { + return fmt.Errorf("endpoint %q: api_key auth does not accept token/token_file fields", name) + } + if cfg.Username != "" || cfg.UsernameFile != "" || cfg.Password != "" || cfg.PasswordFile != "" { + return fmt.Errorf("endpoint %q: api_key auth does not accept username/password fields", name) + } + + return nil +} + +func validateBasicAuth(name string, cfg *config.AuthConfig) error { + hasUsername := cfg.Username != "" + hasUsernameFile := cfg.UsernameFile != "" + hasPassword := cfg.Password != "" + hasPasswordFile := cfg.PasswordFile != "" + + if hasUsername && hasUsernameFile { + return fmt.Errorf("endpoint %q: basic auth has both username and username_file set; use exactly one", name) + } + if !hasUsername && !hasUsernameFile { + return fmt.Errorf("endpoint %q: basic auth requires username or username_file", name) + } + + if hasPassword && hasPasswordFile { + return fmt.Errorf("endpoint %q: basic auth has both password and password_file set; use exactly one", name) + } + if !hasPassword && !hasPasswordFile { + return fmt.Errorf("endpoint %q: basic auth requires password or password_file", name) + } + + // Fields that must not appear for this auth type + if cfg.Token != "" || cfg.TokenFile != "" { + return fmt.Errorf("endpoint %q: basic auth does not accept token/token_file fields", name) + } + if cfg.Key != "" || cfg.KeyFile != "" { + return fmt.Errorf("endpoint %q: basic auth does not accept key/key_file fields", name) + } + + return nil +} diff --git a/internal/adapter/discovery/auth_test.go b/internal/adapter/discovery/auth_test.go new file mode 100644 index 00000000..ec0a2d0b --- /dev/null +++ b/internal/adapter/discovery/auth_test.go @@ -0,0 +1,535 @@ +package discovery + +import ( + "context" + "encoding/base64" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/thushan/olla/internal/config" +) + +// validEndpointBase returns a minimal EndpointConfig that passes all +// non-auth validation, so auth tests can focus on auth behaviour only. +func validEndpointBase(name string) config.EndpointConfig { + p := 100 + return config.EndpointConfig{ + Name: name, + URL: "http://localhost:11434", + Type: "ollama", + Priority: &p, + CheckInterval: 5 * time.Second, + CheckTimeout: 2 * time.Second, + } +} + +// TestValidateAuth_Shape exercises the pure shape-validation rules against +// all three auth types before any env or file resolution takes place. +func TestValidateAuth_Shape(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + auth config.AuthConfig + wantErr bool + }{ + // ── bearer ────────────────────────────────────────────────────────── + { + name: "bearer with token", + auth: config.AuthConfig{Type: "bearer", Token: "tok"}, + wantErr: false, + }, + { + name: "bearer with token_file", + auth: config.AuthConfig{Type: "bearer", TokenFile: "/run/secrets/token"}, + wantErr: false, + }, + { + name: "bearer both token and token_file", + auth: config.AuthConfig{Type: "bearer", Token: "tok", TokenFile: "/run/secrets/token"}, + wantErr: true, + }, + { + name: "bearer missing both token and token_file", + auth: config.AuthConfig{Type: "bearer"}, + wantErr: true, + }, + { + name: "bearer with forbidden key field", + auth: config.AuthConfig{Type: "bearer", Token: "tok", Key: "k"}, + wantErr: true, + }, + { + name: "bearer with forbidden username field", + auth: config.AuthConfig{Type: "bearer", Token: "tok", Username: "u"}, + wantErr: true, + }, + { + name: "bearer with forbidden password field", + auth: config.AuthConfig{Type: "bearer", Token: "tok", Password: "p"}, + wantErr: true, + }, + + // ── api_key ───────────────────────────────────────────────────────── + { + name: "api_key with key", + auth: config.AuthConfig{Type: "api_key", Key: "k"}, + wantErr: false, + }, + { + name: "api_key with key_file", + auth: config.AuthConfig{Type: "api_key", KeyFile: "/run/secrets/key"}, + wantErr: false, + }, + { + name: "api_key with optional header override", + auth: config.AuthConfig{Type: "api_key", Key: "k", Header: "X-My-Key"}, + wantErr: false, + }, + { + name: "api_key both key and key_file", + auth: config.AuthConfig{Type: "api_key", Key: "k", KeyFile: "/run/secrets/key"}, + wantErr: true, + }, + { + name: "api_key missing both key and key_file", + auth: config.AuthConfig{Type: "api_key"}, + wantErr: true, + }, + { + name: "api_key with forbidden token field", + auth: config.AuthConfig{Type: "api_key", Key: "k", Token: "t"}, + wantErr: true, + }, + { + name: "api_key with forbidden username field", + auth: config.AuthConfig{Type: "api_key", Key: "k", Username: "u"}, + wantErr: true, + }, + + // ── basic ──────────────────────────────────────────────────────────── + { + name: "basic with inline credentials", + auth: config.AuthConfig{Type: "basic", Username: "user", Password: "pass"}, + wantErr: false, + }, + { + name: "basic with file credentials", + auth: config.AuthConfig{Type: "basic", UsernameFile: "/run/secrets/user", PasswordFile: "/run/secrets/pass"}, + wantErr: false, + }, + { + name: "basic mixed inline and file", + auth: config.AuthConfig{Type: "basic", Username: "user", PasswordFile: "/run/secrets/pass"}, + wantErr: false, + }, + { + name: "basic both username and username_file", + auth: config.AuthConfig{Type: "basic", Username: "u", UsernameFile: "/f", Password: "p"}, + wantErr: true, + }, + { + name: "basic both password and password_file", + auth: config.AuthConfig{Type: "basic", Username: "u", Password: "p", PasswordFile: "/f"}, + wantErr: true, + }, + { + name: "basic missing username", + auth: config.AuthConfig{Type: "basic", Password: "p"}, + wantErr: true, + }, + { + name: "basic missing password", + auth: config.AuthConfig{Type: "basic", Username: "u"}, + wantErr: true, + }, + { + name: "basic with forbidden token field", + auth: config.AuthConfig{Type: "basic", Username: "u", Password: "p", Token: "t"}, + wantErr: true, + }, + { + name: "basic with forbidden key field", + auth: config.AuthConfig{Type: "basic", Username: "u", Password: "p", Key: "k"}, + wantErr: true, + }, + + // ── unknown type ───────────────────────────────────────────────────── + { + name: "unknown auth type", + auth: config.AuthConfig{Type: "oauth2", Token: "tok"}, + wantErr: true, + }, + { + name: "empty auth type", + auth: config.AuthConfig{}, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := validateAuth("test-endpoint", &tc.auth) + if tc.wantErr && err == nil { + t.Error("expected validation error, got nil") + } + if !tc.wantErr && err != nil { + t.Errorf("unexpected validation error: %v", err) + } + }) + } +} + +// TestLoadFromConfig_AuthValidation_RejectsInvalidShape verifies that +// LoadFromConfig propagates auth validation errors and surfaces the endpoint name. +func TestLoadFromConfig_AuthValidation_RejectsInvalidShape(t *testing.T) { + t.Parallel() + + p := 100 + badAuth := config.EndpointConfig{ + Name: "bad-auth-ep", + URL: "http://localhost:11434", + Type: "ollama", + Priority: &p, + CheckInterval: 5 * time.Second, + CheckTimeout: 2 * time.Second, + Auth: &config.AuthConfig{ + // bearer with both token and token_file must fail + Type: "bearer", + Token: "tok", + TokenFile: "/run/secrets/token", + }, + } + + repo := NewStaticEndpointRepository() + err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{badAuth}) + if err == nil { + t.Fatal("expected error for conflicting token/token_file, got nil") + } +} + +// TestLoadFromConfig_AuthNil_Succeeds confirms that endpoints without an auth +// block load normally. Auth is always optional. +func TestLoadFromConfig_AuthNil_Succeeds(t *testing.T) { + t.Parallel() + + cfg := validEndpointBase("no-auth") + // Auth is nil by default + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig with nil auth failed: %v", err) + } +} + +// ── Commit 2: env/file resolution ─────────────────────────────────────────── + +// TestAuthResolve_EnvVar_ExpandsToken verifies that a ${VAR} placeholder in the +// token field is expanded through the environment at load time. +func TestAuthResolve_EnvVar_ExpandsToken(t *testing.T) { + t.Setenv("OLLA_TEST_TOKEN", "resolved-secret") + + cfg := validEndpointBase("bearer-env") + cfg.Auth = &config.AuthConfig{Type: "bearer", Token: "${OLLA_TEST_TOKEN}"} + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + if len(eps) != 1 { + t.Fatalf("expected 1 endpoint, got %d", len(eps)) + } + if !strings.HasSuffix(eps[0].AuthHeaderValue, "resolved-secret") { + t.Errorf("AuthHeaderValue %q does not end with resolved token", eps[0].AuthHeaderValue) + } +} + +// TestAuthResolve_MissingEnvVar_FatalError verifies that an unset ${VAR} in an +// auth field causes a startup-fatal error that mentions the endpoint name. +func TestAuthResolve_MissingEnvVar_FatalError(t *testing.T) { + t.Parallel() + + // Guarantee the var is absent + varName := "OLLA_DEFINITELY_UNSET_VAR_XYZ" + os.Unsetenv(varName) //nolint:errcheck + + cfg := validEndpointBase("bearer-missing-env") + cfg.Auth = &config.AuthConfig{Type: "bearer", Token: fmt.Sprintf("${%s}", varName)} + + repo := NewStaticEndpointRepository() + err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}) + if err == nil { + t.Fatal("expected error for unset env var, got nil") + } + if !strings.Contains(err.Error(), "bearer-missing-env") { + t.Errorf("error should mention endpoint name, got: %v", err) + } +} + +// TestAuthResolve_TokenFile_ReadsAndTrims verifies that token_file reads the +// file content and strips trailing whitespace (e.g. trailing newline from echo). +func TestAuthResolve_TokenFile_ReadsAndTrims(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + tokenPath := filepath.Join(dir, "token.txt") + if err := os.WriteFile(tokenPath, []byte("file-secret\n"), 0o600); err != nil { + t.Fatalf("writing token file: %v", err) + } + + cfg := validEndpointBase("bearer-file") + cfg.Auth = &config.AuthConfig{Type: "bearer", TokenFile: tokenPath} + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + if !strings.HasSuffix(eps[0].AuthHeaderValue, "file-secret") { + t.Errorf("AuthHeaderValue %q does not end with trimmed file content", eps[0].AuthHeaderValue) + } +} + +// TestAuthResolve_BothInlineAndFile_FatalError ensures the both-set conflict is +// caught at resolution time (ExpandWithFile enforces this). +func TestAuthResolve_BothInlineAndFile_FatalError(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + tokenPath := filepath.Join(dir, "token.txt") + _ = os.WriteFile(tokenPath, []byte("tok\n"), 0o600) + + cfg := validEndpointBase("bearer-both") + cfg.Auth = &config.AuthConfig{Type: "bearer", Token: "inline", TokenFile: tokenPath} + + repo := NewStaticEndpointRepository() + err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}) + if err == nil { + t.Fatal("expected error for both token and token_file, got nil") + } +} + +// TestAuthResolve_EmptyAfterExpansion_FatalError verifies that a token that +// resolves to an empty string (e.g. env var set to "") is a startup-fatal error. +func TestAuthResolve_EmptyAfterExpansion_FatalError(t *testing.T) { + t.Setenv("OLLA_EMPTY_TOKEN", "") + + cfg := validEndpointBase("bearer-empty") + cfg.Auth = &config.AuthConfig{Type: "bearer", Token: "${OLLA_EMPTY_TOKEN:-}"} + + repo := NewStaticEndpointRepository() + err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}) + if err == nil { + t.Fatal("expected error for empty-resolved token, got nil") + } +} + +// ── Commit 3: precomputed headers ─────────────────────────────────────────── + +// TestAuthPrecompute_Bearer_AuthorizationHeader verifies the bearer auth +// produces the correct Authorization header value. +func TestAuthPrecompute_Bearer_AuthorizationHeader(t *testing.T) { + t.Parallel() + + cfg := validEndpointBase("bearer-precompute") + cfg.Auth = &config.AuthConfig{Type: "bearer", Token: "mysecrettoken"} + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + ep := eps[0] + + if ep.AuthHeaderName != "Authorization" { + t.Errorf("AuthHeaderName = %q, want %q", ep.AuthHeaderName, "Authorization") + } + if ep.AuthHeaderValue != "Bearer mysecrettoken" { + t.Errorf("AuthHeaderValue = %q, want %q", ep.AuthHeaderValue, "Bearer mysecrettoken") + } +} + +// TestAuthPrecompute_APIKey_DefaultHeader verifies that api_key with no header +// override uses X-Api-Key as the header name. +func TestAuthPrecompute_APIKey_DefaultHeader(t *testing.T) { + t.Parallel() + + cfg := validEndpointBase("apikey-default") + cfg.Auth = &config.AuthConfig{Type: "api_key", Key: "mykey"} + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + ep := eps[0] + + if ep.AuthHeaderName != "X-Api-Key" { + t.Errorf("AuthHeaderName = %q, want %q", ep.AuthHeaderName, "X-Api-Key") + } + if ep.AuthHeaderValue != "mykey" { + t.Errorf("AuthHeaderValue = %q, want %q", ep.AuthHeaderValue, "mykey") + } +} + +// TestAuthPrecompute_APIKey_CustomHeader verifies that an explicit header field +// overrides the default X-Api-Key name. +func TestAuthPrecompute_APIKey_CustomHeader(t *testing.T) { + t.Parallel() + + cfg := validEndpointBase("apikey-custom") + cfg.Auth = &config.AuthConfig{Type: "api_key", Key: "mykey", Header: "X-My-Auth"} + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + ep := eps[0] + + if ep.AuthHeaderName != "X-My-Auth" { + t.Errorf("AuthHeaderName = %q, want %q", ep.AuthHeaderName, "X-My-Auth") + } + if ep.AuthHeaderValue != "mykey" { + t.Errorf("AuthHeaderValue = %q, want %q", ep.AuthHeaderValue, "mykey") + } +} + +// TestAuthPrecompute_Basic_CorrectBase64 verifies the basic auth header is the +// correctly base64-encoded "username:password" pair. We decode it to be explicit. +func TestAuthPrecompute_Basic_CorrectBase64(t *testing.T) { + t.Parallel() + + cfg := validEndpointBase("basic-precompute") + cfg.Auth = &config.AuthConfig{Type: "basic", Username: "alice", Password: "s3cret"} + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + ep := eps[0] + + if ep.AuthHeaderName != "Authorization" { + t.Errorf("AuthHeaderName = %q, want %q", ep.AuthHeaderName, "Authorization") + } + + const prefix = "Basic " + if !strings.HasPrefix(ep.AuthHeaderValue, prefix) { + t.Fatalf("AuthHeaderValue %q does not start with %q", ep.AuthHeaderValue, prefix) + } + + encoded := strings.TrimPrefix(ep.AuthHeaderValue, prefix) + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatalf("base64 decode failed: %v", err) + } + + const want = "alice:s3cret" + if string(decoded) != want { + t.Errorf("decoded basic credentials = %q, want %q", string(decoded), want) + } +} + +// TestAuthPrecompute_NoAuth_EmptyFields verifies that endpoints without auth +// have zero-value AuthHeaderName and nil Headers. +func TestAuthPrecompute_NoAuth_EmptyFields(t *testing.T) { + t.Parallel() + + cfg := validEndpointBase("no-auth-fields") + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + ep := eps[0] + + if ep.AuthHeaderName != "" { + t.Errorf("AuthHeaderName should be empty for endpoint without auth, got %q", ep.AuthHeaderName) + } + if ep.AuthHeaderValue != "" { + t.Errorf("AuthHeaderValue should be empty for endpoint without auth, got %q", ep.AuthHeaderValue) + } + if ep.Headers != nil { + t.Errorf("Headers should be nil when no headers configured, got %v", ep.Headers) + } +} + +// TestAuthPrecompute_Headers_EnvResolved verifies that values in the headers +// map are expanded through the environment at load time. +func TestAuthPrecompute_Headers_EnvResolved(t *testing.T) { + t.Setenv("OLLA_TEST_TENANT", "acme-corp") + + cfg := validEndpointBase("headers-env") + cfg.Headers = map[string]string{ + "X-Tenant": "${OLLA_TEST_TENANT}", + "X-Static": "literal", + } + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + eps, _ := repo.GetAll(context.Background()) + ep := eps[0] + + if ep.Headers["X-Tenant"] != "acme-corp" { + t.Errorf("Headers[X-Tenant] = %q, want %q", ep.Headers["X-Tenant"], "acme-corp") + } + if ep.Headers["X-Static"] != "literal" { + t.Errorf("Headers[X-Static] = %q, want %q", ep.Headers["X-Static"], "literal") + } +} + +// ── Security audit ─────────────────────────────────────────────────────────── + +// TestAuthSecurity_AuthHeaderValue_NotInJSON asserts that AuthHeaderValue is +// excluded from JSON serialisation of a domain.Endpoint. This guards against +// credential leakage through status endpoints or debug logs. +// (The json:"-" tag on AuthHeaderValue provides the guarantee; this test keeps +// it honest so no future refactor can accidentally remove it.) +func TestAuthSecurity_AuthHeaderValue_NotInJSON(t *testing.T) { + t.Parallel() + + cfg := validEndpointBase("security-ep") + cfg.Auth = &config.AuthConfig{Type: "bearer", Token: "super-secret-do-not-leak"} + + repo := NewStaticEndpointRepository() + if err := repo.LoadFromConfig(context.Background(), []config.EndpointConfig{cfg}); err != nil { + t.Fatalf("LoadFromConfig failed: %v", err) + } + + // The domain.Endpoint JSON test already covers the tag; here we exercise + // the full load → retrieve → marshal path to catch end-to-end regressions. + eps, _ := repo.GetAll(context.Background()) + ep := eps[0] + + // AuthHeaderValue must be set (we loaded auth successfully). + if ep.AuthHeaderValue == "" { + t.Fatal("AuthHeaderValue is empty: auth was not wired in") + } + + // GetURLString and GetHealthCheckURLString are the only string accessors on + // Endpoint; neither should expose credentials. + if strings.Contains(ep.GetURLString(), "secret") { + t.Errorf("GetURLString leaks credential: %q", ep.GetURLString()) + } + if strings.Contains(ep.GetHealthCheckURLString(), "secret") { + t.Errorf("GetHealthCheckURLString leaks credential: %q", ep.GetHealthCheckURLString()) + } +} diff --git a/internal/adapter/discovery/http_client.go b/internal/adapter/discovery/http_client.go index 53cf18ac..de530f37 100644 --- a/internal/adapter/discovery/http_client.go +++ b/internal/adapter/discovery/http_client.go @@ -174,6 +174,15 @@ func (c *HTTPModelDiscoveryClient) discoverWithProfile(ctx context.Context, endp req.Header.Set("User-Agent", fmt.Sprintf(DefaultUserAgent, version.ShortName, version.Version)) req.Header.Set("Accept", DefaultContentType) + // Apply any per-endpoint custom headers then auth, mirroring what CopyHeaders does on + // the proxy path. Without this, authenticated backends 401 during model discovery. + for name, value := range endpoint.Headers { + req.Header.Set(name, value) + } + if endpoint.AuthHeaderName != "" { + req.Header.Set(endpoint.AuthHeaderName, endpoint.AuthHeaderValue) + } + resp, err := c.httpClient.Do(req) if err != nil { networkErr := &NetworkError{URL: discoveryURL, Err: err} diff --git a/internal/adapter/discovery/http_client_test.go b/internal/adapter/discovery/http_client_test.go index a6b23321..67429c1f 100644 --- a/internal/adapter/discovery/http_client_test.go +++ b/internal/adapter/discovery/http_client_test.go @@ -948,6 +948,55 @@ func createTestEndpointWithModelURL(baseURL, endpointType, modelURLString string } } +// TestDiscoverModels_AuthenticatedEndpoint verifies that discovery requests carry +// the endpoint auth header. Without it, authenticated backends 401 and discovery +// never populates the model list. +func TestDiscoverModels_AuthenticatedEndpoint(t *testing.T) { + t.Parallel() + + const token = "Bearer test-secret" + var receivedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + if receivedAuth != token { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":[{"name":"llama3:8b","size":1234}]}`)) + })) + t.Cleanup(srv.Close) + + endpointURL, _ := url.Parse(srv.URL) + endpoint := &domain.Endpoint{ + Name: "auth-backend", + Type: domain.ProfileOllama, + URL: endpointURL, + URLString: srv.URL, + AuthHeaderName: "Authorization", + AuthHeaderValue: token, + } + + factory, err := profile.NewFactoryWithDefaults() + if err != nil { + t.Fatalf("profile factory: %v", err) + } + + client := NewHTTPModelDiscoveryClientWithDefaults(factory, createTestLogger()) + models, err := client.DiscoverModels(context.Background(), endpoint) + if err != nil { + t.Fatalf("DiscoverModels failed: %v", err) + } + if len(models) == 0 { + t.Error("expected models, got none") + } + if receivedAuth != token { + t.Errorf("backend received auth %q, want %q", receivedAuth, token) + } +} + func createTestLogger() logger.StyledLogger { slogLogger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelError, // Only log errors to reduce test noise diff --git a/internal/adapter/discovery/repository.go b/internal/adapter/discovery/repository.go index 0774ef53..0f0ed588 100644 --- a/internal/adapter/discovery/repository.go +++ b/internal/adapter/discovery/repository.go @@ -12,6 +12,7 @@ import ( "github.com/thushan/olla/internal/config" "github.com/thushan/olla/internal/core/domain" "github.com/thushan/olla/internal/util" + "github.com/thushan/olla/pkg/envresolver" ) const ( @@ -114,6 +115,7 @@ func (r *StaticEndpointRepository) UpdateEndpoint(ctx context.Context, endpoint existing.BackoffMultiplier = endpoint.BackoffMultiplier existing.NextCheckTime = endpoint.NextCheckTime existing.LastLatency = endpoint.LastLatency + existing.RateLimitedUntil = endpoint.RateLimitedUntil return nil } @@ -144,6 +146,18 @@ func (r *StaticEndpointRepository) LoadFromConfig(ctx context.Context, configs [ return fmt.Errorf("invalid endpoint config for %q: %w", cfg.Name, err) } + var resolved resolvedAuth + if cfg.Auth != nil { + if err := validateAuth(cfg.Name, cfg.Auth); err != nil { + return err + } + var rerr error + resolved, rerr = resolveAuth(cfg.Name, cfg.Auth) + if rerr != nil { + return rerr + } + } + endpointURL, err := url.Parse(cfg.URL) if err != nil { return fmt.Errorf("invalid endpoint URL %q: %w", cfg.URL, err) @@ -190,6 +204,22 @@ func (r *StaticEndpointRepository) LoadFromConfig(ctx context.Context, configs [ PreservePath: cfg.PreservePath, } + if cfg.Auth != nil { + newEndpoint.AuthHeaderName, newEndpoint.AuthHeaderValue = precomputeAuthHeaders(resolved) + } + + if len(cfg.Headers) > 0 { + resolvedHeaders := make(map[string]string, len(cfg.Headers)) + for k, v := range cfg.Headers { + expanded, herr := envresolver.ExpandStrict(v) + if herr != nil { + return fmt.Errorf("endpoint %q: header %q: %w", cfg.Name, k, herr) + } + resolvedHeaders[k] = expanded + } + newEndpoint.Headers = resolvedHeaders + } + newEndpoints[urlString] = newEndpoint } diff --git a/internal/adapter/discovery/repository_test.go b/internal/adapter/discovery/repository_test.go index e75152bd..98446cb1 100644 --- a/internal/adapter/discovery/repository_test.go +++ b/internal/adapter/discovery/repository_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" "github.com/thushan/olla/internal/config" ) @@ -845,3 +846,46 @@ func TestStaticEndpointRepository_EmptyURLs_WithNestedPath(t *testing.T) { t.Errorf("ModelURLString = %q, expected %q", endpoint.ModelURLString, expectedModelURL) } } + +// TestUpdateEndpoint_PersistsRateLimitedUntil confirms that a 429-triggered +// RateLimitedUntil timestamp is not lost when UpdateEndpoint is called. +// The field was missing from the copy list, so the scheduler kept re-probing +// rate-limited backends immediately after the health checker set the deadline. +func TestUpdateEndpoint_PersistsRateLimitedUntil(t *testing.T) { + t.Parallel() + + repo := NewStaticEndpointRepository() + cfg := []config.EndpointConfig{ + { + Name: "rl-endpoint", + URL: "http://localhost:11434", + Type: "ollama", + Priority: ptrInt(100), + CheckInterval: 5 * time.Second, + CheckTimeout: 2 * time.Second, + }, + } + if err := repo.LoadFromConfig(context.Background(), cfg); err != nil { + t.Fatalf("LoadFromConfig: %v", err) + } + + eps, err := repo.GetAll(context.Background()) + require.NoError(t, err, "GetAll before UpdateEndpoint") + ep := eps[0] + + rateLimitDeadline := time.Now().Add(60 * time.Second).Truncate(time.Millisecond) + ep.RateLimitedUntil = rateLimitDeadline + + if err := repo.UpdateEndpoint(context.Background(), ep); err != nil { + t.Fatalf("UpdateEndpoint: %v", err) + } + + updated, err := repo.GetAll(context.Background()) + require.NoError(t, err, "GetAll after UpdateEndpoint") + if updated[0].RateLimitedUntil.IsZero() { + t.Error("RateLimitedUntil was not persisted; got zero time after UpdateEndpoint") + } + if !updated[0].RateLimitedUntil.Equal(rateLimitDeadline) { + t.Errorf("RateLimitedUntil = %v, want %v", updated[0].RateLimitedUntil, rateLimitDeadline) + } +} diff --git a/internal/adapter/health/checker.go b/internal/adapter/health/checker.go index 73c0d878..c122bb58 100644 --- a/internal/adapter/health/checker.go +++ b/internal/adapter/health/checker.go @@ -70,6 +70,13 @@ func NewHTTPHealthCheckerWithDefaults(repository domain.EndpointRepository, logg MaxIdleConnsPerHost: 2, IdleConnTimeout: 30 * time.Second, DisableKeepAlives: false, + // Per-probe timeouts are already enforced by the CheckTimeout context in + // performSingleCheck; this header timeout is a backstop against backends + // that accept the TCP connection but then never send a response header. + ResponseHeaderTimeout: DefaultHealthCheckerResponseHeaderTimeout, + // No proxy: health probes now carry auth credentials (Authorization / API key + // headers injected by the endpoint auth config). Routing them through an + // environment proxy risks leaking those credentials to a third party. }, } return NewHTTPHealthChecker(repository, logger, client) @@ -162,11 +169,16 @@ func (c *HTTPHealthChecker) performHealthChecks(ctx context.Context) { endpointsToCheck := make([]*domain.Endpoint, 0, len(endpoints)) - // Filter endpoints that are due for checking + // Filter endpoints that are due for checking. + // Skip endpoints still inside a rate-limit window; hammering a throttled backend + // will not make it respond faster and wastes quota. for _, endpoint := range endpoints { if now.Before(endpoint.NextCheckTime) { continue } + if !endpoint.RateLimitedUntil.IsZero() && now.Before(endpoint.RateLimitedUntil) { + continue + } endpointsToCheck = append(endpointsToCheck, endpoint) } @@ -225,6 +237,14 @@ func (c *HTTPHealthChecker) checkEndpoint(ctx context.Context, endpoint *domain. endpointCopy.LastChecked = now endpointCopy.LastLatency = result.Latency + // Persist the rate-limit window so the scheduler can skip this endpoint until + // the backend is ready to serve probes again. + if !result.RateLimitedUntil.IsZero() { + endpointCopy.RateLimitedUntil = result.RateLimitedUntil + } else { + endpointCopy.RateLimitedUntil = time.Time{} // clear on non-429 responses + } + isSuccess := result.Status == domain.StatusHealthy nextInterval, newMultiplier := calculateBackoff(&endpointCopy, isSuccess) diff --git a/internal/adapter/health/client.go b/internal/adapter/health/client.go index 28b5c25b..75455dc1 100644 --- a/internal/adapter/health/client.go +++ b/internal/adapter/health/client.go @@ -5,8 +5,10 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/http" + "strconv" "time" "github.com/thushan/olla/internal/version" @@ -98,11 +100,21 @@ func (hc *HealthClient) Check(ctx context.Context, endpoint *domain.Endpoint) (r // Record overall latency including retries result.Latency = time.Since(overallStart) - // Record result in circuit breaker - if lastErr != nil || result.Status != domain.StatusHealthy { + // Record result in circuit breaker. + // ConfigError and RateLimited are not service failures. The backend is up and + // responding; counting them as failures would trip the CB on misconfigured + // credentials or a throttled endpoint, hiding the real cause from the operator. + if lastErr != nil { hc.circuitBreaker.RecordFailure(healthCheckURL) } else { - hc.circuitBreaker.RecordSuccess(healthCheckURL) + switch result.Status { + case domain.StatusConfigError, domain.StatusRateLimited: + // Do not trip the circuit breaker; this is an operator or rate error, not downtime. + case domain.StatusHealthy, domain.StatusBusy: + hc.circuitBreaker.RecordSuccess(healthCheckURL) + default: + hc.circuitBreaker.RecordFailure(healthCheckURL) + } } if lastErr != nil { @@ -132,6 +144,7 @@ func (hc *HealthClient) performSingleCheck(ctx context.Context, endpoint *domain } req = injectDefaultHeaders(req) + injectEndpointAuth(req, endpoint) resp, err := hc.client.Do(req) result.Latency = time.Since(start) @@ -158,9 +171,45 @@ func (hc *HealthClient) performSingleCheck(ctx context.Context, endpoint *domain result.StatusCode = resp.StatusCode result.Status = determineStatus(resp.StatusCode, result.Latency, nil, domain.ErrorTypeNone) + if resp.StatusCode == http.StatusTooManyRequests { + result.RateLimitedUntil = parseRetryAfter(resp.Header.Get("Retry-After"), endpoint.Name) + } + return result, nil } +// parseRetryAfter interprets the Retry-After header value per RFC 9110. +// It accepts delay-seconds and HTTP-date formats. Falls back to DefaultRateLimitBackoff +// if the value is missing or malformed. +func parseRetryAfter(header, endpointName string) time.Time { + now := time.Now() + + if header == "" { + slog.Info("no Retry-After header on 429, using default backoff", + "endpoint", endpointName, + "default", DefaultRateLimitBackoff) + return now.Add(DefaultRateLimitBackoff) + } + + // Try delay-seconds first (most common for API services). + if secs, err := strconv.ParseInt(header, 10, 64); err == nil { + return now.Add(time.Duration(secs) * time.Second) + } + + // Try HTTP-date format (RFC 1123 / RFC 850 / ANSI C asctime). + for _, layout := range []string{http.TimeFormat, "Monday, 02-Jan-06 15:04:05 MST", "Mon Jan _2 15:04:05 2006"} { + if t, err := time.Parse(layout, header); err == nil { + return t + } + } + + slog.Info("malformed Retry-After header on 429, using default backoff", + "endpoint", endpointName, + "header", header, + "default", DefaultRateLimitBackoff) + return now.Add(DefaultRateLimitBackoff) +} + func injectDefaultHeaders(req *http.Request) *http.Request { req.Header.Set("User-Agent", fmt.Sprintf("%s-HealthChecker/%s", version.ShortName, version.Version)) req.Header.Set("Accept", "application/json, text/plain, */*") @@ -168,6 +217,25 @@ func injectDefaultHeaders(req *http.Request) *http.Request { return req } +// injectEndpointAuth applies the endpoint's configured auth and custom headers onto +// a probe request. We can't use core.CopyHeaders here because it strips then re-injects +// based on an incoming client request, which doesn't exist for health probes. +func injectEndpointAuth(req *http.Request, endpoint *domain.Endpoint) { + if endpoint == nil { + return + } + + // Apply static headers from the endpoint config first so auth can override them. + for name, value := range endpoint.Headers { + req.Header.Set(name, value) + } + + // Auth always wins over the headers map (matches CopyHeaders precedence). + if endpoint.AuthHeaderName != "" { + req.Header.Set(endpoint.AuthHeaderName, endpoint.AuthHeaderValue) + } +} + func calculateBackoffDelay(attempt int) time.Duration { // Use centralized backoff calculation with 25% jitter // SHERPA-198: Jitterbug - calculation was invalid, 0 jitter was being applied @@ -227,8 +295,15 @@ func classifyError(err error) domain.HealthCheckErrorType { return domain.ErrorTypeHTTPError } -// determineStatus converts HTTP response info into endpoint status -// Status logic: offline for network errors, busy for slow responses, healthy otherwise +// determineStatus converts HTTP response info into endpoint status. +// +// Classification priorities: +// - Network/transport errors → Offline +// - 401/403 → ConfigError (operator must fix credentials; no circuit-breaker trip) +// - 429 → RateLimited (transient; scheduler will honour Retry-After before next probe) +// - 2xx with high latency → Busy (still routable, just slow) +// - 2xx → Healthy +// - anything else → Unhealthy func determineStatus(statusCode int, latency time.Duration, err error, errorType domain.HealthCheckErrorType) domain.EndpointStatus { if err != nil { switch errorType { @@ -239,6 +314,18 @@ func determineStatus(statusCode int, latency time.Duration, err error, errorType } } + switch statusCode { + case http.StatusUnauthorized, http.StatusForbidden: + // The backend is up but rejecting our credentials. Retrying will never help + // until the operator updates the auth config. We don't trip the circuit breaker + // for this; it's a config problem, not a service availability problem. + return domain.StatusConfigError + case http.StatusTooManyRequests: + // The backend is healthy but throttling us. The scheduler honours Retry-After + // before the next probe so we don't hammer a rate-limited endpoint. + return domain.StatusRateLimited + } + if statusCode >= HealthyEndpointStatusRangeStart && statusCode < HealthyEndpointStatusRangeEnd { if latency > SlowResponseThreshold { return domain.StatusBusy diff --git a/internal/adapter/health/client_auth_test.go b/internal/adapter/health/client_auth_test.go new file mode 100644 index 00000000..a4651373 --- /dev/null +++ b/internal/adapter/health/client_auth_test.go @@ -0,0 +1,144 @@ +package health + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thushan/olla/internal/core/domain" +) + +// authEnforcingBackend returns 401 when the expected header is absent or wrong, +// and 200 when it matches. This proves auth is actually transported, not just set. +func authEnforcingBackend(t *testing.T, headerName, wantValue string) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(headerName) != wantValue { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + return srv +} + +func makeEndpoint(t *testing.T, rawURL string) *domain.Endpoint { + t.Helper() + u, err := url.Parse(rawURL) + require.NoError(t, err) + return &domain.Endpoint{ + Name: "test", + URL: u, + HealthCheckURL: u, + URLString: u.String(), + HealthCheckURLString: u.String(), + CheckTimeout: 2 * time.Second, + } +} + +// TestHealthProbe_BearerAuth proves that a probe on an endpoint with bearer auth +// reaches the backend with the correct Authorization header and is classified healthy. +func TestHealthProbe_BearerAuth(t *testing.T) { + t.Parallel() + + const token = "Bearer secret-token" + srv := authEnforcingBackend(t, "Authorization", token) + + ep := makeEndpoint(t, srv.URL) + ep.AuthHeaderName = "Authorization" + ep.AuthHeaderValue = token + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + result, err := hc.Check(context.Background(), ep) + + require.NoError(t, err) + assert.Equal(t, domain.StatusHealthy, result.Status, "probe with correct bearer token must be healthy") + assert.Equal(t, http.StatusOK, result.StatusCode) +} + +// TestHealthProbe_MissingAuth demonstrates that probing an auth-protected backend +// without credentials configured on the endpoint returns an unhealthy classification. +// A 401 response maps to StatusConfigError via the health client's status mapping. +func TestHealthProbe_MissingAuth(t *testing.T) { + t.Parallel() + + srv := authEnforcingBackend(t, "Authorization", "Bearer required") + + // Endpoint has no auth configured; backend will reject with 401. + ep := makeEndpoint(t, srv.URL) + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + result, err := hc.Check(context.Background(), ep) + + // No transport error; just an HTTP 401. + require.NoError(t, err, "401 is an HTTP response, not a transport error") + assert.Equal(t, http.StatusUnauthorized, result.StatusCode) + assert.NotEqual(t, domain.StatusHealthy, result.Status, "unauthenticated probe must not be healthy") +} + +// TestHealthProbe_CustomHeaders proves that the endpoint.Headers map entries +// are forwarded on health probes, not just on proxy requests. +func TestHealthProbe_CustomHeaders(t *testing.T) { + t.Parallel() + + const ( + headerName = "X-Backend-Key" + headerValue = "backend-secret" + ) + + srv := authEnforcingBackend(t, headerName, headerValue) + + ep := makeEndpoint(t, srv.URL) + ep.Headers = map[string]string{ + headerName: headerValue, + } + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + result, err := hc.Check(context.Background(), ep) + + require.NoError(t, err) + assert.Equal(t, domain.StatusHealthy, result.Status, "probe with custom header must be healthy") +} + +// TestInjectEndpointAuth_Precedence verifies that auth wins over the headers map +// when both configure the same field, matching CopyHeaders precedence. +func TestInjectEndpointAuth_Precedence(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "http://localhost/health", nil) + require.NoError(t, err) + + ep := &domain.Endpoint{ + Headers: map[string]string{ + "Authorization": "Bearer from-headers-map", + }, + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer from-auth-section", + } + + injectEndpointAuth(req, ep) + + values := req.Header["Authorization"] + require.Len(t, values, 1, "must have exactly one Authorization value") + assert.Equal(t, "Bearer from-auth-section", values[0], "auth section must beat headers map") +} + +// TestInjectEndpointAuth_Nil ensures nil endpoint is a no-op and does not panic. +func TestInjectEndpointAuth_Nil(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "http://localhost/health", nil) + require.NoError(t, err) + + assert.NotPanics(t, func() { + injectEndpointAuth(req, nil) + }) + + assert.Empty(t, req.Header.Get("Authorization")) +} diff --git a/internal/adapter/health/client_classify_test.go b/internal/adapter/health/client_classify_test.go new file mode 100644 index 00000000..14f53a1b --- /dev/null +++ b/internal/adapter/health/client_classify_test.go @@ -0,0 +1,125 @@ +package health + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thushan/olla/internal/core/domain" +) + +// statusBackend returns a fixed HTTP status code for every request. +func statusBackend(t *testing.T, code int) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + })) + t.Cleanup(srv.Close) + return srv +} + +func probeEndpoint(t *testing.T, rawURL string) *domain.Endpoint { + t.Helper() + u, err := url.Parse(rawURL) + require.NoError(t, err) + return &domain.Endpoint{ + Name: "classify-test", + URL: u, + HealthCheckURL: u, + URLString: u.String(), + HealthCheckURLString: u.String(), + CheckTimeout: 2 * time.Second, + } +} + +func TestDetermineStatus_401_IsConfigError(t *testing.T) { + t.Parallel() + + srv := statusBackend(t, http.StatusUnauthorized) + ep := probeEndpoint(t, srv.URL) + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + result, err := hc.Check(context.Background(), ep) + + require.NoError(t, err, "401 is a valid HTTP response, not a transport error") + assert.Equal(t, domain.StatusConfigError, result.Status) + assert.Equal(t, http.StatusUnauthorized, result.StatusCode) +} + +func TestDetermineStatus_403_IsConfigError(t *testing.T) { + t.Parallel() + + srv := statusBackend(t, http.StatusForbidden) + ep := probeEndpoint(t, srv.URL) + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + result, err := hc.Check(context.Background(), ep) + + require.NoError(t, err) + assert.Equal(t, domain.StatusConfigError, result.Status) + assert.Equal(t, http.StatusForbidden, result.StatusCode) +} + +// TestConfigError_DoesNotTripCircuitBreaker verifies that many 401 responses never +// open the circuit breaker. Auth failures are a config problem; the CB is for +// service availability problems. +func TestConfigError_DoesNotTripCircuitBreaker(t *testing.T) { + t.Parallel() + + srv := statusBackend(t, http.StatusUnauthorized) + ep := probeEndpoint(t, srv.URL) + + cb := NewCircuitBreaker() + hc := NewHealthClient(http.DefaultClient, cb) + + // Fire well past the CB threshold. + for range DefaultCircuitBreakerThreshold * 3 { + _, _ = hc.Check(context.Background(), ep) + } + + assert.False(t, cb.IsOpen(ep.HealthCheckURLString), + "circuit breaker must not open on repeated auth failures") +} + +// TestConfigError_IsNotRoutable ensures the new status doesn't accidentally +// get included in the routable set. +func TestConfigError_IsNotRoutable(t *testing.T) { + t.Parallel() + + assert.False(t, domain.StatusConfigError.IsRoutable()) + assert.False(t, domain.StatusRateLimited.IsRoutable()) +} + +// TestDetermineStatus_ExistingCases guards regressions on the pre-existing +// status classification table now that the switch has new cases. +func TestDetermineStatus_ExistingCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + latency time.Duration + want domain.EndpointStatus + }{ + {"200 fast → healthy", http.StatusOK, 0, domain.StatusHealthy}, + {"200 slow → busy", http.StatusOK, SlowResponseThreshold + time.Millisecond, domain.StatusBusy}, + {"404 → unhealthy", http.StatusNotFound, 0, domain.StatusUnhealthy}, + {"500 → unhealthy", http.StatusInternalServerError, 0, domain.StatusUnhealthy}, + {"401 → config_error", http.StatusUnauthorized, 0, domain.StatusConfigError}, + {"403 → config_error", http.StatusForbidden, 0, domain.StatusConfigError}, + {"429 → rate_limited", http.StatusTooManyRequests, 0, domain.StatusRateLimited}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := determineStatus(tt.statusCode, tt.latency, nil, domain.ErrorTypeNone) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/adapter/health/client_ratelimit_test.go b/internal/adapter/health/client_ratelimit_test.go new file mode 100644 index 00000000..9d9aed2a --- /dev/null +++ b/internal/adapter/health/client_ratelimit_test.go @@ -0,0 +1,169 @@ +package health + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thushan/olla/internal/core/domain" +) + +// rateLimitBackend returns 429 with an optional Retry-After header. +func rateLimitBackend(t *testing.T, retryAfterHeader string) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if retryAfterHeader != "" { + w.Header().Set("Retry-After", retryAfterHeader) + } + w.WriteHeader(http.StatusTooManyRequests) + })) + t.Cleanup(srv.Close) + return srv +} + +func rlEndpoint(t *testing.T, rawURL string) *domain.Endpoint { + t.Helper() + u, err := url.Parse(rawURL) + require.NoError(t, err) + return &domain.Endpoint{ + Name: "rl-test", + URL: u, + HealthCheckURL: u, + URLString: u.String(), + HealthCheckURLString: u.String(), + CheckTimeout: 2 * time.Second, + } +} + +// TestRateLimit_RetryAfterSeconds verifies that a numeric Retry-After is parsed +// into a RateLimitedUntil approximately 60s in the future. +func TestRateLimit_RetryAfterSeconds(t *testing.T) { + t.Parallel() + + srv := rateLimitBackend(t, "60") + ep := rlEndpoint(t, srv.URL) + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + before := time.Now() + result, err := hc.Check(context.Background(), ep) + after := time.Now() + + require.NoError(t, err) + assert.Equal(t, domain.StatusRateLimited, result.Status) + assert.False(t, result.RateLimitedUntil.IsZero(), "RateLimitedUntil must be set") + + // Allow a generous window for test execution jitter. + lo := before.Add(59 * time.Second) + hi := after.Add(61 * time.Second) + assert.True(t, result.RateLimitedUntil.After(lo) && result.RateLimitedUntil.Before(hi), + "RateLimitedUntil (%v) should be ~60s from probe time [%v, %v]", + result.RateLimitedUntil, lo, hi) +} + +// TestRateLimit_RetryAfterHTTPDate verifies that an HTTP-date Retry-After is parsed. +func TestRateLimit_RetryAfterHTTPDate(t *testing.T) { + t.Parallel() + + future := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second) + httpDate := future.Format(http.TimeFormat) + + srv := rateLimitBackend(t, httpDate) + ep := rlEndpoint(t, srv.URL) + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + result, err := hc.Check(context.Background(), ep) + + require.NoError(t, err) + assert.Equal(t, domain.StatusRateLimited, result.Status) + assert.False(t, result.RateLimitedUntil.IsZero()) + // Should be within a second of our expected time. + assert.WithinDuration(t, future, result.RateLimitedUntil, time.Second) +} + +// TestRateLimit_NoRetryAfterUsesDefault verifies the 30s fallback when the header +// is absent. +func TestRateLimit_NoRetryAfterUsesDefault(t *testing.T) { + t.Parallel() + + srv := rateLimitBackend(t, "") + ep := rlEndpoint(t, srv.URL) + + hc := NewHealthClient(http.DefaultClient, NewCircuitBreaker()) + before := time.Now() + result, err := hc.Check(context.Background(), ep) + after := time.Now() + + require.NoError(t, err) + assert.Equal(t, domain.StatusRateLimited, result.Status) + + lo := before.Add(DefaultRateLimitBackoff - time.Second) + hi := after.Add(DefaultRateLimitBackoff + time.Second) + assert.True(t, result.RateLimitedUntil.After(lo) && result.RateLimitedUntil.Before(hi), + "default backoff should be ~%v", DefaultRateLimitBackoff) +} + +// TestRateLimit_DoesNotTripCircuitBreaker checks that 429 responses never open the CB. +func TestRateLimit_DoesNotTripCircuitBreaker(t *testing.T) { + t.Parallel() + + srv := rateLimitBackend(t, "1") + ep := rlEndpoint(t, srv.URL) + + cb := NewCircuitBreaker() + hc := NewHealthClient(http.DefaultClient, cb) + + for range DefaultCircuitBreakerThreshold * 3 { + _, _ = hc.Check(context.Background(), ep) + } + + assert.False(t, cb.IsOpen(ep.HealthCheckURLString), + "circuit breaker must not open on rate-limit responses") +} + +// TestScheduler_SkipsRateLimitedEndpoints verifies the comparison logic that the +// scheduler uses to skip endpoints still inside their Retry-After window. +// We test the predicate directly rather than wiring up a full scheduler tick +// to keep this fast and deterministic. +func TestScheduler_SkipsRateLimitedEndpoints(t *testing.T) { + t.Parallel() + + now := time.Now() + + tests := []struct { + name string + rateLimitedUntil time.Time + wantSkipped bool + }{ + { + name: "window in future: skip", + rateLimitedUntil: now.Add(30 * time.Second), + wantSkipped: true, + }, + { + name: "window just expired: probe", + rateLimitedUntil: now.Add(-time.Millisecond), + wantSkipped: false, + }, + { + name: "zero time: probe (never rate-limited)", + rateLimitedUntil: time.Time{}, + wantSkipped: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ep := &domain.Endpoint{RateLimitedUntil: tt.rateLimitedUntil} + // Mirror the scheduler predicate from performHealthChecks. + skipped := !ep.RateLimitedUntil.IsZero() && now.Before(ep.RateLimitedUntil) + assert.Equal(t, tt.wantSkipped, skipped) + }) + } +} diff --git a/internal/adapter/health/client_transport_test.go b/internal/adapter/health/client_transport_test.go new file mode 100644 index 00000000..1d290c8b --- /dev/null +++ b/internal/adapter/health/client_transport_test.go @@ -0,0 +1,82 @@ +package health + +import ( + "net/http" + "reflect" + "runtime" + "testing" + + "github.com/thushan/olla/internal/logger" +) + +// funcName extracts the full symbol name of a function value for comparison. +func funcName(f interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() +} + +// newHealthTestLogger returns a quiet logger for transport tests. +func newHealthTestLogger(t *testing.T) logger.StyledLogger { + t.Helper() + loggerCfg := &logger.Config{Level: "error", Theme: "default"} + log, cleanup, _ := logger.New(loggerCfg) + t.Cleanup(cleanup) + return logger.NewPlainStyledLogger(log) +} + +// extractTransport pulls the *http.Transport from the health checker built by +// NewHTTPHealthCheckerWithDefaults. The HTTPClient interface doesn't expose the +// transport, so we type-assert through the concrete *http.Client field. +func extractTransport(t *testing.T) *http.Transport { + t.Helper() + + checker := NewHTTPHealthCheckerWithDefaults(newMockRepository(), newHealthTestLogger(t)) + + // The client field is an HTTPClient interface; NewHTTPHealthCheckerWithDefaults + // always passes a *http.Client so the assertion is safe in tests. + httpClient, ok := checker.healthClient.client.(*http.Client) + if !ok { + t.Fatal("health client is not *http.Client; update this test if the type changes") + } + + transport, ok := httpClient.Transport.(*http.Transport) + if !ok { + t.Fatal("http.Client.Transport is not *http.Transport; update this test if the type changes") + } + + return transport +} + +// TestHealthClientTransport_NoProxyFromEnvironment asserts that the default health +// checker transport does NOT route through environment proxies. Health probes carry +// auth credentials; routing them via an operator-configured proxy risks leaking those +// credentials to a third party. +func TestHealthClientTransport_NoProxyFromEnvironment(t *testing.T) { + t.Parallel() + + transport := extractTransport(t) + + if transport.Proxy != nil { + t.Errorf("health transport.Proxy = %s, want nil: credentials must not transit an env proxy", + funcName(transport.Proxy)) + } +} + +// TestHealthClientTransport_ResponseHeaderTimeout asserts that the default health +// checker transport has a finite ResponseHeaderTimeout. Without it a backend that +// accepts the TCP connection but withholds response headers blocks health probes +// indefinitely, masking downtime from the scheduler. +func TestHealthClientTransport_ResponseHeaderTimeout(t *testing.T) { + t.Parallel() + + transport := extractTransport(t) + + if transport.ResponseHeaderTimeout <= 0 { + t.Errorf("health transport.ResponseHeaderTimeout is %v; a backend that stalls after accept will block health probes indefinitely", + transport.ResponseHeaderTimeout) + } + + const want = DefaultHealthCheckerResponseHeaderTimeout + if transport.ResponseHeaderTimeout != want { + t.Errorf("transport.ResponseHeaderTimeout = %v, want %v", transport.ResponseHeaderTimeout, want) + } +} diff --git a/internal/adapter/health/types.go b/internal/adapter/health/types.go index 91da1baf..bc07b316 100644 --- a/internal/adapter/health/types.go +++ b/internal/adapter/health/types.go @@ -10,6 +10,14 @@ const ( DefaultHealthCheckerTimeout = 5 * time.Second SlowResponseThreshold = 10 * time.Second + // DefaultHealthCheckerResponseHeaderTimeout caps the time a backend may hold + // the connection open after accepting without sending a single response header. + // Shorter than the proxy equivalent; health probes are latency-sensitive. + DefaultHealthCheckerResponseHeaderTimeout = 10 * time.Second + + // DefaultRateLimitBackoff is used when a 429 carries no Retry-After header. + DefaultRateLimitBackoff = 30 * time.Second + HealthyEndpointStatusRangeStart = 200 HealthyEndpointStatusRangeEnd = 300 diff --git a/internal/adapter/proxy/config/unified.go b/internal/adapter/proxy/config/unified.go index 5c573c41..4b8409b3 100644 --- a/internal/adapter/proxy/config/unified.go +++ b/internal/adapter/proxy/config/unified.go @@ -23,6 +23,15 @@ const ( OllaDefaultTimeout = 30 * time.Second OllaDefaultKeepAlive = 30 * time.Second OllaDefaultReadTimeout = 30 * time.Second + + // DefaultResponseHeaderTimeout caps the time a backend may hold the connection + // open after accepting without sending a single response header byte. + // 30 s is chosen to match Olla's other timeout defaults; Sherpa uses the same constant. + DefaultResponseHeaderTimeout = 30 * time.Second + + // DefaultHealthResponseHeaderTimeout is shorter than the proxy timeout because + // health probes are latency-sensitive and already bounded by CheckTimeout. + DefaultHealthResponseHeaderTimeout = 10 * time.Second ) // ProxyConfig defines the interface for all proxy configurations diff --git a/internal/adapter/proxy/core/common.go b/internal/adapter/proxy/core/common.go index bcae1970..b4a489e4 100644 --- a/internal/adapter/proxy/core/common.go +++ b/internal/adapter/proxy/core/common.go @@ -34,8 +34,10 @@ func GetViaHeader() string { return viaHeader } -// CopyHeaders copies headers from originalReq to proxyReq with proper handling -func CopyHeaders(proxyReq, originalReq *http.Request) { +// CopyHeaders copies headers from originalReq to proxyReq with proper handling. +// endpoint carries the per-endpoint auth and custom header config applied after +// the client headers are copied and the sensitive strip list runs. +func CopyHeaders(proxyReq, originalReq *http.Request, endpoint *domain.Endpoint) { // Pre-size based on source to avoid rehashing if proxyReq.Header == nil { proxyReq.Header = make(http.Header, len(originalReq.Header)) @@ -88,6 +90,23 @@ func CopyHeaders(proxyReq, originalReq *http.Request) { // Update or set X-Forwarded headers updateForwardedHeaders(proxyReq, originalReq) + + // Apply endpoint-level custom headers after the strip so operators can explicitly + // re-introduce a header that the strip removed (e.g. a backend that needs X-Api-Key). + // Auth is applied after these so the auth: section always wins on conflict. If the + // user accidentally puts Authorization in headers: and auth:, auth: takes precedence. + if endpoint != nil { + for name, value := range endpoint.Headers { + proxyReq.Header.Set(name, value) + } + + // Auth wins over anything in the headers: map and over anything the client sent. + // The strip loop above already removed client credentials; Set() here is + // defensive so future strip-list gaps can't leak client creds to the upstream. + if endpoint.AuthHeaderName != "" { + proxyReq.Header.Set(endpoint.AuthHeaderName, endpoint.AuthHeaderValue) + } + } } // SHERPA-81: Update X-Forwarded-* headers in request @@ -162,6 +181,55 @@ func extractClientIP(r *http.Request) string { return host } +// responseHeaderStripList holds upstream response headers that must never reach +// the client. A backend that reflects auth credentials or sets cookies is almost +// certainly misconfigured; stripping here prevents credential leakage in the rare +// case where a compromised or buggy upstream reflects these headers back. +var responseHeaderStripList = []string{ + constants.HeaderAuthorization, + constants.HeaderProxyAuthorization, + constants.HeaderXAPIKey, + constants.HeaderXAuthToken, + "Set-Cookie", +} + +// CopyResponseHeaders copies upstream response headers to the client, filtering +// headers that should never leave the proxy boundary. Use this at every site that +// copies resp.Header to w.Header() to keep the strip list consistent. +// +// The deny set is the union of the static strip list, the endpoint's auth header +// name, and every key in the endpoint's custom header map. Operator-configured +// headers must be stripped on the return path for the same reason they're set on +// the outbound path: if a compromised backend reflects them, the client would +// receive credentials it has no business seeing. Pass nil endpoint to use only the +// static list (safe for callers without endpoint context, though all current call +// sites have one). +func CopyResponseHeaders(dst http.Header, src http.Header, endpoint *domain.Endpoint) { + // Build a transient deny set: static list + endpoint-specific names. + // For the common case (no endpoint or empty config) this stays small. + deny := make(map[string]struct{}, len(responseHeaderStripList)+2) + for _, h := range responseHeaderStripList { + deny[h] = struct{}{} + } + if endpoint != nil { + if endpoint.AuthHeaderName != "" { + deny[http.CanonicalHeaderKey(endpoint.AuthHeaderName)] = struct{}{} + } + for name := range endpoint.Headers { + deny[http.CanonicalHeaderKey(name)] = struct{}{} + } + } + + for key, values := range src { + if _, blocked := deny[http.CanonicalHeaderKey(key)]; blocked { + continue + } + for _, v := range values { + dst.Add(key, v) + } + } +} + // SetStickySessionHeaders writes sticky session outcome headers before WriteHeader // is called. It reads the StickyOutcome pointer that was injected into the context // by the handler layer after the balancer's Select fills it. Must be called before diff --git a/internal/adapter/proxy/core/common_auth_test.go b/internal/adapter/proxy/core/common_auth_test.go new file mode 100644 index 00000000..0e35612c --- /dev/null +++ b/internal/adapter/proxy/core/common_auth_test.go @@ -0,0 +1,298 @@ +package core + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thushan/olla/internal/core/domain" +) + +// capturedHeaders records request headers received by an httptest backend. +// We intentionally avoid logging or printing header values to prevent credential +// leakage in CI output. +type capturedHeaders struct { + headers http.Header +} + +func newCapturingBackend(t *testing.T) (*httptest.Server, *capturedHeaders) { + t.Helper() + captured := &capturedHeaders{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.headers = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + return srv, captured +} + +func TestCopyHeaders_WithAuth(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + endpoint *domain.Endpoint + clientAuth string // value of Authorization on the incoming client request + wantHeader string + wantValue string + wantNoHeader string // header that must NOT be present + }{ + { + name: "no auth on endpoint: client auth is stripped", + endpoint: &domain.Endpoint{}, + clientAuth: "Bearer client-token", + wantNoHeader: "Authorization", + }, + { + name: "bearer auth injected", + endpoint: &domain.Endpoint{ + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer tok-backend", + }, + clientAuth: "Bearer client-token", + wantHeader: "Authorization", + wantValue: "Bearer tok-backend", + }, + { + name: "api_key with default X-Api-Key header", + endpoint: &domain.Endpoint{ + AuthHeaderName: "X-Api-Key", + AuthHeaderValue: "sk-backend-key", + }, + clientAuth: "", + wantHeader: "X-Api-Key", + wantValue: "sk-backend-key", + }, + { + name: "api_key with custom header name", + endpoint: &domain.Endpoint{ + AuthHeaderName: "X-Custom-Auth", + AuthHeaderValue: "custom-val", + }, + wantHeader: "X-Custom-Auth", + wantValue: "custom-val", + }, + { + name: "basic auth injected", + endpoint: &domain.Endpoint{ + AuthHeaderName: "Authorization", + AuthHeaderValue: "Basic dXNlcjpwYXNz", + }, + clientAuth: "Basic client-cred", + wantHeader: "Authorization", + wantValue: "Basic dXNlcjpwYXNz", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + if tt.clientAuth != "" { + originalReq.Header.Set("Authorization", tt.clientAuth) + } + + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, tt.endpoint) + + if tt.wantHeader != "" { + assert.Equal(t, tt.wantValue, proxyReq.Header.Get(tt.wantHeader), + "endpoint auth header must be set to the configured value") + } + + if tt.wantNoHeader != "" { + assert.Empty(t, proxyReq.Header.Get(tt.wantNoHeader), + "sensitive header must be stripped when endpoint has no auth configured") + } + }) + } +} + +// TestCopyHeaders_AuthOverwrite asserts that a client-supplied Authorization header +// is replaced by the endpoint's configured value, not appended to it. +func TestCopyHeaders_AuthOverwrite(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer endpoint-token", + } + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + originalReq.Header.Set("Authorization", "Bearer client-token") + + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, endpoint) + + // Must be exactly the endpoint value, never both. + values := proxyReq.Header["Authorization"] + require.Len(t, values, 1, "Authorization must have exactly one value (Set not Add)") + assert.Equal(t, "Bearer endpoint-token", values[0], "endpoint credential must win over client credential") +} + +// TestCopyHeaders_NilEndpointStripsAuth verifies that passing nil endpoint +// still strips the client's Authorization. The nil path must not regress the security behaviour. +func TestCopyHeaders_NilEndpointStripsAuth(t *testing.T) { + t.Parallel() + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + originalReq.Header.Set("Authorization", "Bearer client-secret") + + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, nil) + + assert.Empty(t, proxyReq.Header.Get("Authorization"), + "client Authorization must be stripped even when no endpoint auth is configured") +} + +// TestCopyHeaders_AuthArrivesAtBackend wires up a real httptest backend and +// confirms that the injected Authorization header survives the full round-trip, +// not just that CopyHeaders places it on proxyReq in memory. +func TestCopyHeaders_AuthArrivesAtBackend(t *testing.T) { + t.Parallel() + + srv, captured := newCapturingBackend(t) + + endpoint := &domain.Endpoint{ + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer backend-secret", + } + + originalReq := httptest.NewRequest(http.MethodGet, "http://olla.internal/api/tags", nil) + + proxyReq, err := http.NewRequest(http.MethodGet, srv.URL+"/api/tags", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, endpoint) + + resp, err := http.DefaultClient.Do(proxyReq) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + + require.NotNil(t, captured.headers, "backend must have received a request") + assert.Equal(t, "Bearer backend-secret", captured.headers.Get("Authorization"), + "Authorization header must arrive at the backend after transport") +} + +// TestCopyHeaders_CustomHeaders covers the endpoint.Headers map injection behaviour. +func TestCopyHeaders_CustomHeaders(t *testing.T) { + t.Parallel() + + t.Run("custom headers set with no auth", func(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + Headers: map[string]string{ + "X-Tenant-ID": "acme", + "X-Env": "prod", + }, + } + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, endpoint) + + assert.Equal(t, "acme", proxyReq.Header.Get("X-Tenant-ID")) + assert.Equal(t, "prod", proxyReq.Header.Get("X-Env")) + assert.Empty(t, proxyReq.Header.Get("Authorization"), "no auth header when auth is not configured") + }) + + t.Run("auth wins when headers map also sets Authorization", func(t *testing.T) { + t.Parallel() + + // If a user puts Authorization in both headers: and auth:, auth: must win. + endpoint := &domain.Endpoint{ + Headers: map[string]string{ + "Authorization": "Bearer from-headers-map", + }, + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer from-auth-section", + } + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, endpoint) + + values := proxyReq.Header["Authorization"] + require.Len(t, values, 1, "must have exactly one Authorization value") + assert.Equal(t, "Bearer from-auth-section", values[0], "auth: section must beat headers: map") + }) + + t.Run("sensitive header in headers map overrides the strip: operator intent wins", func(t *testing.T) { + t.Parallel() + + // The strip removes the client's X-Api-Key, but if the operator explicitly + // puts X-Api-Key in headers:, it is their deliberate configuration and should be honoured. + endpoint := &domain.Endpoint{ + Headers: map[string]string{ + "X-Api-Key": "backend-api-key", + }, + } + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + originalReq.Header.Set("X-Api-Key", "client-api-key") + + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, endpoint) + + // The endpoint's value must appear, not the client's. + assert.Equal(t, "backend-api-key", proxyReq.Header.Get("X-Api-Key"), + "operator-configured header must reach the backend even if it was in the strip list") + }) + + t.Run("nil headers map is a no-op", func(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + Headers: nil, + } + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + originalReq.Header.Set("Content-Type", "application/json") + + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, endpoint) + + // Content-Type is copied from the client as normal. + assert.Equal(t, "application/json", proxyReq.Header.Get("Content-Type")) + }) + + t.Run("multiple custom headers all set correctly", func(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + Headers: map[string]string{ + "X-Org": "my-org", + "X-Region": "ap-southeast-2", + "X-Priority": "high", + }, + } + + originalReq := httptest.NewRequest(http.MethodPost, "http://olla.internal/v1/chat", nil) + proxyReq, err := http.NewRequest(http.MethodPost, "http://backend.internal/v1/chat", nil) + require.NoError(t, err) + + CopyHeaders(proxyReq, originalReq, endpoint) + + assert.Equal(t, "my-org", proxyReq.Header.Get("X-Org")) + assert.Equal(t, "ap-southeast-2", proxyReq.Header.Get("X-Region")) + assert.Equal(t, "high", proxyReq.Header.Get("X-Priority")) + }) +} diff --git a/internal/adapter/proxy/core/common_test.go b/internal/adapter/proxy/core/common_test.go index 1929bab5..2e33805e 100644 --- a/internal/adapter/proxy/core/common_test.go +++ b/internal/adapter/proxy/core/common_test.go @@ -121,7 +121,7 @@ func TestCopyHeaders(t *testing.T) { proxyReq := httptest.NewRequest("GET", "http://backend.com/test", nil) // Copy headers - CopyHeaders(proxyReq, originalReq) + CopyHeaders(proxyReq, originalReq, nil) // Check that expected headers were copied for k, expectedValues := range tt.expectedCopied { @@ -231,7 +231,7 @@ func TestCopyHeaders_ProxyHeaders(t *testing.T) { proxyReq := httptest.NewRequest("GET", "http://backend.com/test", nil) // Copy headers - CopyHeaders(proxyReq, originalReq) + CopyHeaders(proxyReq, originalReq, nil) // Check forwarded headers if tt.expectedForwardedFor != "" { @@ -603,7 +603,7 @@ func TestCopyHeaders_ExistingHeaders(t *testing.T) { t.Logf("Test: %s - Initial header count: %d", tt.description, initialLen) // Copy headers - CopyHeaders(proxyReq, originalReq) + CopyHeaders(proxyReq, originalReq, nil) // Verify expected headers (excluding proxy-specific headers) for k, expectedValues := range tt.expectedHeaders { @@ -648,7 +648,7 @@ func TestCopyHeaders_MapPreSizingOptimization(t *testing.T) { originalReq.Header.Set("X-Test", "value") // This should work without panic - CopyHeaders(req, originalReq) + CopyHeaders(req, originalReq, nil) assert.Equal(t, "value", req.Header.Get("X-Test")) }) @@ -667,7 +667,7 @@ func BenchmarkCopyHeaders(b *testing.B) { b.ReportAllocs() for range b.N { proxyReq := httptest.NewRequest("GET", "http://backend.com/proxy", nil) - CopyHeaders(proxyReq, originalReq) + CopyHeaders(proxyReq, originalReq, nil) } } @@ -688,7 +688,7 @@ func BenchmarkCopyHeaders_WithExistingHeaders(b *testing.B) { // Pre-populate with some headers (edge case) proxyReq.Header.Set("X-Pre-Existing-1", "value1") proxyReq.Header.Set("X-Pre-Existing-2", "value2") - CopyHeaders(proxyReq, originalReq) + CopyHeaders(proxyReq, originalReq, nil) } } @@ -789,7 +789,7 @@ func TestCopyHeaders_DoesNotPropagateInboundHost(t *testing.T) { proxyReq := httptest.NewRequest("POST", "http://backend.internal:11434/api/generate", nil) proxyReq.Host = "" // Transport will use URL.Host when this is empty - CopyHeaders(proxyReq, originalReq) + CopyHeaders(proxyReq, originalReq, nil) // Host must remain unset so Go's transport derives it from URL.Host (the backend address). assert.Empty(t, proxyReq.Host, "outbound Host must not be overridden with the inbound client Host") @@ -799,6 +799,170 @@ func TestCopyHeaders_DoesNotPropagateInboundHost(t *testing.T) { "X-Forwarded-Host must carry the original inbound Host") } +// TestCopyResponseHeaders_StripsSensitiveHeaders verifies that headers a +// compromised or misconfigured backend should never reflect to clients are +// removed, while safe headers pass through unchanged. +func TestCopyResponseHeaders_StripsSensitiveHeaders(t *testing.T) { + t.Parallel() + + sensitiveHeaders := []string{ + "Authorization", + "Proxy-Authorization", + "X-Api-Key", + "X-Auth-Token", + "Set-Cookie", + } + + for _, header := range sensitiveHeaders { + t.Run("strips_"+header, func(t *testing.T) { + t.Parallel() + + src := http.Header{} + src.Set(header, "must-not-appear") + src.Set("Content-Type", "application/json") + + dst := http.Header{} + CopyResponseHeaders(dst, src, nil) + + if got := dst.Get(header); got != "" { + t.Errorf("CopyResponseHeaders forwarded sensitive header %q = %q, want empty", header, got) + } + if got := dst.Get("Content-Type"); got == "" { + t.Error("CopyResponseHeaders dropped Content-Type header; safe headers must pass through") + } + }) + } +} + +// TestCopyResponseHeaders_PassesThroughSafeHeaders verifies that non-sensitive +// headers from the upstream response are forwarded to the client unchanged. +func TestCopyResponseHeaders_PassesThroughSafeHeaders(t *testing.T) { + t.Parallel() + + src := http.Header{} + src.Set("Content-Type", "text/event-stream") + src.Set("X-Custom-Header", "custom-value") + src.Set("Cache-Control", "no-cache") + + dst := http.Header{} + CopyResponseHeaders(dst, src, nil) + + for _, h := range []string{"Content-Type", "X-Custom-Header", "Cache-Control"} { + if dst.Get(h) == "" { + t.Errorf("CopyResponseHeaders dropped safe header %q", h) + } + } +} + +// TestCopyResponseHeaders_StripsEndpointAuthHeader verifies that an endpoint's +// configured auth header name is stripped from the upstream response, preventing +// a backend that reflects it from leaking credentials to the client. +func TestCopyResponseHeaders_StripsEndpointAuthHeader(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + AuthHeaderName: "X-Custom-Auth", + } + + src := http.Header{} + src.Set("X-Custom-Auth", "leaked-secret") + src.Set("Content-Encoding", "gzip") + + dst := http.Header{} + CopyResponseHeaders(dst, src, endpoint) + + if got := dst.Get("X-Custom-Auth"); got != "" { + t.Errorf("endpoint auth header leaked to client: X-Custom-Auth = %q, want empty", got) + } + if got := dst.Get("Content-Encoding"); got == "" { + t.Error("CopyResponseHeaders dropped safe header Content-Encoding") + } +} + +// TestCopyResponseHeaders_StripsEndpointConfiguredHeaders verifies that every +// header named in endpoint.Headers is stripped from the upstream response. The +// rule is consistent: anything the operator names is denied on the way back, +// regardless of whether the header looks benign. +func TestCopyResponseHeaders_StripsEndpointConfiguredHeaders(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + Headers: map[string]string{ + "X-Foo": "bar", + "X-Bar": "baz", + }, + } + + src := http.Header{} + src.Set("X-Foo", "upstream-value") + src.Set("X-Bar", "upstream-value") + src.Set("Content-Type", "application/json") + + dst := http.Header{} + CopyResponseHeaders(dst, src, endpoint) + + for _, h := range []string{"X-Foo", "X-Bar"} { + if got := dst.Get(h); got != "" { + t.Errorf("endpoint configured header %q leaked to client = %q, want empty", h, got) + } + } + if got := dst.Get("Content-Type"); got == "" { + t.Error("CopyResponseHeaders dropped safe header Content-Type") + } +} + +// TestCopyResponseHeaders_StripsBenignLookingEndpointHeader proves that +// operator-configured headers are always stripped on the return path, even when +// the header name looks benign (e.g. Content-Type). Consistency matters more than +// semantic judgement about what "looks safe". +func TestCopyResponseHeaders_StripsBenignLookingEndpointHeader(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + Headers: map[string]string{ + "Content-Type": "application/json", + }, + } + + src := http.Header{} + src.Set("Content-Type", "text/plain") + src.Set("Content-Encoding", "gzip") + + dst := http.Header{} + CopyResponseHeaders(dst, src, endpoint) + + if got := dst.Get("Content-Type"); got != "" { + t.Errorf("operator-configured Content-Type was not stripped: got %q, want empty", got) + } + // A header not in any deny list must still pass through. + if got := dst.Get("Content-Encoding"); got == "" { + t.Error("CopyResponseHeaders dropped Content-Encoding which is not in any deny list") + } +} + +// TestCopyResponseHeaders_PassesThroughUndeniedHeader confirms that a normal +// response header not mentioned in any deny list reaches the client unchanged. +func TestCopyResponseHeaders_PassesThroughUndeniedHeader(t *testing.T) { + t.Parallel() + + endpoint := &domain.Endpoint{ + AuthHeaderName: "X-Custom-Auth", + Headers: map[string]string{"X-Foo": "bar"}, + } + + src := http.Header{} + src.Set("Content-Encoding", "gzip") + src.Set("X-Custom-Auth", "secret") + src.Set("X-Foo", "value") + + dst := http.Header{} + CopyResponseHeaders(dst, src, endpoint) + + if got := dst.Get("Content-Encoding"); got != "gzip" { + t.Errorf("Content-Encoding = %q, want %q", got, "gzip") + } +} + // BenchmarkSetResponseHeaders benchmarks the SetResponseHeaders function func BenchmarkSetResponseHeaders(b *testing.B) { stats := &ports.RequestStats{ diff --git a/internal/adapter/proxy/core/retry.go b/internal/adapter/proxy/core/retry.go index 9ad01a2d..53776d3f 100644 --- a/internal/adapter/proxy/core/retry.go +++ b/internal/adapter/proxy/core/retry.go @@ -35,7 +35,49 @@ func NewRetryHandler(discoveryService ports.DiscoveryService, logger logger.Styl // ProxyFunc defines the signature for endpoint proxy implementations type ProxyFunc func(ctx context.Context, w http.ResponseWriter, r *http.Request, endpoint *domain.Endpoint, stats *ports.RequestStats) error -// ExecuteWithRetry attempts request delivery with automatic failover on connection errors +// responseStartedWriter wraps http.ResponseWriter and records whether any response +// bytes or status codes have been sent to the client. We use this to gate retry +// decisions: once the client has received data, retrying a non-idempotent request +// would send duplicate content or charge twice on metered APIs. +type responseStartedWriter struct { + http.ResponseWriter + started bool +} + +func (rw *responseStartedWriter) WriteHeader(code int) { + rw.started = true + rw.ResponseWriter.WriteHeader(code) +} + +func (rw *responseStartedWriter) Write(b []byte) (int, error) { + rw.started = true + return rw.ResponseWriter.Write(b) +} + +// Unwrap exposes the underlying ResponseWriter so http.NewResponseController can +// discover optional interfaces (Flush, Hijack, SetDeadline, etc.) via the chain. +// Without this, the wrapper hides the underlying flusher and SSE streams stall. +func (rw *responseStartedWriter) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + +// isIdempotent reports whether the HTTP method is safe to retry after a partial +// response. GET, HEAD and OPTIONS are defined as idempotent by RFC 9110; POST, +// PATCH and DELETE are not. Retrying them risks double-billing or duplicate side +// effects if the upstream already processed the first attempt. +func isIdempotent(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return true + default: + return false + } +} + +// ExecuteWithRetry attempts request delivery with automatic failover on connection errors. +// For non-idempotent methods (POST, PATCH, DELETE), retry is suppressed once the +// response has started. Resending to a different endpoint risks duplicate content +// reaching the client or double-charging a metered API. func (h *RetryHandler) ExecuteWithRetry( ctx context.Context, w http.ResponseWriter, @@ -59,6 +101,9 @@ func (h *RetryHandler) ExecuteWithRetry( return err } + // Wrap the writer so we can detect when bytes have been committed to the client. + tracker := &responseStartedWriter{ResponseWriter: w} + var lastErr error maxRetries := len(endpoints) attemptCount := 0 @@ -76,7 +121,7 @@ func (h *RetryHandler) ExecuteWithRetry( } attemptCount++ - lastErr = h.executeProxyAttempt(ctx, w, r, endpoint, selector, stats, proxyFunc) + lastErr = h.executeProxyAttempt(ctx, tracker, r, endpoint, selector, stats, proxyFunc) if lastErr == nil { return nil @@ -87,6 +132,17 @@ func (h *RetryHandler) ExecuteWithRetry( return lastErr } + // For non-idempotent methods, once the response has started we cannot + // safely retry. The client would receive a partial response from this + // endpoint followed by a fresh one from the next, causing corruption or + // double-billing. Return the error and let the caller decide. + if tracker.started && !isIdempotent(r.Method) { + h.logger.Debug("skipping retry: response already started for non-idempotent method", + "method", r.Method, + "endpoint", endpoint.Name) + return lastErr + } + // Handle connection error and retry logic availableEndpoints = h.handleConnectionFailure(ctx, endpoint, lastErr, attemptCount, availableEndpoints, maxRetries) } diff --git a/internal/adapter/proxy/core/retry_safety_test.go b/internal/adapter/proxy/core/retry_safety_test.go new file mode 100644 index 00000000..cadf3c1e --- /dev/null +++ b/internal/adapter/proxy/core/retry_safety_test.go @@ -0,0 +1,307 @@ +package core + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thushan/olla/internal/core/domain" + "github.com/thushan/olla/internal/core/ports" + "github.com/thushan/olla/internal/logger" +) + +// ---- helpers ---------------------------------------------------------------- + +func newTestRetryHandler(t *testing.T) *RetryHandler { + t.Helper() + logCfg := &logger.Config{Level: "error"} + log, _, _ := logger.New(logCfg) + return NewRetryHandler(&testDiscoveryService{}, logger.NewPlainStyledLogger(log)) +} + +// roundRobinSelector cycles through endpoints in order. +type roundRobinSelector struct{ idx int } + +func (s *roundRobinSelector) Select(_ context.Context, eps []*domain.Endpoint) (*domain.Endpoint, error) { + if len(eps) == 0 { + return nil, errors.New("no endpoints") + } + ep := eps[s.idx%len(eps)] + s.idx++ + return ep, nil +} +func (s *roundRobinSelector) Name() string { return "round-robin" } +func (s *roundRobinSelector) IncrementConnections(_ *domain.Endpoint) {} +func (s *roundRobinSelector) DecrementConnections(_ *domain.Endpoint) {} + +// namedEndpoint creates a minimal endpoint with the given name. +func namedEndpoint(name string) *domain.Endpoint { + return &domain.Endpoint{Name: name, CheckTimeout: 0} +} + +// connectionResetError satisfies net.Error so IsConnectionError returns true. +type connectionResetError struct{} + +func (e *connectionResetError) Error() string { return "connection reset by peer" } +func (e *connectionResetError) Timeout() bool { return false } +func (e *connectionResetError) Temporary() bool { return false } + +// Ensure it implements net.Error. +var _ net.Error = (*connectionResetError)(nil) + +// ---- tests ------------------------------------------------------------------ + +// TestIsIdempotent confirms the idempotency predicate matches RFC 9110. +func TestIsIdempotent(t *testing.T) { + t.Parallel() + + tests := []struct { + method string + want bool + }{ + {http.MethodGet, true}, + {http.MethodHead, true}, + {http.MethodOptions, true}, + {http.MethodPost, false}, + {http.MethodPatch, false}, + {http.MethodDelete, false}, + {http.MethodPut, false}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, isIdempotent(tt.method)) + }) + } +} + +// TestResponseStartedWriter_TracksWrites verifies that the sentinel wrapper sets +// started=true on both WriteHeader and Write calls. +func TestResponseStartedWriter_TracksWrites(t *testing.T) { + t.Parallel() + + t.Run("WriteHeader sets started", func(t *testing.T) { + t.Parallel() + rw := &responseStartedWriter{ResponseWriter: httptest.NewRecorder()} + assert.False(t, rw.started) + rw.WriteHeader(http.StatusOK) + assert.True(t, rw.started) + }) + + t.Run("Write sets started", func(t *testing.T) { + t.Parallel() + rw := &responseStartedWriter{ResponseWriter: httptest.NewRecorder()} + assert.False(t, rw.started) + _, _ = rw.Write([]byte("hello")) + assert.True(t, rw.started) + }) + + t.Run("neither called: not started", func(t *testing.T) { + t.Parallel() + rw := &responseStartedWriter{ResponseWriter: httptest.NewRecorder()} + assert.False(t, rw.started) + }) +} + +// TestRetry_POSTWithBytesWritten_NoRetry is the critical correctness test. +// An httptest backend writes 100 bytes then RSTs. ExecuteWithRetry must NOT +// retry to a second endpoint because the response has already started. +func TestRetry_POSTWithBytesWritten_NoRetry(t *testing.T) { + t.Parallel() + + attemptsHit := 0 + + // Backend: write 100 bytes then close the connection abruptly. + // We hijack the connection to send a TCP RST after the body begins. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, 100)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + // Hijack and close to simulate a mid-stream RST. + hj, ok := w.(http.Hijacker) + if !ok { + return + } + conn, _, _ := hj.Hijack() + _ = conn.Close() + })) + t.Cleanup(srv.Close) + + h := newTestRetryHandler(t) + ep1 := namedEndpoint("ep1") + ep2 := namedEndpoint("ep2") + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + w := httptest.NewRecorder() + stats := &ports.RequestStats{} + + proxyFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, ep *domain.Endpoint, s *ports.RequestStats) error { + attemptsHit++ + // Simulate: write the headers + bytes, then return a connection error. + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, 100)) + return &connectionResetError{} + } + + err := h.ExecuteWithRetry(context.Background(), w, req, []*domain.Endpoint{ep1, ep2}, + &roundRobinSelector{}, stats, proxyFunc) + + // Error expected; the stream failed. + require.Error(t, err) + // Critically: only ONE attempt. Retrying would double-bill the user. + assert.Equal(t, 1, attemptsHit, "must not retry POST after response bytes flushed to client") +} + +// TestRetry_POSTBeforeBytesWritten_DoesRetry confirms that a connection error +// before any bytes are written still triggers failover, even for POST. +func TestRetry_POSTBeforeBytesWritten_DoesRetry(t *testing.T) { + t.Parallel() + + attemptsHit := 0 + + h := newTestRetryHandler(t) + ep1 := namedEndpoint("ep1") + ep2 := namedEndpoint("ep2") + + req := httptest.NewRequest(http.MethodPost, "/v1/chat", nil) + w := httptest.NewRecorder() + stats := &ports.RequestStats{} + + proxyFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, ep *domain.Endpoint, s *ports.RequestStats) error { + attemptsHit++ + if attemptsHit == 1 { + // First attempt: connection error before writing anything. + return &connectionResetError{} + } + // Second attempt: success. + w.WriteHeader(http.StatusOK) + return nil + } + + err := h.ExecuteWithRetry(context.Background(), w, req, []*domain.Endpoint{ep1, ep2}, + &roundRobinSelector{}, stats, proxyFunc) + + require.NoError(t, err) + assert.Equal(t, 2, attemptsHit, "POST should retry when no bytes have been written yet") +} + +// TestRetry_GETMidStreamRST_DoesRetry confirms that GET is always retried on +// connection errors, even after bytes have been written, because GET is idempotent. +func TestRetry_GETMidStreamRST_DoesRetry(t *testing.T) { + t.Parallel() + + attemptsHit := 0 + + h := newTestRetryHandler(t) + ep1 := namedEndpoint("ep1") + ep2 := namedEndpoint("ep2") + + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + w := httptest.NewRecorder() + stats := &ports.RequestStats{} + + proxyFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, ep *domain.Endpoint, s *ports.RequestStats) error { + attemptsHit++ + if attemptsHit == 1 { + // First attempt: write bytes then RST, simulates mid-stream failure. + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("partial")) + return &connectionResetError{} + } + // Second attempt: success. + w.WriteHeader(http.StatusOK) + return nil + } + + err := h.ExecuteWithRetry(context.Background(), w, req, []*domain.Endpoint{ep1, ep2}, + &roundRobinSelector{}, stats, proxyFunc) + + require.NoError(t, err) + assert.Equal(t, 2, attemptsHit, "GET should always retry on connection errors") +} + +// TestResponseStartedWriter_Unwrap verifies that Flush works through the wrapper via +// http.NewResponseController. Without Unwrap(), the controller cannot reach the +// underlying flusher and SSE streams stall silently. +func TestResponseStartedWriter_Unwrap(t *testing.T) { + t.Parallel() + + inner := httptest.NewRecorder() + rw := &responseStartedWriter{ResponseWriter: inner} + + rc := http.NewResponseController(rw) + if err := rc.Flush(); err != nil { + t.Errorf("Flush() via ResponseController on wrapped writer = %v, want nil", err) + } +} + +// TestRetry_HTTPTestBackend_PostRSTBeforeBody uses a real httptest backend that +// refuses the connection before sending a response. We verify that ExecuteWithRetry +// does attempt a second endpoint (no bytes written). +func TestRetry_HTTPTestBackend_PostRSTBeforeBody(t *testing.T) { + t.Parallel() + + // Backend 1: immediately close, simulates a refused connection. + srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "no hijack", http.StatusInternalServerError) + return + } + conn, _, _ := hj.Hijack() + _ = conn.Close() + })) + t.Cleanup(srv1.Close) + + secondHit := false + + // Backend 2: healthy. + srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + secondHit = true + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv2.Close) + + h := newTestRetryHandler(t) + req := httptest.NewRequest(http.MethodPost, "/v1/chat", nil) + w := httptest.NewRecorder() + stats := &ports.RequestStats{} + + ep1 := namedEndpoint("ep1") + ep2 := namedEndpoint("ep2") + + // proxyFunc uses srv2 on the second attempt to prove failover. + attempt := 0 + proxyFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, ep *domain.Endpoint, s *ports.RequestStats) error { + attempt++ + if attempt == 1 { + // Simulate connection error from srv1 before any bytes written. + return fmt.Errorf("connection refused: %w", &connectionResetError{}) + } + // Forward to srv2. + resp, err := http.Get(srv2.URL) + if err != nil { + return err + } + defer resp.Body.Close() //nolint:errcheck + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) + return nil + } + + err := h.ExecuteWithRetry(context.Background(), w, req, []*domain.Endpoint{ep1, ep2}, + &roundRobinSelector{}, stats, proxyFunc) + + require.NoError(t, err) + assert.True(t, secondHit, "second endpoint must be tried after first fails without writing bytes") +} diff --git a/internal/adapter/proxy/olla/service.go b/internal/adapter/proxy/olla/service.go index eecda19b..830ece86 100644 --- a/internal/adapter/proxy/olla/service.go +++ b/internal/adapter/proxy/olla/service.go @@ -41,7 +41,7 @@ import ( "github.com/thushan/olla/internal/adapter/health" "github.com/thushan/olla/internal/adapter/proxy/common" - "github.com/thushan/olla/internal/adapter/proxy/config" + proxyconfig "github.com/thushan/olla/internal/adapter/proxy/config" "github.com/thushan/olla/internal/adapter/proxy/core" "github.com/thushan/olla/internal/app/middleware" "github.com/thushan/olla/internal/core/domain" @@ -144,22 +144,22 @@ func NewService( ) (*Service, error) { if configuration.StreamBufferSize == 0 { - configuration.StreamBufferSize = config.OllaDefaultStreamBufferSize + configuration.StreamBufferSize = proxyconfig.OllaDefaultStreamBufferSize } if configuration.MaxIdleConns == 0 { - configuration.MaxIdleConns = config.OllaDefaultMaxIdleConns + configuration.MaxIdleConns = proxyconfig.OllaDefaultMaxIdleConns } if configuration.MaxConnsPerHost == 0 { - configuration.MaxConnsPerHost = config.OllaDefaultMaxConnsPerHost + configuration.MaxConnsPerHost = proxyconfig.OllaDefaultMaxConnsPerHost } if configuration.MaxIdleConnsPerHost == 0 { - configuration.MaxIdleConnsPerHost = config.OllaDefaultMaxIdleConnsPerHost + configuration.MaxIdleConnsPerHost = proxyconfig.OllaDefaultMaxIdleConnsPerHost } if configuration.IdleConnTimeout == 0 { - configuration.IdleConnTimeout = config.OllaDefaultIdleConnTimeout + configuration.IdleConnTimeout = proxyconfig.OllaDefaultIdleConnTimeout } if configuration.ReadTimeout == 0 { - configuration.ReadTimeout = config.DefaultReadTimeout + configuration.ReadTimeout = proxyconfig.DefaultReadTimeout } base := core.NewBaseProxyComponents(discoveryService, selector, statsCollector, metricsExtractor, logger) @@ -211,13 +211,18 @@ func NewService( // createOptimisedTransport creates an HTTP transport optimised for AI workloads func createOptimisedTransport(config *Configuration) *http.Transport { return &http.Transport{ - MaxIdleConns: config.MaxIdleConns, - MaxIdleConnsPerHost: config.MaxIdleConnsPerHost, - MaxConnsPerHost: config.MaxConnsPerHost, - IdleConnTimeout: config.IdleConnTimeout, - TLSHandshakeTimeout: DefaultTLSHandshakeTimeout, - DisableCompression: true, - ForceAttemptHTTP2: true, + MaxIdleConns: config.MaxIdleConns, + MaxIdleConnsPerHost: config.MaxIdleConnsPerHost, + MaxConnsPerHost: config.MaxConnsPerHost, + IdleConnTimeout: config.IdleConnTimeout, + TLSHandshakeTimeout: DefaultTLSHandshakeTimeout, + DisableCompression: true, + ForceAttemptHTTP2: true, + ResponseHeaderTimeout: proxyconfig.DefaultResponseHeaderTimeout, + // Olla targets local inference backends; outbound proxy env vars are not + // honoured here because they would route credentialled requests through an + // intermediary on plain HTTP. Health probes (no credentials) keep the proxy + // so corporate monitoring infra still works for connectivity checks. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ Timeout: config.GetConnectionTimeout(), @@ -366,8 +371,9 @@ func (s *Service) buildTargetURL(r *http.Request, endpoint *domain.Endpoint) *ur return common.BuildTargetURL(r, endpoint, s.configuration.GetProxyPrefix()) } -// prepareProxyRequest creates and prepares the proxy request with headers -func (s *Service) prepareProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL, stats *ports.RequestStats) (*http.Request, error) { +// prepareProxyRequest creates and prepares the proxy request with headers. +// endpoint is passed through so CopyHeaders can apply per-endpoint auth and custom headers. +func (s *Service) prepareProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL, endpoint *domain.Endpoint, stats *ports.RequestStats) (*http.Request, error) { proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body) if err != nil { return nil, err @@ -375,7 +381,7 @@ func (s *Service) prepareProxyRequest(ctx context.Context, r *http.Request, targ // Copy headers headerStart := time.Now() - core.CopyHeaders(proxyReq, r) + core.CopyHeaders(proxyReq, r, endpoint) stats.HeaderProcessingMs = time.Since(headerStart).Milliseconds() // Add model header @@ -458,12 +464,8 @@ func (s *Service) handleSuccessfulResponse(ctx context.Context, w http.ResponseW core.SetResponseHeaders(w, stats, endpoint) core.SetStickySessionHeaders(w, r) - // Copy response headers - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } + // Copy response headers, stripping any sensitive headers the upstream may reflect + core.CopyResponseHeaders(w.Header(), resp.Header, endpoint) w.WriteHeader(resp.StatusCode) diff --git a/internal/adapter/proxy/olla/service_retry.go b/internal/adapter/proxy/olla/service_retry.go index eba0258c..000d35cb 100644 --- a/internal/adapter/proxy/olla/service_retry.go +++ b/internal/adapter/proxy/olla/service_retry.go @@ -84,7 +84,7 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit // Rewrite model name in request body if this is an alias-resolved request core.RewriteModelForAlias(ctx, r, endpoint) - proxyReq, err := s.prepareProxyRequest(ctx, r, targetURL, stats) + proxyReq, err := s.prepareProxyRequest(ctx, r, targetURL, endpoint, stats) if err != nil { if cb != nil { cb.RecordFailure() @@ -124,12 +124,8 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit core.SetResponseHeaders(w, stats, endpoint) core.SetStickySessionHeaders(w, r) - // Copy response headers - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } + // Copy response headers, stripping any sensitive headers the upstream may reflect + core.CopyResponseHeaders(w.Header(), resp.Header, endpoint) w.WriteHeader(resp.StatusCode) diff --git a/internal/adapter/proxy/olla/service_transport_test.go b/internal/adapter/proxy/olla/service_transport_test.go index a7432bb3..04878645 100644 --- a/internal/adapter/proxy/olla/service_transport_test.go +++ b/internal/adapter/proxy/olla/service_transport_test.go @@ -1,12 +1,19 @@ package olla import ( + "reflect" + "runtime" "testing" "time" - "github.com/thushan/olla/internal/adapter/proxy/config" + proxyconfig "github.com/thushan/olla/internal/adapter/proxy/config" ) +// funcName extracts the full symbol name of a function value. +func funcName(f interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() +} + // TestCreateOptimisedTransport_ConnectionLimits verifies that both MaxConnsPerHost and // MaxIdleConnsPerHost are mapped to their correct fields on http.Transport. // Previously MaxConnsPerHost was mistakenly written to MaxIdleConnsPerHost and @@ -42,18 +49,18 @@ func TestCreateOptimisedTransport_DefaultsApplied(t *testing.T) { // Zero-value config — defaults should be filled in by NewService, but we can verify // the expected defaults are consistent with the package constants. cfg := &Configuration{} - cfg.MaxConnsPerHost = config.OllaDefaultMaxConnsPerHost - cfg.MaxIdleConnsPerHost = config.OllaDefaultMaxIdleConnsPerHost - cfg.MaxIdleConns = config.OllaDefaultMaxIdleConns - cfg.IdleConnTimeout = config.OllaDefaultIdleConnTimeout + cfg.MaxConnsPerHost = proxyconfig.OllaDefaultMaxConnsPerHost + cfg.MaxIdleConnsPerHost = proxyconfig.OllaDefaultMaxIdleConnsPerHost + cfg.MaxIdleConns = proxyconfig.OllaDefaultMaxIdleConns + cfg.IdleConnTimeout = proxyconfig.OllaDefaultIdleConnTimeout transport := createOptimisedTransport(cfg) - if transport.MaxConnsPerHost != config.OllaDefaultMaxConnsPerHost { - t.Errorf("MaxConnsPerHost: want %d, got %d", config.OllaDefaultMaxConnsPerHost, transport.MaxConnsPerHost) + if transport.MaxConnsPerHost != proxyconfig.OllaDefaultMaxConnsPerHost { + t.Errorf("MaxConnsPerHost: want %d, got %d", proxyconfig.OllaDefaultMaxConnsPerHost, transport.MaxConnsPerHost) } - if transport.MaxIdleConnsPerHost != config.OllaDefaultMaxIdleConnsPerHost { - t.Errorf("MaxIdleConnsPerHost: want %d, got %d", config.OllaDefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost) + if transport.MaxIdleConnsPerHost != proxyconfig.OllaDefaultMaxIdleConnsPerHost { + t.Errorf("MaxIdleConnsPerHost: want %d, got %d", proxyconfig.OllaDefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost) } } @@ -82,3 +89,45 @@ func TestCreateOptimisedTransport_FieldsAreDistinct(t *testing.T) { t.Errorf("MaxIdleConnsPerHost: want 10, got %d", transport.MaxIdleConnsPerHost) } } + +// TestCreateOptimisedTransport_NoProxyFromEnvironment asserts that the Olla proxy +// transport does NOT honour HTTP_PROXY/HTTPS_PROXY. Olla targets local +// inference backends; routing credentialled requests through an outbound proxy +// on plain HTTP is a credential-exposure risk. Health probes keep the env proxy. +func TestCreateOptimisedTransport_NoProxyFromEnvironment(t *testing.T) { + t.Parallel() + + cfg := &Configuration{} + cfg.MaxIdleConns = proxyconfig.OllaDefaultMaxIdleConns + cfg.IdleConnTimeout = proxyconfig.OllaDefaultIdleConnTimeout + + transport := createOptimisedTransport(cfg) + + if transport.Proxy != nil { + got := funcName(transport.Proxy) + t.Errorf("Olla transport.Proxy = %s, want nil: proxy requests must not be routed through env proxy", got) + } +} + +// TestCreateOptimisedTransport_ResponseHeaderTimeout asserts that the Olla transport +// has a finite ResponseHeaderTimeout. Without it, a backend that accepts the TCP +// connection but withholds response headers blocks the goroutine indefinitely. +func TestCreateOptimisedTransport_ResponseHeaderTimeout(t *testing.T) { + t.Parallel() + + cfg := &Configuration{} + cfg.MaxIdleConns = proxyconfig.OllaDefaultMaxIdleConns + cfg.IdleConnTimeout = proxyconfig.OllaDefaultIdleConnTimeout + + transport := createOptimisedTransport(cfg) + + if transport.ResponseHeaderTimeout <= 0 { + t.Errorf("transport.ResponseHeaderTimeout is %v; backends that stall after accept will hang indefinitely", + transport.ResponseHeaderTimeout) + } + + want := proxyconfig.DefaultResponseHeaderTimeout + if transport.ResponseHeaderTimeout != want { + t.Errorf("transport.ResponseHeaderTimeout = %v, want %v", transport.ResponseHeaderTimeout, want) + } +} diff --git a/internal/adapter/proxy/sherpa/service.go b/internal/adapter/proxy/sherpa/service.go index e9a06460..ba5e59f8 100644 --- a/internal/adapter/proxy/sherpa/service.go +++ b/internal/adapter/proxy/sherpa/service.go @@ -35,6 +35,7 @@ import ( "net/http" "time" + "github.com/thushan/olla/internal/adapter/proxy/config" "github.com/thushan/olla/internal/adapter/proxy/core" "github.com/thushan/olla/internal/core/domain" "github.com/thushan/olla/internal/core/ports" @@ -56,6 +57,9 @@ const ( ClientDisconnectionTimeThreshold = 5 * time.Second ) +// DefaultResponseHeaderTimeout re-exports the shared constant for Sherpa callers. +const DefaultResponseHeaderTimeout = config.DefaultResponseHeaderTimeout + // Service implements the Sherpa proxy - optimised for simplicity and maintainability type Service struct { *core.BaseProxyComponents @@ -88,11 +92,16 @@ func NewService( // Create transport with TCP tuning for LLM streaming transport := &http.Transport{ - MaxIdleConns: DefaultMaxIdleConns, - IdleConnTimeout: DefaultIdleConnTimeout, - DisableCompression: DefaultDisableCompression, - TLSHandshakeTimeout: DefaultTLSHandshakeTimeout, - MaxIdleConnsPerHost: DefaultMaxIdleConnsPerHost, + MaxIdleConns: DefaultMaxIdleConns, + IdleConnTimeout: DefaultIdleConnTimeout, + DisableCompression: DefaultDisableCompression, + TLSHandshakeTimeout: DefaultTLSHandshakeTimeout, + MaxIdleConnsPerHost: DefaultMaxIdleConnsPerHost, + ResponseHeaderTimeout: DefaultResponseHeaderTimeout, + // Olla targets local inference backends; outbound proxy env vars are not + // honoured here because they would route credentialled requests through an + // intermediary on plain HTTP. Health probes (no credentials) keep the proxy + // so corporate monitoring infra still works for connectivity checks. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ Timeout: configuration.GetConnectionTimeout(), diff --git a/internal/adapter/proxy/sherpa/service_retry.go b/internal/adapter/proxy/sherpa/service_retry.go index 8b9601c7..8521ffe4 100644 --- a/internal/adapter/proxy/sherpa/service_retry.go +++ b/internal/adapter/proxy/sherpa/service_retry.go @@ -82,7 +82,7 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit rlog.Debug("created proxy request") headerStart := time.Now() - core.CopyHeaders(proxyReq, r) + core.CopyHeaders(proxyReq, r, endpoint) stats.HeaderProcessingMs = time.Since(headerStart).Milliseconds() // Add model header if available @@ -117,12 +117,8 @@ func (s *Service) proxyToSingleEndpoint(ctx context.Context, w http.ResponseWrit core.SetResponseHeaders(w, stats, endpoint) core.SetStickySessionHeaders(w, r) - // Copy response headers - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } + // Copy response headers, stripping any sensitive headers the upstream may reflect + core.CopyResponseHeaders(w.Header(), resp.Header, endpoint) w.WriteHeader(resp.StatusCode) diff --git a/internal/adapter/proxy/sherpa/service_transport_test.go b/internal/adapter/proxy/sherpa/service_transport_test.go new file mode 100644 index 00000000..9f7f4ca2 --- /dev/null +++ b/internal/adapter/proxy/sherpa/service_transport_test.go @@ -0,0 +1,73 @@ +package sherpa + +import ( + "reflect" + "runtime" + "testing" + "time" +) + +// funcName extracts the full symbol name of a function value for comparison. +// http.ProxyFromEnvironment is a named function so the pointer is stable across builds. +func funcName(f interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() +} + +// newSherpaServiceForTransportTest builds a real Sherpa service via NewService so +// the transport tests exercise the production construction path. +func newSherpaServiceForTransportTest(t *testing.T) *Service { + t.Helper() + + cfg := &Configuration{} + cfg.ConnectionTimeout = 2 * time.Second + cfg.ConnectionKeepAlive = 30 * time.Second + cfg.StreamBufferSize = 8192 + + svc, err := NewService( + nil, // discovery service, not needed for transport tests + &mockEndpointSelector{}, + cfg, + nil, // stats collector + nil, // metrics extractor + createTestLogger(), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + t.Cleanup(svc.Cleanup) + return svc +} + +// TestSherpaTransport_NoProxyFromEnvironment asserts that the Sherpa proxy +// transport does NOT honour HTTP_PROXY/HTTPS_PROXY. Olla targets local +// inference backends; routing credentialled requests through an outbound proxy +// on plain HTTP is a credential-exposure risk. Health probes keep the env proxy. +func TestSherpaTransport_NoProxyFromEnvironment(t *testing.T) { + t.Parallel() + + svc := newSherpaServiceForTransportTest(t) + + if svc.transport.Proxy != nil { + got := funcName(svc.transport.Proxy) + t.Errorf("Sherpa transport.Proxy = %s, want nil: proxy requests must not be routed through env proxy", got) + } +} + +// TestSherpaTransport_ResponseHeaderTimeout asserts that the Sherpa transport +// has a finite ResponseHeaderTimeout. Without it, a backend that accepts the +// TCP connection but withholds response headers blocks the goroutine indefinitely. +func TestSherpaTransport_ResponseHeaderTimeout(t *testing.T) { + t.Parallel() + + svc := newSherpaServiceForTransportTest(t) + + if svc.transport.ResponseHeaderTimeout <= 0 { + t.Errorf("transport.ResponseHeaderTimeout is %v; backends that stall after accept will hang indefinitely", + svc.transport.ResponseHeaderTimeout) + } + + const want = DefaultResponseHeaderTimeout + if svc.transport.ResponseHeaderTimeout != want { + t.Errorf("transport.ResponseHeaderTimeout = %v, want %v", svc.transport.ResponseHeaderTimeout, want) + } +} diff --git a/internal/app/handlers/handler_auth_routes_test.go b/internal/app/handlers/handler_auth_routes_test.go new file mode 100644 index 00000000..6224c76e --- /dev/null +++ b/internal/app/handlers/handler_auth_routes_test.go @@ -0,0 +1,260 @@ +package handlers + +// TestAuthAcrossProxyRoutes proves that every proxy-bearing route handler passes +// an auth-configured endpoint to the proxy service unchanged. +// +// Background: issue #139 revealed that providerProxyHandler was not wired through +// the same middleware path as proxyHandler, so cross-cutting concerns (sticky +// sessions in that case, auth injection in general) could silently be skipped. +// +// This test catches the next #139-style regression: if a new handler family is +// added that bypasses executeProxyRequest (and therefore bypasses the CopyHeaders +// call that injects endpoint credentials), this test will fail because the proxy +// service will never be invoked with the auth endpoint. +// +// What this test does NOT cover: +// - Whether CopyHeaders correctly injects the auth header (covered by +// internal/adapter/proxy/core/common_auth_test.go) +// - End-to-end network delivery of the header to the backend (out of scope for +// handler-layer tests; that lives in the proxy engine tests) +// +// Route families covered: +// - /olla/proxy/ → proxyHandler +// - /olla/ollama/ → providerProxyHandler (representative of all provider routes) +// - /olla/anthropic/v1/messages → translationHandler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/thushan/olla/internal/adapter/inspector" + "github.com/thushan/olla/internal/config" + "github.com/thushan/olla/internal/core/constants" + "github.com/thushan/olla/internal/core/domain" + "github.com/thushan/olla/internal/core/ports" + styledlogger "github.com/thushan/olla/internal/logger" +) + +// authEndpoint returns an endpoint configured with bearer auth. We use this as +// the only entry in the discovery service so every proxy route must route through it. +func authEndpoint(t *testing.T, providerType string) *domain.Endpoint { + t.Helper() + u, _ := url.Parse("http://localhost:11434") + return &domain.Endpoint{ + Name: "auth-endpoint", + URL: u, + URLString: u.String(), + Type: providerType, + Status: domain.StatusHealthy, + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer test-backend-secret", + } +} + +// authCapturingProxyService records the endpoints passed in by the handler so we +// can assert the auth-configured endpoint was forwarded unmodified. +type authCapturingProxyService struct { + capturedEndpoints []*domain.Endpoint + capturedCtx context.Context +} + +func (s *authCapturingProxyService) ProxyRequestToEndpoints( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + endpoints []*domain.Endpoint, + stats *ports.RequestStats, + _ styledlogger.StyledLogger, +) error { + s.capturedEndpoints = endpoints + s.capturedCtx = ctx + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + return nil +} + +func (s *authCapturingProxyService) ProxyRequest( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + stats *ports.RequestStats, + _ styledlogger.StyledLogger, +) error { + return nil +} + +func (s *authCapturingProxyService) GetStats(ctx context.Context) (ports.ProxyStats, error) { + return ports.ProxyStats{}, nil +} +func (s *authCapturingProxyService) UpdateConfig(c ports.ProxyConfiguration) {} + +// discoveryWithAuth returns a discovery service that yields a single +// auth-configured endpoint of the given provider type. +func discoveryWithAuth(ep *domain.Endpoint) *mockDiscoveryServiceWithHealthy { + return &mockDiscoveryServiceWithHealthy{endpoints: []*domain.Endpoint{ep}} +} + +// minimalApp builds an Application wired for auth route tests. +func minimalApp(t *testing.T, capture *authCapturingProxyService, ds ports.DiscoveryService, log *mockStyledLogger) *Application { + t.Helper() + + return &Application{ + logger: log, + proxyService: capture, + discoveryService: ds, + inspectorChain: inspector.NewChain(log), + profileFactory: &mockProfileFactory{ + validProfiles: map[string]bool{ + "ollama": true, + "openai": true, + "lmstudio": true, + "lm-studio": true, + "vllm": true, + }, + }, + statsCollector: &mockStatsCollector{}, + repository: &mockEndpointRepository{}, + Config: &config.Config{ + Server: config.ServerConfig{RateLimits: config.ServerRateLimits{}}, + }, + StartTime: time.Now(), + } +} + +// TestAuthAcrossProxyRoutes_ProxyHandler verifies proxyHandler forwards the +// auth-configured endpoint to the proxy service. This route has always worked +// correctly; the test serves as a baseline for the parameterised coverage. +func TestAuthAcrossProxyRoutes_ProxyHandler(t *testing.T) { + t.Parallel() + + ep := authEndpoint(t, "openai") + capture := &authCapturingProxyService{} + mockLog := &mockStyledLogger{} + app := minimalApp(t, capture, discoveryWithAuth(ep), mockLog) + + body := `{"model":"llama3","messages":[{"role":"user","content":"hi"}]}` + req := httptest.NewRequest(http.MethodPost, "/olla/proxy/v1/chat/completions", + strings.NewReader(body)) + req.Header.Set(constants.HeaderContentType, constants.ContentTypeJSON) + + // Inject the route prefix the router would normally set. + ctx := context.WithValue(req.Context(), constants.ContextRoutePrefixKey, "/olla/proxy/") + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + app.proxyHandler(rec, req) + + require.Equal(t, http.StatusOK, rec.Code, "proxyHandler must complete successfully") + assertAuthEndpointReached(t, "proxyHandler", capture, ep) +} + +// TestAuthAcrossProxyRoutes_ProviderProxyHandler verifies providerProxyHandler (the +// route used by /olla/ollama/, /olla/openai/, etc.) also forwards the auth endpoint. +// This is the handler that was wired incorrectly in issue #139. Had this test +// existed then, the missing sticky-session wiring would have been caught first. +func TestAuthAcrossProxyRoutes_ProviderProxyHandler(t *testing.T) { + t.Parallel() + + ep := authEndpoint(t, "ollama") + capture := &authCapturingProxyService{} + mockLog := &mockStyledLogger{} + app := minimalApp(t, capture, discoveryWithAuth(ep), mockLog) + + body := `{"model":"llama3","messages":[{"role":"user","content":"hi"}]}` + req := httptest.NewRequest(http.MethodPost, "/olla/ollama/api/chat", + strings.NewReader(body)) + req.Header.Set(constants.HeaderContentType, constants.ContentTypeJSON) + + rec := httptest.NewRecorder() + app.providerProxyHandler(rec, req) + + require.Equal(t, http.StatusOK, rec.Code, "providerProxyHandler must complete successfully") + assertAuthEndpointReached(t, "providerProxyHandler", capture, ep) +} + +// TestAuthAcrossProxyRoutes_TranslationHandler verifies the Anthropic translation +// handler also flows through the proxy service with the auth endpoint intact. +// The translation handler has its own code path (buffering body, model extraction, +// passthrough logic) so it warrants separate coverage. +func TestAuthAcrossProxyRoutes_TranslationHandler(t *testing.T) { + t.Parallel() + + ep := authEndpoint(t, "openai") + capture := &authCapturingProxyService{} + mockLog := &mockStyledLogger{} + app := minimalApp(t, capture, discoveryWithAuth(ep), mockLog) + // statsCollector is needed by recordTranslatorMetrics + app.statsCollector = &mockStatsCollector{} + + trans := &mockTranslator{ + name: "anthropic", + implementsErrorWriter: true, + implementsPathProvider: true, + pathProvider: "/olla/anthropic/v1/messages", + writeErrorFunc: func(w http.ResponseWriter, err error, statusCode int) { + w.Header().Set(constants.HeaderContentType, constants.ContentTypeJSON) + w.WriteHeader(statusCode) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"error": err.Error()}) + }, + } + + handler := app.translationHandler(trans) + + body := map[string]interface{}{ + "model": "claude-3-sonnet", + "max_tokens": 100, + "messages": []interface{}{map[string]interface{}{"role": "user", "content": "hi"}}, + } + bodyBytes, _ := json.Marshal(body) + req := httptest.NewRequest(http.MethodPost, "/olla/anthropic/v1/messages", + bytes.NewReader(bodyBytes)) + req.Header.Set(constants.HeaderContentType, constants.ContentTypeJSON) + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code, "translationHandler must complete successfully") + assertAuthEndpointReached(t, "translationHandler (anthropic)", capture, ep) +} + +// assertAuthEndpointReached checks that the proxy service was called with at +// least one endpoint that carries the expected auth configuration. This is the +// invariant that CopyHeaders depends on to inject backend credentials. +func assertAuthEndpointReached(t *testing.T, handlerName string, capture *authCapturingProxyService, want *domain.Endpoint) { + t.Helper() + + require.NotNil(t, capture.capturedCtx, + "%s: proxy service was never called; handler returned before reaching executeProxyRequest", handlerName) + + assert.NotEmpty(t, capture.capturedEndpoints, + "%s: proxy service was called with zero endpoints", handlerName) + + found := false + for _, ep := range capture.capturedEndpoints { + if ep.AuthHeaderName == want.AuthHeaderName && ep.AuthHeaderValue == want.AuthHeaderValue { + found = true + break + } + } + + assert.True(t, found, + "%s: auth-configured endpoint was not present in the endpoints forwarded to the proxy service; "+ + "CopyHeaders will not inject backend credentials for this route family", handlerName) +} + +// Compile-time check that authCapturingProxyService satisfies the proxy service interface. +var _ ports.ProxyService = (*authCapturingProxyService)(nil) + +// ensure styledlogger import is referenced. +var _ styledlogger.StyledLogger = (*mockStyledLogger)(nil) diff --git a/internal/app/middleware/logging.go b/internal/app/middleware/logging.go index d87f96ee..8d4d32f4 100644 --- a/internal/app/middleware/logging.go +++ b/internal/app/middleware/logging.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "net/http" + "net/url" "strings" "time" @@ -191,7 +192,7 @@ func AccessLoggingMiddleware(styledLogger logger.StyledLogger) func(http.Handler "remote_addr", r.RemoteAddr, "method", r.Method, "path", r.URL.Path, - "query", r.URL.RawQuery, + "query", redactQuery(r.URL.RawQuery), "status", wrapped.status, "request_bytes", requestSize, "response_bytes", wrapped.size, @@ -204,6 +205,61 @@ func AccessLoggingMiddleware(styledLogger logger.StyledLogger) func(http.Handler } } +// sensitiveQueryKeys lists query parameter names whose values must never appear +// in logs. Values are compared case-insensitively. +var sensitiveQueryKeys = []string{ + "api_key", "token", "access_token", "key", "password", "secret", "auth", +} + +// redactQuery returns a sanitised version of a raw query string with values for +// sensitive parameter names replaced by [REDACTED]. It does not modify the +// original string; callers should use the return value for logging only. +func redactQuery(raw string) string { + if raw == "" { + return raw + } + + // Parse into individual key=value pairs while preserving order and raw form. + // We rebuild manually rather than using url.Values.Encode() because the latter + // percent-encodes bracket characters in "[REDACTED]". + pairs := strings.Split(raw, "&") + var changed bool + out := make([]string, len(pairs)) + + for i, pair := range pairs { + k, _, hasVal := strings.Cut(pair, "=") + if !hasVal { + out[i] = pair + continue + } + // Decode the key for comparison so percent-encoded forms like + // %70assword (password) are caught. Fall back to the raw key if + // the escape sequence is malformed. + decoded, decodeErr := url.QueryUnescape(k) + if decodeErr != nil { + decoded = k + } + sensitive := false + for _, sk := range sensitiveQueryKeys { + if strings.EqualFold(decoded, sk) { + sensitive = true + break + } + } + if sensitive { + out[i] = k + "=[REDACTED]" + changed = true + } else { + out[i] = pair + } + } + + if !changed { + return raw + } + return strings.Join(out, "&") +} + // formatBytes converts byte count to human-readable format func formatBytes(bytes int64) string { const unit = 1024 diff --git a/internal/app/middleware/logging_test.go b/internal/app/middleware/logging_test.go index 59c2963c..2493a49d 100644 --- a/internal/app/middleware/logging_test.go +++ b/internal/app/middleware/logging_test.go @@ -150,6 +150,114 @@ func TestGetRequestIDWithoutContext(t *testing.T) { } } +func TestRedactQuery(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantHide []string // substrings that must NOT appear in output + wantKeep []string // substrings that MUST appear in output + }{ + { + name: "empty query", + input: "", + wantKeep: []string{}, + }, + { + name: "api_key redacted", + input: "api_key=sk-1234", + wantHide: []string{"sk-1234"}, + wantKeep: []string{"[REDACTED]"}, + }, + { + name: "safe param unchanged", + input: "safe_param=value", + wantKeep: []string{"safe_param", "value"}, + }, + { + name: "mixed: token redacted, safe param kept", + input: "safe_param=value&token=secret", + wantHide: []string{"secret"}, + wantKeep: []string{"[REDACTED]", "safe_param", "value"}, + }, + { + name: "case-insensitive TOKEN", + input: "TOKEN=foo", + wantHide: []string{"foo"}, + wantKeep: []string{"[REDACTED]"}, + }, + { + name: "password redacted", + input: "user=alice&password=hunter2", + wantHide: []string{"hunter2"}, + wantKeep: []string{"[REDACTED]", "alice"}, + }, + { + name: "access_token redacted", + input: "access_token=tok-xyz", + wantHide: []string{"tok-xyz"}, + wantKeep: []string{"[REDACTED]"}, + }, + { + name: "secret redacted", + input: "secret=my-secret&page=2", + wantHide: []string{"my-secret"}, + wantKeep: []string{"[REDACTED]", "page"}, + }, + { + name: "auth redacted", + input: "auth=bearer-token", + wantHide: []string{"bearer-token"}, + wantKeep: []string{"[REDACTED]"}, + }, + { + name: "multiple sensitive keys all redacted", + input: "api_key=k1&token=t2&key=k3", + wantHide: []string{"k1", "t2", "k3"}, + wantKeep: []string{"[REDACTED]"}, + }, + { + // %70assword decodes to "password" and must still be redacted. + name: "percent-encoded key redacted", + input: "%70assword=secret", + wantHide: []string{"secret"}, + wantKeep: []string{"[REDACTED]"}, + }, + { + // api%5Fkey decodes to "api_key" (underscore encoded as %5F). + name: "encoded underscore in key redacted", + input: "api%5Fkey=foo", + wantHide: []string{"foo"}, + wantKeep: []string{"[REDACTED]"}, + }, + { + // %ZZ is not a valid percent-escape; must fall back to raw key comparison + // without panicking. "zzz" is not sensitive so the value passes through. + name: "malformed escape falls back to raw key", + input: "%ZZname=value", + wantKeep: []string{"value"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := redactQuery(tt.input) + for _, hide := range tt.wantHide { + if strings.Contains(result, hide) { + t.Errorf("redactQuery(%q) = %q; should not contain %q", tt.input, result, hide) + } + } + for _, keep := range tt.wantKeep { + if keep != "" && !strings.Contains(result, keep) { + t.Errorf("redactQuery(%q) = %q; should contain %q", tt.input, result, keep) + } + } + }) + } +} + // Mock styled logger for testing type mockStyledLogger struct{} diff --git a/internal/config/types.go b/internal/config/types.go index 69407931..69001148 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -126,9 +126,28 @@ type StaticDiscoveryConfig struct { Endpoints []EndpointConfig `yaml:"endpoints"` } +// AuthConfig holds per-endpoint outbound authentication configuration. +// Inline credential and _file variants are mutually exclusive; validation +// (fatal startup error on conflict) is enforced in P3. +type AuthConfig struct { + // header overrides the default header name for api_key auth (default X-Api-Key). + Header string `yaml:"header,omitempty"` + Token string `yaml:"token,omitempty"` + TokenFile string `yaml:"token_file,omitempty"` + Key string `yaml:"key,omitempty"` + KeyFile string `yaml:"key_file,omitempty"` + Username string `yaml:"username,omitempty"` + UsernameFile string `yaml:"username_file,omitempty"` + Password string `yaml:"password,omitempty"` + PasswordFile string `yaml:"password_file,omitempty"` + Type string `yaml:"type,omitempty"` +} + // EndpointConfig holds configuration for an AI inference endpoint type EndpointConfig struct { ModelFilter *domain.FilterConfig `yaml:"model_filter,omitempty"` + Auth *AuthConfig `yaml:"auth,omitempty"` + Headers map[string]string `yaml:"headers,omitempty"` // Priority uses a pointer so nil means "omitted in config" rather than explicitly zero. // This lets applyEndpointDefaults distinguish "user set 0" from "user said nothing", // since 0 is a valid, lower-than-default priority value. diff --git a/internal/config/types_test.go b/internal/config/types_test.go new file mode 100644 index 00000000..b454441c --- /dev/null +++ b/internal/config/types_test.go @@ -0,0 +1,142 @@ +package config + +import ( + "reflect" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestAuthConfig_RoundTrip(t *testing.T) { + t.Parallel() + + t.Run("bearer", func(t *testing.T) { + t.Parallel() + original := &AuthConfig{ + Type: "bearer", + Token: "${TOKEN}", + } + roundTrip(t, original, &AuthConfig{}) + }) + + t.Run("bearer with file", func(t *testing.T) { + t.Parallel() + original := &AuthConfig{ + Type: "bearer", + TokenFile: "/run/secrets/token", + } + roundTrip(t, original, &AuthConfig{}) + }) + + t.Run("api_key with custom header", func(t *testing.T) { + t.Parallel() + original := &AuthConfig{ + Type: "api_key", + Key: "${API_KEY}", + Header: "X-Custom-Key", + } + roundTrip(t, original, &AuthConfig{}) + }) + + t.Run("api_key with file", func(t *testing.T) { + t.Parallel() + original := &AuthConfig{ + Type: "api_key", + KeyFile: "/run/secrets/key", + } + roundTrip(t, original, &AuthConfig{}) + }) + + t.Run("basic", func(t *testing.T) { + t.Parallel() + original := &AuthConfig{ + Type: "basic", + Username: "user", + Password: "pass", + } + roundTrip(t, original, &AuthConfig{}) + }) + + t.Run("basic with files", func(t *testing.T) { + t.Parallel() + original := &AuthConfig{ + Type: "basic", + UsernameFile: "/run/secrets/user", + PasswordFile: "/run/secrets/pass", + } + roundTrip(t, original, &AuthConfig{}) + }) +} + +func TestEndpointConfig_AuthAndHeaders_RoundTrip(t *testing.T) { + t.Parallel() + + t.Run("full endpoint with auth and headers", func(t *testing.T) { + t.Parallel() + original := &EndpointConfig{ + Name: "secure-llama", + URL: "http://llamabox.local:8080", + Type: "llamacpp", + Auth: &AuthConfig{ + Type: "bearer", + Token: "${TOKEN}", + Header: "X-Api-Key", + }, + Headers: map[string]string{ + "X-Custom": "value", + "X-Another": "other", + }, + } + roundTrip(t, original, &EndpointConfig{}) + }) + + t.Run("endpoint without auth or headers is backwards compatible", func(t *testing.T) { + t.Parallel() + yamlIn := ` +name: plain +url: http://localhost:11434 +type: ollama +` + var got EndpointConfig + if err := yaml.Unmarshal([]byte(yamlIn), &got); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if got.Auth != nil { + t.Errorf("Auth should be nil when not present in YAML, got %+v", got.Auth) + } + if got.Headers != nil { + t.Errorf("Headers should be nil when not present in YAML, got %+v", got.Headers) + } + if got.Name != "plain" || got.URL != "http://localhost:11434" || got.Type != "ollama" { + t.Errorf("base fields not parsed correctly: %+v", got) + } + }) + + t.Run("headers map with multiple entries round-trips", func(t *testing.T) { + t.Parallel() + original := &EndpointConfig{ + Name: "ep", + Headers: map[string]string{ + "X-Tenant": "acme", + "X-Region": "us-east", + "X-Version": "2", + }, + } + roundTrip(t, original, &EndpointConfig{}) + }) +} + +// roundTrip marshals src to YAML and unmarshals into dst, then asserts deep equality. +func roundTrip[T any](t *testing.T, src *T, dst *T) { + t.Helper() + data, err := yaml.Marshal(src) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if err := yaml.Unmarshal(data, dst); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(src, dst) { + t.Errorf("round-trip mismatch\n got: %+v\n want: %+v", dst, src) + } +} diff --git a/internal/core/constants/auth.go b/internal/core/constants/auth.go new file mode 100644 index 00000000..d3d97170 --- /dev/null +++ b/internal/core/constants/auth.go @@ -0,0 +1,37 @@ +// Package constants defines auth scheme and header constants used across +// the outbound request pipeline. +package constants + +// Auth type identifiers. These are the valid values for endpoint auth.type in config. +const ( + AuthTypeBearer = "bearer" + AuthTypeAPIKey = "api_key" + AuthTypeBasic = "basic" +) + +// HTTP header names used when injecting auth onto outbound requests. +const ( + // AuthHeaderAuthorization is the standard header for bearer and basic auth. + AuthHeaderAuthorization = "Authorization" + + // AuthDefaultAPIKeyHeader is the fallback header name when an api_key auth + // block omits the optional header field. + AuthDefaultAPIKeyHeader = "X-Api-Key" //nolint:gosec // false positive: this is a header name, not a credential +) + +// Auth scheme prefixes. Note the trailing space; these are prepended to the +// credential value when building the final Authorization header. +const ( + AuthSchemeBearer = "Bearer " + AuthSchemeBasic = "Basic " +) + +// IsValidAuthType reports whether s is a recognised auth.type value. +func IsValidAuthType(s string) bool { + switch s { + case AuthTypeBearer, AuthTypeAPIKey, AuthTypeBasic: + return true + default: + return false + } +} diff --git a/internal/core/constants/auth_test.go b/internal/core/constants/auth_test.go new file mode 100644 index 00000000..e19440a3 --- /dev/null +++ b/internal/core/constants/auth_test.go @@ -0,0 +1,71 @@ +package constants_test + +import ( + "testing" + + "github.com/thushan/olla/internal/core/constants" +) + +func TestIsValidAuthType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want bool + }{ + {"bearer", constants.AuthTypeBearer, true}, + {"api_key", constants.AuthTypeAPIKey, true}, + {"basic", constants.AuthTypeBasic, true}, + {"empty", "", false}, + {"unknown", "oauth2", false}, + {"case sensitive", "Bearer", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := constants.IsValidAuthType(tt.input) + if got != tt.want { + t.Errorf("IsValidAuthType(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestAuthConstants(t *testing.T) { + t.Parallel() + + t.Run("auth type values", func(t *testing.T) { + t.Parallel() + if constants.AuthTypeBearer != "bearer" { + t.Errorf("AuthTypeBearer: expected %q, got %q", "bearer", constants.AuthTypeBearer) + } + if constants.AuthTypeAPIKey != "api_key" { + t.Errorf("AuthTypeAPIKey: expected %q, got %q", "api_key", constants.AuthTypeAPIKey) + } + if constants.AuthTypeBasic != "basic" { + t.Errorf("AuthTypeBasic: expected %q, got %q", "basic", constants.AuthTypeBasic) + } + }) + + t.Run("header names", func(t *testing.T) { + t.Parallel() + if constants.AuthHeaderAuthorization != "Authorization" { + t.Errorf("AuthHeaderAuthorization: expected %q, got %q", "Authorization", constants.AuthHeaderAuthorization) + } + if constants.AuthDefaultAPIKeyHeader != "X-Api-Key" { + t.Errorf("AuthDefaultAPIKeyHeader: expected %q, got %q", "X-Api-Key", constants.AuthDefaultAPIKeyHeader) + } + }) + + t.Run("scheme prefixes include trailing space", func(t *testing.T) { + t.Parallel() + if constants.AuthSchemeBearer != "Bearer " { + t.Errorf("AuthSchemeBearer: expected %q, got %q", "Bearer ", constants.AuthSchemeBearer) + } + if constants.AuthSchemeBasic != "Basic " { + t.Errorf("AuthSchemeBasic: expected %q, got %q", "Basic ", constants.AuthSchemeBasic) + } + }) +} diff --git a/internal/core/domain/auth_hint_test.go b/internal/core/domain/auth_hint_test.go new file mode 100644 index 00000000..ea8f03cf --- /dev/null +++ b/internal/core/domain/auth_hint_test.go @@ -0,0 +1,109 @@ +package domain_test + +import ( + "testing" + + "gopkg.in/yaml.v3" + + "github.com/thushan/olla/internal/core/domain" +) + +// TestAuthHintRoundTrip verifies that the auth hint section in a profile YAML +// deserialises correctly and that the zero value is safe (omitempty means the +// hint is simply absent when not configured). +func TestAuthHintRoundTrip(t *testing.T) { + t.Parallel() + + yamlInput := ` +name: test-profile +version: "1.0" +characteristics: + timeout: 5m + streaming_support: true + auth: + required: false + types: + - bearer + - api_key + default_header: "X-Api-Key" +` + + var cfg domain.ProfileConfig + if err := yaml.Unmarshal([]byte(yamlInput), &cfg); err != nil { + t.Fatalf("yaml.Unmarshal failed: %v", err) + } + + hint := cfg.Characteristics.Auth + + if hint.Required { + t.Error("expected Required=false") + } + if len(hint.Types) != 2 { + t.Errorf("expected 2 auth types, got %d: %v", len(hint.Types), hint.Types) + } + if hint.Types[0] != "bearer" { + t.Errorf("expected Types[0]=bearer, got %q", hint.Types[0]) + } + if hint.Types[1] != "api_key" { + t.Errorf("expected Types[1]=api_key, got %q", hint.Types[1]) + } + if hint.DefaultHeader != "X-Api-Key" { + t.Errorf("expected DefaultHeader=X-Api-Key, got %q", hint.DefaultHeader) + } +} + +// TestAuthHintAbsent verifies that a profile without an auth section produces +// a zero-value AuthHint, not an error. +func TestAuthHintAbsent(t *testing.T) { + t.Parallel() + + yamlInput := ` +name: minimal-profile +version: "1.0" +characteristics: + timeout: 2m + streaming_support: false +` + + var cfg domain.ProfileConfig + if err := yaml.Unmarshal([]byte(yamlInput), &cfg); err != nil { + t.Fatalf("yaml.Unmarshal failed: %v", err) + } + + hint := cfg.Characteristics.Auth + if hint.Required { + t.Error("expected Required=false for absent auth hint") + } + if len(hint.Types) != 0 { + t.Errorf("expected no auth types for absent hint, got %v", hint.Types) + } +} + +// TestAuthHintRequiredFlag verifies the required flag is parsed correctly for +// profiles that mandate authentication (e.g. a cloud API gateway). +func TestAuthHintRequiredFlag(t *testing.T) { + t.Parallel() + + yamlInput := ` +name: cloud-profile +version: "1.0" +characteristics: + timeout: 1m + auth: + required: true + types: + - bearer +` + + var cfg domain.ProfileConfig + if err := yaml.Unmarshal([]byte(yamlInput), &cfg); err != nil { + t.Fatalf("yaml.Unmarshal failed: %v", err) + } + + if !cfg.Characteristics.Auth.Required { + t.Error("expected Required=true") + } + if len(cfg.Characteristics.Auth.Types) != 1 || cfg.Characteristics.Auth.Types[0] != "bearer" { + t.Errorf("unexpected types: %v", cfg.Characteristics.Auth.Types) + } +} diff --git a/internal/core/domain/endpoint.go b/internal/core/domain/endpoint.go index afb4ad83..92fe6e6a 100644 --- a/internal/core/domain/endpoint.go +++ b/internal/core/domain/endpoint.go @@ -8,21 +8,28 @@ import ( ) const ( - StatusStringHealthy = "healthy" - StatusStringBusy = "busy" - StatusStringOffline = "offline" - StatusStringWarming = "warming" - StatusStringUnhealthy = "unhealthy" - StatusStringUnknown = "unknown" + StatusStringHealthy = "healthy" + StatusStringBusy = "busy" + StatusStringOffline = "offline" + StatusStringWarming = "warming" + StatusStringUnhealthy = "unhealthy" + StatusStringUnknown = "unknown" + StatusStringConfigError = "config_error" + StatusStringRateLimited = "rate_limited" ) type Endpoint struct { - LastChecked time.Time - NextCheckTime time.Time - URL *url.URL - HealthCheckURL *url.URL - ModelUrl *url.URL - ModelFilter *FilterConfig + LastChecked time.Time + NextCheckTime time.Time + // RateLimitedUntil is set when a health probe receives 429. The scheduler skips + // probing this endpoint until the time passes. Never serialised. + RateLimitedUntil time.Time `json:"-"` + URL *url.URL + HealthCheckURL *url.URL + ModelUrl *url.URL + ModelFilter *FilterConfig + // Headers holds verbatim outbound headers copied from endpoint config at load time. + Headers map[string]string `json:"-"` Name string Type string `json:"type,omitempty"` Status EndpointStatus @@ -30,13 +37,19 @@ type Endpoint struct { HealthCheckPathString string HealthCheckURLString string ModelURLString string - LastLatency time.Duration - CheckInterval time.Duration - CheckTimeout time.Duration - Priority int - ConsecutiveFailures int - BackoffMultiplier int - PreservePath bool + // AuthHeaderName is the resolved header name for outbound auth (e.g. "Authorization", "X-Api-Key"). + // Precomputed at load time so the hot path pays no allocation cost. + AuthHeaderName string + // AuthHeaderValue is the fully composed header value (e.g. "Bearer tok", "Basic base64(...)"). + // Never serialised; leaking credentials through logs or status endpoints would be a security issue. + AuthHeaderValue string `json:"-"` + LastLatency time.Duration + CheckInterval time.Duration + CheckTimeout time.Duration + Priority int + ConsecutiveFailures int + BackoffMultiplier int + PreservePath bool } func (e *Endpoint) GetURLString() string { @@ -56,6 +69,12 @@ const ( StatusWarming EndpointStatus = StatusStringWarming StatusUnhealthy EndpointStatus = StatusStringUnhealthy StatusUnknown EndpointStatus = StatusStringUnknown + // StatusConfigError indicates the endpoint is reachable but the credentials + // or headers are wrong. The operator must fix config; retrying achieves nothing. + StatusConfigError EndpointStatus = StatusStringConfigError + // StatusRateLimited indicates the endpoint returned 429. The scheduler should + // honour the Retry-After delay before probing again. + StatusRateLimited EndpointStatus = StatusStringRateLimited ) func (s EndpointStatus) IsRoutable() bool { diff --git a/internal/core/domain/endpoint_test.go b/internal/core/domain/endpoint_test.go new file mode 100644 index 00000000..9b0eba82 --- /dev/null +++ b/internal/core/domain/endpoint_test.go @@ -0,0 +1,69 @@ +package domain_test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/thushan/olla/internal/core/domain" +) + +// TestEndpoint_AuthHeaderValue_NotSerialised ensures that AuthHeaderValue never +// appears in JSON output. The field carries live credentials; exposing it through +// status endpoints or logs would be a security issue. +func TestEndpoint_AuthHeaderValue_NotSerialised(t *testing.T) { + t.Parallel() + + ep := &domain.Endpoint{ + Name: "secure-ep", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer super-secret-token", + } + + data, err := json.Marshal(ep) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + + if strings.Contains(string(data), "super-secret-token") { + t.Errorf("AuthHeaderValue leaked into JSON output: %s", string(data)) + } + if strings.Contains(string(data), "AuthHeaderValue") { + t.Errorf("AuthHeaderValue field name present in JSON output: %s", string(data)) + } +} + +func TestEndpoint_AuthFields_ZeroValueByDefault(t *testing.T) { + t.Parallel() + + ep := &domain.Endpoint{Name: "plain"} + + if ep.AuthHeaderName != "" { + t.Errorf("AuthHeaderName should be empty by default, got %q", ep.AuthHeaderName) + } + if ep.AuthHeaderValue != "" { + t.Errorf("AuthHeaderValue should be empty by default, got %q", ep.AuthHeaderValue) + } + if ep.Headers != nil { + t.Errorf("Headers should be nil by default, got %v", ep.Headers) + } +} + +func TestEndpoint_Headers_StoredVerbatim(t *testing.T) { + t.Parallel() + + want := map[string]string{ + "X-Tenant": "acme", + "X-Region": "us-east", + } + ep := &domain.Endpoint{ + Name: "ep", + Headers: want, + } + + for k, v := range want { + if ep.Headers[k] != v { + t.Errorf("Headers[%q] = %q, want %q", k, ep.Headers[k], v) + } + } +} diff --git a/internal/core/domain/healthcheck.go b/internal/core/domain/healthcheck.go index 919d463b..e71439c0 100644 --- a/internal/core/domain/healthcheck.go +++ b/internal/core/domain/healthcheck.go @@ -6,11 +6,14 @@ import ( ) type HealthCheckResult struct { - Error error - Status EndpointStatus - Latency time.Duration - ErrorType HealthCheckErrorType - StatusCode int + // RateLimitedUntil is populated when the probe received a 429 with a Retry-After + // header. The scheduler uses this to skip probing until the window has elapsed. + RateLimitedUntil time.Time + Error error + Status EndpointStatus + Latency time.Duration + ErrorType HealthCheckErrorType + StatusCode int } type HealthCheckErrorType int diff --git a/internal/core/domain/json_safety_test.go b/internal/core/domain/json_safety_test.go new file mode 100644 index 00000000..af1b441f --- /dev/null +++ b/internal/core/domain/json_safety_test.go @@ -0,0 +1,124 @@ +package domain_test + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/thushan/olla/internal/core/domain" +) + +// TestEndpointModelsURLNotSerialised asserts that EndpointModels.EndpointURL +// is never emitted in JSON output. The field is used as an internal map key +// and must not appear in API responses; it may carry auth credentials or +// internal network addresses. +func TestEndpointModelsURLNotSerialised(t *testing.T) { + t.Parallel() + + em := domain.EndpointModels{ + LastUpdated: time.Now(), + EndpointURL: "http://user:pass@192.168.1.100:8000", + Models: []*domain.ModelInfo{{Name: "llama3"}}, + } + + data, err := json.Marshal(em) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + if strings.Contains(string(data), "192.168.1.100") { + t.Errorf("EndpointModels JSON contains endpoint URL: %s", data) + } + if strings.Contains(string(data), "endpoint_url") { + t.Errorf("EndpointModels JSON contains 'endpoint_url' key: %s", data) + } + // Sanity check: the model data is still present. + if !strings.Contains(string(data), "llama3") { + t.Errorf("EndpointModels JSON missing model data: %s", data) + } +} + +// TestSourceEndpointURLNotSerialised asserts that SourceEndpoint.EndpointURL +// is not emitted in JSON output. The field holds the backend URL used for +// internal routing and must not surface in unified model API responses. +func TestSourceEndpointURLNotSerialised(t *testing.T) { + t.Parallel() + + se := domain.SourceEndpoint{ + EndpointURL: "http://admin:secret@gpu-host:8000", + EndpointName: "gpu-vllm", + NativeName: "meta-llama/Llama-3.1-8B", + } + + data, err := json.Marshal(se) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + if strings.Contains(string(data), "gpu-host") { + t.Errorf("SourceEndpoint JSON contains backend hostname: %s", data) + } + if strings.Contains(string(data), "admin") { + t.Errorf("SourceEndpoint JSON contains auth info: %s", data) + } + // Sanity: public fields are still serialised. + if !strings.Contains(string(data), "gpu-vllm") { + t.Errorf("SourceEndpoint JSON missing endpoint_name: %s", data) + } +} + +// TestEndpointAuthFieldsNotSerialised asserts that Endpoint.AuthHeaderValue +// is not included in JSON output. Leaking a resolved credential through any +// status endpoint would be a serious security issue. +func TestEndpointAuthFieldsNotSerialised(t *testing.T) { + t.Parallel() + + ep := domain.Endpoint{ + Name: "vllm-gpu", + Type: "vllm", + AuthHeaderName: "Authorization", + AuthHeaderValue: "Bearer sk-super-secret-token", + } + + data, err := json.Marshal(ep) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + if strings.Contains(string(data), "sk-super-secret-token") { + t.Errorf("Endpoint JSON contains auth header value: %s", data) + } +} + +// TestEndpointHeadersNotSerialised asserts that Endpoint.Headers is not included +// in JSON output. The map may contain API keys or other custom auth values set +// via the headers: config block; exposing them through status endpoints would +// leak operator secrets. +func TestEndpointHeadersNotSerialised(t *testing.T) { + t.Parallel() + + ep := domain.Endpoint{ + Name: "guarded-backend", + Type: "ollama", + Headers: map[string]string{ + "X-Custom-Key": "do-not-leak", + }, + } + + data, err := json.Marshal(ep) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + if strings.Contains(string(data), "do-not-leak") { + t.Errorf("Endpoint JSON contains Headers map value: %s", data) + } + if strings.Contains(string(data), "X-Custom-Key") { + t.Errorf("Endpoint JSON contains Headers map key: %s", data) + } + // Sanity: other fields still serialise. + if !strings.Contains(string(data), "guarded-backend") { + t.Errorf("Endpoint JSON missing name field: %s", data) + } +} diff --git a/internal/core/domain/model.go b/internal/core/domain/model.go index 4eb42ff7..488738f2 100644 --- a/internal/core/domain/model.go +++ b/internal/core/domain/model.go @@ -35,8 +35,11 @@ type ModelInfo struct { } type EndpointModels struct { - LastUpdated time.Time `json:"last_updated"` - EndpointURL string `json:"endpoint_url"` + LastUpdated time.Time `json:"last_updated"` + // EndpointURL is used for internal map-keying only; it must not appear in + // serialised API responses because it may carry auth credentials or internal + // network addresses. + EndpointURL string `json:"-"` Models []*ModelInfo `json:"models"` } diff --git a/internal/core/domain/profile_config.go b/internal/core/domain/profile_config.go index 33c5aa2a..46b6c28e 100644 --- a/internal/core/domain/profile_config.go +++ b/internal/core/domain/profile_config.go @@ -6,13 +6,10 @@ import "time" // new inference platforms without touching Go code. Much easier than // submitting PRs for every new LLM server that pops up. type ProfileConfig struct { - - // Metrics extraction configuration for provider responses - Metrics MetricsConfig `yaml:"metrics,omitempty"` - Name string `yaml:"name"` - Version string `yaml:"version"` - DisplayName string `yaml:"display_name"` - Description string `yaml:"description"` + Name string `yaml:"name"` + Version string `yaml:"version"` + DisplayName string `yaml:"display_name"` + Description string `yaml:"description"` Detection struct { Headers []string `yaml:"headers"` @@ -34,6 +31,9 @@ type ProfileConfig struct { } `yaml:"parsing_rules"` } `yaml:"request"` + // Metrics extraction configuration for provider responses + Metrics MetricsConfig `yaml:"metrics,omitempty"` + Models struct { CapabilityPatterns map[string][]string `yaml:"capability_patterns"` NameFormat string `yaml:"name_format"` @@ -57,6 +57,18 @@ type ProfileConfig struct { OpenAICompatible bool `yaml:"openai_compatible"` } `yaml:"api"` + Characteristics struct { + // Auth declares optional authentication hints for this backend profile. + // This is purely informational at the profile level and does not enforce + // validation. Future tooling (startup warnings, status endpoints) can use + // it to guide operators when configuring endpoint auth. + Auth AuthHint `yaml:"auth,omitempty" json:"auth,omitempty"` + Timeout time.Duration `yaml:"timeout"` + MaxConcurrentRequests int `yaml:"max_concurrent_requests"` + DefaultPriority int `yaml:"default_priority"` + StreamingSupport bool `yaml:"streaming_support"` + } `yaml:"characteristics"` + Resources struct { Quantization struct { Multipliers map[string]float64 `yaml:"multipliers"` @@ -75,13 +87,23 @@ type ProfileConfig struct { ChatCompletions int `yaml:"chat_completions"` Embeddings int `yaml:"embeddings"` } `yaml:"path_indices"` +} - Characteristics struct { - Timeout time.Duration `yaml:"timeout"` - MaxConcurrentRequests int `yaml:"max_concurrent_requests"` - DefaultPriority int `yaml:"default_priority"` - StreamingSupport bool `yaml:"streaming_support"` - } `yaml:"characteristics"` +// AuthHint describes the authentication capabilities of a backend profile. +// Fields are intentionally optional. Absent means "we don't know" or "not applicable". +// This is advisory only; the actual auth configuration lives in the endpoint config. +type AuthHint struct { + // DefaultHeader is the expected credential header for api_key auth. + // Empty means the backend uses the standard Authorization header. + DefaultHeader string `yaml:"default_header,omitempty" json:"default_header,omitempty"` + + // Types lists the authentication schemes this backend supports (e.g. ["bearer"]). + // An empty slice means the backend has no documented auth support. + Types []string `yaml:"types,omitempty" json:"types,omitempty"` + + // Required indicates whether the backend is known to require authentication. + // False (default) means auth is optional or not applicable for most deployments. + Required bool `yaml:"required,omitempty" json:"required,omitempty"` } // ModelSizePattern defines resource requirements for models matching specific patterns diff --git a/internal/logger/styled_plain.go b/internal/logger/styled_plain.go index ff34f27c..cf2a3cc7 100644 --- a/internal/logger/styled_plain.go +++ b/internal/logger/styled_plain.go @@ -100,6 +100,10 @@ func (sl *PlainStyledLogger) InfoHealthStatus(msg string, name string, status do statusText = "Unhealthy" case domain.StatusUnknown: statusText = "Unknown" + case domain.StatusConfigError: + statusText = "Config Error" + case domain.StatusRateLimited: + statusText = "Rate Limited" } styledMsg := fmt.Sprintf("%s %s is %s", msg, name, statusText) sl.logger.Info(styledMsg, args...) diff --git a/makefile b/makefile index 83dfdc8a..48b7f42c 100644 --- a/makefile +++ b/makefile @@ -18,7 +18,7 @@ LDFLAGS := -ldflags "\ -X '$(PKG).Tool=$(TOOL)' \ -X '$(PKG).User=$(USER)'" -.PHONY: run clean build test test-verbose test-short test-race test-cover bench version install-deps check-deps vet test-script-integration test-script-sticky mock-up mock-down mock-status mock-logs test-sticky-manual +.PHONY: run clean build test test-verbose test-short test-race test-cover bench version install-deps check-deps vet test-script-integration test-script-sticky mock-up mock-down mock-status mock-logs test-sticky-manual test-auth-bearer test-auth-env-fatal test-auth-manual # Build the application with version info build: @@ -391,6 +391,30 @@ test-sticky-manual: @echo "Running full sticky session manual test..." @bash test/scripts/sticky/run-manual.sh +# ── Auth test helpers ───────────────────────────────────────────────────────── +# These targets run the auth integration scripts that prove outbound credential +# injection works end-to-end against the Go mock backend (test/cmd/mockbackend). +# No Docker or AIMock required; the scripts spin up the mock backend in-process. + +## test-auth-bearer: Bearer token injection end-to-end (happy + failure paths) +test-auth-bearer: + @echo "Running bearer auth integration test..." + @bash test/scripts/auth/auth-bearer.sh + +## test-auth-env-fatal: Missing env var must abort Olla startup with a clear error +test-auth-env-fatal: + @echo "Running env-var-missing fatal startup test..." + @bash test/scripts/auth/auth-env-fatal.sh + +## test-auth-manual: Run all auth scripts that do not require Docker +test-auth-manual: + @echo "Running all auth integration tests..." + @bash test/scripts/auth/auth-bearer.sh + @bash test/scripts/auth/auth-api-key.sh + @bash test/scripts/auth/auth-basic.sh + @bash test/scripts/auth/auth-headers-only.sh + @bash test/scripts/auth/auth-env-fatal.sh + # Show help help: @echo "Available targets:" @@ -440,4 +464,7 @@ help: @echo " mock-status - Show AIMock container state" @echo " mock-logs - Tail logs from all AIMock instances" @echo " test-sticky-manual - Full sticky session end-to-end test (mock + olla + assert)" + @echo " test-auth-bearer - Bearer token injection end-to-end (no Docker needed)" + @echo " test-auth-env-fatal - Missing env var aborts startup with a clear error" + @echo " test-auth-manual - All auth integration scripts (no Docker needed)" @echo " help - Show this help" \ No newline at end of file diff --git a/pkg/envresolver/platform_test.go b/pkg/envresolver/platform_test.go new file mode 100644 index 00000000..b483b157 --- /dev/null +++ b/pkg/envresolver/platform_test.go @@ -0,0 +1,7 @@ +package envresolver + +import "runtime" + +func isWindows() bool { + return runtime.GOOS == "windows" +} diff --git a/pkg/envresolver/resolver.go b/pkg/envresolver/resolver.go new file mode 100644 index 00000000..c7aaca16 --- /dev/null +++ b/pkg/envresolver/resolver.go @@ -0,0 +1,119 @@ +// Package envresolver expands ${VAR} and ${VAR:-default} placeholders in +// configuration strings using environment variable lookups. It intentionally +// does not support the bare $VAR form: config files often contain literal +// dollar signs (shell scripts, cost strings, regex), and requiring braces +// eliminates ambiguity without meaningful ergonomic cost. +package envresolver + +import ( + "errors" + "fmt" + "os" + "regexp" + "strings" +) + +// tokenPattern matches ${VAR} and ${VAR:-default}. No nesting, no bare $VAR. +var tokenPattern = regexp.MustCompile(`\$\{([^}]+)\}`) + +// Expand replaces every ${VAR} and ${VAR:-default} placeholder in s with its +// resolved value. An unset variable with no default resolves to the empty +// string. Expand never returns an error; use ExpandStrict when a missing +// variable must be fatal. +func Expand(s string) string { + if s == "" || !strings.Contains(s, "${") { + return s + } + + return tokenPattern.ReplaceAllStringFunc(s, func(token string) string { + expr := token[2 : len(token)-1] // strip ${ and } + name, fallback, hasFallback := strings.Cut(expr, ":-") + + v, set := os.LookupEnv(name) + // POSIX :- semantics: use default when the variable is unset OR empty. + // An explicitly set but empty variable still triggers the default, matching + // shell behaviour and making empty-string auth values detectable downstream. + if set && v != "" { + return v + } + if hasFallback { + return fallback + } + return "" + }) +} + +// ExpandStrict is like Expand but returns an error when a placeholder has no +// environment value and no default. The error message names the variable but +// never echoes the surrounding string or any partial value, so secrets in +// adjacent placeholders do not leak into logs. +func ExpandStrict(s string) (string, error) { + if s == "" || !strings.Contains(s, "${") { + return s, nil + } + + var missing []string + + expanded := tokenPattern.ReplaceAllStringFunc(s, func(token string) string { + expr := token[2 : len(token)-1] + name, fallback, hasFallback := strings.Cut(expr, ":-") + + v, set := os.LookupEnv(name) + // POSIX :- semantics: empty triggers default just like unset. + if set && v != "" { + return v + } + if hasFallback { + return fallback + } + // Only report as missing when the variable is genuinely unset; + // an explicit empty value is a valid (if unusual) operator choice + // and is handled by the downstream empty-token validation. + if !set { + missing = append(missing, name) + } + return "" + }) + + if len(missing) > 0 { + errs := make([]error, len(missing)) + for i, name := range missing { + errs[i] = fmt.Errorf("required environment variable %q is not set", name) + } + return "", errors.Join(errs...) + } + + return expanded, nil +} + +// ExpandWithFile resolves a config value that may come from either a literal +// string or a file path (the _file sibling-field convention). Callers pass the +// literal value and the file path; exactly one must be non-empty. +// +// When fileValue is set, the file is read and its contents are returned with +// leading/trailing whitespace trimmed. This mirrors the Docker Secrets / k8s +// mounted-secret pattern where a file holds a single secret value. +// +// Both values being non-empty is a configuration error the operator must fix +// before the process starts. This function fails fast so the mistake surfaces +// immediately rather than silently preferring one source. +func ExpandWithFile(value, fileValue string) (string, error) { + hasValue := value != "" + hasFile := fileValue != "" + + if hasValue && hasFile { + return "", errors.New("both value and value_file are set; use exactly one") + } + + if hasFile { + raw, err := os.ReadFile(fileValue) + if err != nil { + // Report the path but not any partial content. + return "", fmt.Errorf("reading secret file %q: %w", fileValue, err) + } + return strings.TrimSpace(string(raw)), nil + } + + // Plain value path: still expand any ${VAR} placeholders inside it. + return Expand(value), nil +} diff --git a/pkg/envresolver/resolver_test.go b/pkg/envresolver/resolver_test.go new file mode 100644 index 00000000..23cb1238 --- /dev/null +++ b/pkg/envresolver/resolver_test.go @@ -0,0 +1,342 @@ +package envresolver + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// writeTemp creates a temp file with the given content and returns its path. +// The file is automatically removed when t completes. +func writeTemp(t *testing.T, content string) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "envresolver-*.txt") + require.NoError(t, err) + _, err = f.WriteString(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + return f.Name() +} + +// --- Expand --- + +// TestExpand covers the core placeholder expansion logic. +// Subtests cannot be t.Parallel() here because t.Setenv is used for env +// isolation. t.Parallel() inside a subtest that calls t.Setenv panics in Go's +// test runner. The parent still runs concurrently with other top-level tests. +func TestExpand(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) + input string + want string + }{ + { + name: "empty_string_unchanged", + input: "", + want: "", + }, + { + name: "no_placeholders_unchanged", + input: "plain text", + want: "plain text", + }, + { + name: "plain_var_expansion", + setup: func(t *testing.T) { + t.Setenv("OLLA_TEST_TOKEN", "secret123") + }, + input: "${OLLA_TEST_TOKEN}", + want: "secret123", + }, + { + name: "var_embedded_in_text", + setup: func(t *testing.T) { + t.Setenv("OLLA_TEST_TOKEN", "abc") + }, + input: "Bearer ${OLLA_TEST_TOKEN}", + want: "Bearer abc", + }, + { + name: "default_used_when_var_unset", + input: "${OLLA_MISSING_VAR:-my-default}", + want: "my-default", + }, + { + name: "default_ignored_when_var_set", + setup: func(t *testing.T) { + t.Setenv("OLLA_SET_VAR", "real-value") + }, + input: "${OLLA_SET_VAR:-ignored-default}", + want: "real-value", + }, + { + name: "unset_var_with_no_default_resolves_to_empty", + input: "${OLLA_DEFINITELY_NOT_SET_XYZ}", + want: "", + }, + { + name: "multiple_placeholders_in_one_string", + setup: func(t *testing.T) { + t.Setenv("OLLA_HOST", "localhost") + t.Setenv("OLLA_PORT", "8080") + }, + input: "${OLLA_HOST}:${OLLA_PORT}", + want: "localhost:8080", + }, + { + name: "adjacent_placeholders_no_separator", + setup: func(t *testing.T) { + t.Setenv("OLLA_A", "foo") + t.Setenv("OLLA_B", "bar") + }, + // Regression guard for token boundary in ReplaceAllStringFunc. + input: "${OLLA_A}${OLLA_B}", + want: "foobar", + }, + { + // Nested ${${X}} is not supported. The regex [^}]+ stops at the + // first closing brace, so it matches ${${OLLA_OUTER} as a token + // (name = "${OLLA_OUTER", no env var matches → empty) and leaves a + // trailing "}" literal in the output. Asserted explicitly so this + // is documented behaviour, not a silent surprise. + name: "nested_placeholder_not_supported", + input: "${${OLLA_OUTER}}", + want: "}", + }, + { + // $VAR without braces is not expanded. Bare $ is ambiguous in YAML + // config files (shell scripts, cost strings, regex) so we require + // the explicit ${VAR} form to avoid false-positive expansions. + name: "bare_dollar_var_not_expanded", + input: "$OLLA_TEST_TOKEN", + want: "$OLLA_TEST_TOKEN", + }, + { + name: "string_without_dollar_unchanged", + input: "no-dollar-sign-here", + want: "no-dollar-sign-here", + }, + { + // ${VAR:-} with an empty default is a valid way to force empty + // without triggering ExpandStrict's missing-var error. + name: "default_with_empty_value_part", + input: "${OLLA_MISSING_EMPTY:-}", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup(t) + } + assert.Equal(t, tt.want, Expand(tt.input)) + }) + } +} + +// TestExpand_LookupEnvSemantics covers the distinction between unset and +// explicitly-empty variables. Subtests do not call t.Parallel() because +// t.Setenv panics when combined with t.Parallel() in Go's test runner. +func TestExpand_LookupEnvSemantics(t *testing.T) { + t.Run("unset_var_no_default_resolves_to_empty", func(t *testing.T) { + // Lenient Expand never errors; unset resolves to "". + got := Expand("${OLLA_LOOKUP_UNSET_XYZ_ABC}") + assert.Equal(t, "", got) + }) + + t.Run("explicit_empty_no_default_resolves_to_empty", func(t *testing.T) { + t.Setenv("OLLA_LOOKUP_EXPLICIT_EMPTY", "") + got := Expand("${OLLA_LOOKUP_EXPLICIT_EMPTY}") + assert.Equal(t, "", got) + }) + + t.Run("default_used_when_var_unset", func(t *testing.T) { + got := Expand("${OLLA_LOOKUP_UNSET_FOR_DEFAULT_XYZ:-mydefault}") + assert.Equal(t, "mydefault", got) + }) + + t.Run("default_used_when_var_explicit_empty", func(t *testing.T) { + // POSIX :- treats empty the same as unset; the default wins. + t.Setenv("OLLA_LOOKUP_EMPTY_DEFAULT", "") + got := Expand("${OLLA_LOOKUP_EMPTY_DEFAULT:-fallback}") + assert.Equal(t, "fallback", got) + }) + + t.Run("default_not_used_when_var_non_empty", func(t *testing.T) { + t.Setenv("OLLA_LOOKUP_NON_EMPTY", "real-value") + got := Expand("${OLLA_LOOKUP_NON_EMPTY:-ignored}") + assert.Equal(t, "real-value", got) + }) +} + +// TestExpandStrict_LookupEnvSemantics covers the ExpandStrict distinction +// between unset (fatal) and explicitly-empty (allowed; downstream concern). +// Subtests do not call t.Parallel() for the same reason as TestExpand. +func TestExpandStrict_LookupEnvSemantics(t *testing.T) { + t.Run("unset_var_returns_error", func(t *testing.T) { + _, err := ExpandStrict("${OLLA_STRICT_LOOKUP_UNSET_XYZ_ABC}") + require.Error(t, err) + assert.Contains(t, err.Error(), "OLLA_STRICT_LOOKUP_UNSET_XYZ_ABC") + }) + + t.Run("explicit_empty_no_error", func(t *testing.T) { + // An explicitly set-but-empty variable is not a missing variable; + // the downstream caller validates whether empty is acceptable. + t.Setenv("OLLA_STRICT_LOOKUP_EMPTY", "") + got, err := ExpandStrict("${OLLA_STRICT_LOOKUP_EMPTY}") + require.NoError(t, err) + assert.Equal(t, "", got) + }) +} + +// --- ExpandStrict --- + +func TestExpandStrict(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) + input string + want string + wantErr bool + errContains []string + errAbsent []string + }{ + { + name: "set_var_expands", + setup: func(t *testing.T) { + t.Setenv("OLLA_STRICT_KEY", "strict-val") + }, + input: "${OLLA_STRICT_KEY}", + want: "strict-val", + }, + { + name: "default_used_when_var_unset", + input: "${OLLA_STRICT_MISSING:-fallback}", + want: "fallback", + }, + { + name: "missing_var_returns_error", + input: "${OLLA_STRICT_NOT_SET_ZZZ}", + wantErr: true, + // Must name the variable so the operator knows what to set. + errContains: []string{"OLLA_STRICT_NOT_SET_ZZZ"}, + // Must NOT echo the ${...} token; that leaks the unresolved + // placeholder literally into logs. + errAbsent: []string{"${"}, + }, + { + name: "multiple_missing_vars_all_reported", + input: "${OLLA_MISSING_ONE} ${OLLA_MISSING_TWO}", + wantErr: true, + errContains: []string{"OLLA_MISSING_ONE", "OLLA_MISSING_TWO"}, + }, + { + name: "empty_string_no_error", + input: "", + want: "", + }, + { + name: "no_placeholders_no_error", + input: "static-value", + want: "static-value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup(t) + } + got, err := ExpandStrict(tt.input) + if tt.wantErr { + require.Error(t, err) + for _, s := range tt.errContains { + assert.Contains(t, err.Error(), s) + } + for _, s := range tt.errAbsent { + assert.NotContains(t, err.Error(), s) + } + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +// --- ExpandWithFile --- + +func TestExpandWithFile(t *testing.T) { + t.Run("both_set_is_error", func(t *testing.T) { + _, err := ExpandWithFile("direct-value", "/some/file") + require.Error(t, err) + // Neither the value nor the path must appear. The value may be a + // secret and the path leaks config structure. + assert.NotContains(t, err.Error(), "direct-value") + assert.NotContains(t, err.Error(), "/some/file") + }) + + t.Run("neither_set_returns_empty", func(t *testing.T) { + got, err := ExpandWithFile("", "") + require.NoError(t, err) + assert.Equal(t, "", got) + }) + + t.Run("plain_value_returned", func(t *testing.T) { + got, err := ExpandWithFile("api-key-value", "") + require.NoError(t, err) + assert.Equal(t, "api-key-value", got) + }) + + t.Run("plain_value_with_placeholder_expanded", func(t *testing.T) { + t.Setenv("OLLA_FILE_TEST_KEY", "expanded-key") + got, err := ExpandWithFile("${OLLA_FILE_TEST_KEY}", "") + require.NoError(t, err) + assert.Equal(t, "expanded-key", got) + }) + + t.Run("file_value_read_and_returned", func(t *testing.T) { + f := writeTemp(t, "my-secret-token\n") + got, err := ExpandWithFile("", f) + require.NoError(t, err) + assert.Equal(t, "my-secret-token", got) + }) + + t.Run("file_trailing_newline_trimmed", func(t *testing.T) { + f := writeTemp(t, " token-with-spaces \n") + got, err := ExpandWithFile("", f) + require.NoError(t, err) + assert.Equal(t, "token-with-spaces", got) + }) + + t.Run("file_no_trailing_newline_still_works", func(t *testing.T) { + f := writeTemp(t, "bare-token") + got, err := ExpandWithFile("", f) + require.NoError(t, err) + assert.Equal(t, "bare-token", got) + }) + + t.Run("missing_file_returns_error", func(t *testing.T) { + _, err := ExpandWithFile("", "/tmp/olla-envresolver-does-not-exist-xyz") + require.Error(t, err) + assert.Contains(t, err.Error(), "olla-envresolver-does-not-exist-xyz") + }) + + t.Run("file_permission_denied", func(t *testing.T) { + // chmod-based permission denial is not reliably enforceable on Windows. + // The process owner can still read files they own regardless of mode bits. + if isWindows() { + t.Skip("permission simulation not supported on Windows") + } + f := writeTemp(t, "secret") + require.NoError(t, os.Chmod(f, 0o000)) + t.Cleanup(func() { os.Chmod(f, 0o600) }) + + _, err := ExpandWithFile("", f) + require.Error(t, err) + }) +} diff --git a/test/cmd/mockbackend/main.go b/test/cmd/mockbackend/main.go new file mode 100644 index 00000000..e2137809 --- /dev/null +++ b/test/cmd/mockbackend/main.go @@ -0,0 +1,122 @@ +// mockbackend is a minimal HTTP server for auth integration tests. +// It enforces a single required header and value, returning 401 when the +// credential is absent or wrong. All other paths return a minimal +// OpenAI-compatible JSON response so Olla's health checks and proxy pass +// through without requiring real model infrastructure. +// +// Usage: +// +// go run ./test/cmd/mockbackend \ +// --addr 127.0.0.1:19910 \ +// --require-header Authorization \ +// --require-value "Bearer test-token-abc123" +// +// When --require-header is omitted the server accepts all requests (useful +// for the happy-path AIMock-equivalent flow). +package main + +import ( + "encoding/json" + "flag" + "fmt" + "log/slog" + "net/http" + "os" + "strings" + "time" +) + +func main() { + addr := flag.String("addr", "127.0.0.1:19910", "listen address") + requireHeader := flag.String("require-header", "", "header name that must be present") + requireValue := flag.String("require-value", "", "exact header value required (ignored when require-header is empty)") + flag.Parse() + + mux := http.NewServeMux() + + // Unauthenticated liveness probe so test scripts can poll until the + // server is accepting connections without needing a valid credential. + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + }) + + // Health + model listing share the same auth check so Olla's + // health probes exercise the credential path. + mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) { + if !authorised(w, r, *requireHeader, *requireValue) { + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "object": "list", + "data": []map[string]any{ + {"id": "mock-model", "object": "model", "created": time.Now().Unix()}, + }, + }) + }) + + // Chat completions: the route Olla proxies inference requests through. + mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + if !authorised(w, r, *requireHeader, *requireValue) { + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "mock-cmpl-001", + "object": "chat.completion", + "created": time.Now().Unix(), + "model": "mock-model", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "mock-backend: auth accepted", + }, + "finish_reason": "stop", + }, + }, + "usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}, + }) + }) + + slog.Info("mockbackend listening", + "addr", *addr, + "require_header", *requireHeader, + ) + + srv := &http.Server{ + Addr: *addr, + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + } + + if err := srv.ListenAndServe(); err != nil { + fmt.Fprintf(os.Stderr, "mockbackend: %v\n", err) + os.Exit(1) + } +} + +// authorised checks the required credential and writes a 401 response when it +// is absent or wrong. Returns true when the request may proceed. +func authorised(w http.ResponseWriter, r *http.Request, header, value string) bool { + if header == "" { + return true + } + got := strings.TrimSpace(r.Header.Get(header)) + if got == value { + return true + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("WWW-Authenticate", `Bearer realm="mockbackend"`) + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": fmt.Sprintf("missing or invalid %s header", header), + "type": "authentication_error", + "code": "invalid_api_key", + }, + }) + return false +} diff --git a/test/manual/config.auth.yaml b/test/manual/config.auth.yaml new file mode 100644 index 00000000..f35fd727 --- /dev/null +++ b/test/manual/config.auth.yaml @@ -0,0 +1,88 @@ +## Olla Auth Manual Test Configuration +## Demonstrates outbound auth header injection — bearer, api_key, basic, and +## custom headers — against a single mock backend. +## +## The mock backend (test/cmd/mockbackend) enforces auth and returns 401 when +## a required credential is absent or wrong. AIMock does NOT enforce auth, so +## use the Go mock backend for enforcement tests; AIMock for happy-path flow. +## +## Usage (bearer, token from environment): +## export AUTH_TOKEN="test-token-abc123" +## go run . --config test/manual/config.auth.yaml +## +## The bash scripts in test/scripts/auth/ automate startup and teardown. +## Run them with: +## bash test/scripts/auth/auth-bearer.sh +## bash test/scripts/auth/auth-env-fatal.sh # proves missing var is fatal + +server: + host: "127.0.0.1" + port: 40115 + read_timeout: 10s + write_timeout: 0s + shutdown_timeout: 5s + request_logging: true + request_limits: + max_body_size: 10485760 + max_header_size: 524288 + rate_limits: + global_requests_per_minute: 1000 + per_ip_requests_per_minute: 200 + health_requests_per_minute: 1000 + burst_size: 50 + cleanup_interval: 5m + trust_proxy_headers: false + trusted_proxy_cidrs: ["127.0.0.0/8"] + +proxy: + engine: "olla" + profile: "auto" + load_balancer: "priority" + stream_buffer_size: 8192 + connection_timeout: 10s + response_timeout: 30s + read_timeout: 30s + retry: + enabled: false + +discovery: + type: "static" + refresh_interval: 10s + health_check: + initial_delay: 1s + static: + endpoints: + # Primary bearer-auth endpoint — token injected from AUTH_TOKEN env var. + # The mock backend (port 19910) enforces the value. + - url: "http://127.0.0.1:19910" + name: "mock-bearer" + type: "openai-compatible" + priority: 100 + model_url: "/v1/models" + health_check_url: "/v1/models" + check_interval: 5s + check_timeout: 2s + auth: + type: bearer + token: "${AUTH_TOKEN}" + + model_discovery: + enabled: false + +model_registry: + type: "memory" + enable_unifier: false + unification: + enabled: false + routing_strategy: + type: "optimistic" + options: + fallback_behavior: "all" + +logging: + level: "debug" + format: "text" + output: "stdout" + +engineering: + show_nerdstats: false diff --git a/test/scripts/auth/auth-api-key.sh b/test/scripts/auth/auth-api-key.sh new file mode 100644 index 00000000..8535f78d --- /dev/null +++ b/test/scripts/auth/auth-api-key.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +# auth-api-key.sh — proves Olla injects api_key credentials on outbound requests. +# +# Happy path: correct key in the configured header → 200. +# Failure path: wrong key → mock backend returns 401 → Olla surfaces non-200. +# +# Requires: go, curl, bash 4+ +# Does NOT require Docker / AIMock. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# shellcheck source=lib.sh +source "$SCRIPT_DIR/lib.sh" + +OLLA_PORT="${OLLA_PORT:-40116}" +BACKEND_PORT="${BACKEND_PORT:-19911}" +OLLA_URL="http://127.0.0.1:${OLLA_PORT}" +BACKEND_URL="http://127.0.0.1:${BACKEND_PORT}" +OLLA_LOG="${TMPDIR:-/tmp}/olla-auth-apikey.log" + +GOOD_KEY="sk-apikey-test-abc123" +KEY_HEADER="X-Api-Key" + +OLLA_PID="" +BACKEND_PID="" + +cleanup() { + kill_proc "$OLLA_PID" + kill_proc "$BACKEND_PID" +} +trap cleanup EXIT INT TERM + +echo "=== auth-api-key: outbound api_key header injection ===" +echo "Backend: ${BACKEND_URL} Olla: ${OLLA_URL}" +echo + +free_port "$BACKEND_PORT" +go run "$REPO_ROOT/test/cmd/mockbackend" \ + --addr "127.0.0.1:${BACKEND_PORT}" \ + --require-header "${KEY_HEADER}" \ + --require-value "${GOOD_KEY}" \ + >"${TMPDIR:-/tmp}/mockbackend-apikey.log" 2>&1 & +BACKEND_PID=$! +wait_for_mockbackend "$BACKEND_URL" 15 + +CONFIG=$(mktemp "${TMPDIR:-/tmp}/olla-auth-apikey-XXXXXX.yaml") +cat >"$CONFIG" <"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 1: correct key → 200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" == "200" ]]; then + pass "correct api_key → 200" +else + fail "correct api_key → 200" "got HTTP ${status}" +fi + +# ── restart with wrong key ──────────────────────────────────────────────────── +kill_proc "$OLLA_PID" +OLLA_PID="" +free_port "$OLLA_PORT" + +CONFIG_BAD=$(mktemp "${TMPDIR:-/tmp}/olla-auth-apikey-bad-XXXXXX.yaml") +sed "s/${GOOD_KEY}/wrong-key-000/" "$CONFIG" >"$CONFIG_BAD" + +go run "$REPO_ROOT" --config "$CONFIG_BAD" >>"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 2: wrong key → non-200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" != "200" ]]; then + pass "wrong api_key propagates non-200 (got ${status})" +else + fail "wrong api_key propagates non-200" "expected non-200 but got 200" +fi + +rm -f "$CONFIG" "$CONFIG_BAD" +summarise "auth-api-key" diff --git a/test/scripts/auth/auth-basic.sh b/test/scripts/auth/auth-basic.sh new file mode 100644 index 00000000..850ba6d0 --- /dev/null +++ b/test/scripts/auth/auth-basic.sh @@ -0,0 +1,150 @@ +#!/usr/bin/env bash +# auth-basic.sh — proves Olla injects HTTP Basic auth credentials on outbound +# requests. +# +# The mock backend enforces the pre-encoded "Authorization: Basic " value. +# The script computes the expected value once so we can assert it without +# duplicating encoding logic. +# +# Requires: go, curl, bash 4+, base64 (coreutils or macOS) +# Does NOT require Docker / AIMock. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# shellcheck source=lib.sh +source "$SCRIPT_DIR/lib.sh" + +OLLA_PORT="${OLLA_PORT:-40117}" +BACKEND_PORT="${BACKEND_PORT:-19912}" +OLLA_URL="http://127.0.0.1:${OLLA_PORT}" +BACKEND_URL="http://127.0.0.1:${BACKEND_PORT}" +OLLA_LOG="${TMPDIR:-/tmp}/olla-auth-basic.log" + +USERNAME="testuser" +PASSWORD="testpass99" +# Pre-compute the exact Authorization header value Olla will send. +ENCODED=$(printf '%s:%s' "$USERNAME" "$PASSWORD" | base64 | tr -d '\n') +EXPECTED_AUTH_VALUE="Basic ${ENCODED}" + +OLLA_PID="" +BACKEND_PID="" + +cleanup() { + kill_proc "$OLLA_PID" + kill_proc "$BACKEND_PID" +} +trap cleanup EXIT INT TERM + +echo "=== auth-basic: outbound HTTP Basic auth injection ===" +echo "Backend: ${BACKEND_URL} Olla: ${OLLA_URL}" +echo + +free_port "$BACKEND_PORT" +go run "$REPO_ROOT/test/cmd/mockbackend" \ + --addr "127.0.0.1:${BACKEND_PORT}" \ + --require-header "Authorization" \ + --require-value "${EXPECTED_AUTH_VALUE}" \ + >"${TMPDIR:-/tmp}/mockbackend-basic.log" 2>&1 & +BACKEND_PID=$! +wait_for_mockbackend "$BACKEND_URL" 15 + +CONFIG=$(mktemp "${TMPDIR:-/tmp}/olla-auth-basic-XXXXXX.yaml") +cat >"$CONFIG" <"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 1: correct basic credentials → 200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" == "200" ]]; then + pass "correct basic credentials → 200" +else + fail "correct basic credentials → 200" "got HTTP ${status}" +fi + +# ── restart with wrong password ─────────────────────────────────────────────── +kill_proc "$OLLA_PID" +OLLA_PID="" +free_port "$OLLA_PORT" + +CONFIG_BAD=$(mktemp "${TMPDIR:-/tmp}/olla-auth-basic-bad-XXXXXX.yaml") +sed "s/${PASSWORD}/wrongpassword/" "$CONFIG" >"$CONFIG_BAD" + +go run "$REPO_ROOT" --config "$CONFIG_BAD" >>"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 2: wrong password → non-200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" != "200" ]]; then + pass "wrong basic credentials propagate non-200 (got ${status})" +else + fail "wrong basic credentials propagate non-200" "expected non-200 but got 200" +fi + +rm -f "$CONFIG" "$CONFIG_BAD" +summarise "auth-basic" diff --git a/test/scripts/auth/auth-bearer.sh b/test/scripts/auth/auth-bearer.sh new file mode 100644 index 00000000..1079b06f --- /dev/null +++ b/test/scripts/auth/auth-bearer.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash +# auth-bearer.sh — proves Olla injects bearer tokens on outbound requests. +# +# Happy path: correct token reaches the mock backend → 200. +# Failure path: wrong token in config → mock backend rejects it → Olla returns +# a non-200 (502 or 401 propagated). +# +# Requires: go, curl, bash 4+ +# Does NOT require Docker / AIMock. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# shellcheck source=lib.sh +source "$SCRIPT_DIR/lib.sh" + +OLLA_PORT="${OLLA_PORT:-40115}" +BACKEND_PORT="${BACKEND_PORT:-19910}" +OLLA_URL="http://127.0.0.1:${OLLA_PORT}" +BACKEND_URL="http://127.0.0.1:${BACKEND_PORT}" +OLLA_LOG="${TMPDIR:-/tmp}/olla-auth-bearer.log" + +GOOD_TOKEN="test-token-bearer-abc123" + +OLLA_PID="" +BACKEND_PID="" + +cleanup() { + kill_proc "$OLLA_PID" + kill_proc "$BACKEND_PID" +} +trap cleanup EXIT INT TERM + +echo "=== auth-bearer: outbound bearer token injection ===" +echo "Backend: ${BACKEND_URL} Olla: ${OLLA_URL}" +echo + +# ── start mock backend ──────────────────────────────────────────────────────── +free_port "$BACKEND_PORT" +go run "$REPO_ROOT/test/cmd/mockbackend" \ + --addr "127.0.0.1:${BACKEND_PORT}" \ + --require-header "Authorization" \ + --require-value "Bearer ${GOOD_TOKEN}" \ + >"${TMPDIR:-/tmp}/mockbackend-bearer.log" 2>&1 & +BACKEND_PID=$! +wait_for_mockbackend "$BACKEND_URL" 15 + +# ── write a per-run config with the correct token ──────────────────────────── +CONFIG=$(mktemp "${TMPDIR:-/tmp}/olla-auth-bearer-XXXXXX.yaml") +cat >"$CONFIG" <"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 1: correct token → 200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" == "200" ]]; then + pass "correct bearer token → 200" +else + fail "correct bearer token → 200" "got HTTP ${status}" +fi + +# ── restart Olla with wrong token ──────────────────────────────────────────── +kill_proc "$OLLA_PID" +OLLA_PID="" +free_port "$OLLA_PORT" + +CONFIG_BAD=$(mktemp "${TMPDIR:-/tmp}/olla-auth-bearer-bad-XXXXXX.yaml") +# Replace only the token value so the rest of the config stays valid +sed "s/${GOOD_TOKEN}/wrong-token-xyz/" "$CONFIG" >"$CONFIG_BAD" + +go run "$REPO_ROOT" --config "$CONFIG_BAD" >>"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 2: wrong token → backend rejects → Olla returns non-200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" != "200" ]]; then + pass "wrong bearer token propagates non-200 (got ${status})" +else + fail "wrong bearer token propagates non-200" "expected non-200 but got 200" +fi + +rm -f "$CONFIG" "$CONFIG_BAD" +summarise "auth-bearer" diff --git a/test/scripts/auth/auth-env-fatal.sh b/test/scripts/auth/auth-env-fatal.sh new file mode 100644 index 00000000..59ce4063 --- /dev/null +++ b/test/scripts/auth/auth-env-fatal.sh @@ -0,0 +1,141 @@ +#!/usr/bin/env bash +# auth-env-fatal.sh — proves Olla refuses to start when a config references an +# environment variable that is not set. +# +# This is the most important safety test: a missing credential must produce a +# loud startup failure, not a silent zero-value that sends unauthenticated +# requests to production backends. +# +# The script asserts: +# 1. Olla exits non-zero within a few seconds (fatal startup error). +# 2. The error output mentions the endpoint name so the operator knows which +# endpoint's config is broken. +# +# Requires: go, bash 4+ +# Does NOT require Docker / AIMock or a running backend. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# shellcheck source=lib.sh +source "$SCRIPT_DIR/lib.sh" + +OLLA_PORT="${OLLA_PORT:-40119}" +# Port chosen to avoid collision with other auth test scripts; a backend is +# not actually required because Olla must fail before it ever connects. +BACKEND_PORT="${BACKEND_PORT:-19914}" +OLLA_LOG="${TMPDIR:-/tmp}/olla-auth-env-fatal.log" + +# Deliberately unset — must NOT be present when Olla starts. +# Use seconds + PID for uniqueness; %N (nanoseconds) is GNU-only and fails on BSD/macOS. +MISSING_VAR="OLLA_TEST_MISSING_TOKEN_$(date +%s)_$$" +unset "$MISSING_VAR" 2>/dev/null || true + +ENDPOINT_NAME="mock-env-fatal" + +echo "=== auth-env-fatal: missing env var must abort startup ===" +echo "Missing var: \${${MISSING_VAR}}" +echo "Endpoint name in config: ${ENDPOINT_NAME}" +echo + +CONFIG=$(mktemp "${TMPDIR:-/tmp}/olla-auth-env-fatal-XXXXXX.yaml") +cat >"$CONFIG" <"$OLLA_LOG" 2>&1 +EXIT_CODE=$? +set -e + +# Test 1: non-zero exit (fatal startup failure), but not a timeout. +# Exit 124 means the process was still running after 15s, which is its own failure mode. +if [[ $EXIT_CODE -eq 124 ]]; then + fail "Olla exited non-zero on missing env var" "process timed out after 15s, expected immediate startup failure" +elif [[ $EXIT_CODE -ne 0 ]]; then + pass "Olla exited non-zero on missing env var (exit ${EXIT_CODE})" +else + fail "Olla exited non-zero on missing env var" "got exit 0, startup should have aborted" +fi + +# Test 2: error output mentions the endpoint name +# This lets operators know which endpoint has the broken config, not just +# that some env var somewhere is missing. +if grep -qF "${ENDPOINT_NAME}" "$OLLA_LOG"; then + pass "error mentions endpoint name (${ENDPOINT_NAME})" +else + fail "error mentions endpoint name (${ENDPOINT_NAME})" "endpoint name not found in output" + echo "--- Olla output ---" >&2 + cat "$OLLA_LOG" >&2 + echo "-------------------" >&2 +fi + +# Test 3: error output mentions the missing variable name +if grep -qF "$MISSING_VAR" "$OLLA_LOG"; then + pass "error mentions missing variable name (${MISSING_VAR})" +else + fail "error mentions missing variable name (${MISSING_VAR})" "variable name not found in output" +fi + +rm -f "$CONFIG" +summarise "auth-env-fatal" diff --git a/test/scripts/auth/auth-headers-only.sh b/test/scripts/auth/auth-headers-only.sh new file mode 100644 index 00000000..e89dde91 --- /dev/null +++ b/test/scripts/auth/auth-headers-only.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +# auth-headers-only.sh — proves Olla injects arbitrary static headers on +# outbound requests when only the headers: map is configured (no auth: block). +# +# The mock backend enforces a custom header, confirming that custom headers +# travel to the backend even when no structured auth block is present. +# +# Requires: go, curl, bash 4+ +# Does NOT require Docker / AIMock. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +# shellcheck source=lib.sh +source "$SCRIPT_DIR/lib.sh" + +OLLA_PORT="${OLLA_PORT:-40118}" +BACKEND_PORT="${BACKEND_PORT:-19913}" +OLLA_URL="http://127.0.0.1:${OLLA_PORT}" +BACKEND_URL="http://127.0.0.1:${BACKEND_PORT}" +OLLA_LOG="${TMPDIR:-/tmp}/olla-auth-headers.log" + +CUSTOM_HEADER="X-Tenant-ID" +CUSTOM_VALUE="tenant-abc" + +OLLA_PID="" +BACKEND_PID="" + +cleanup() { + kill_proc "$OLLA_PID" + kill_proc "$BACKEND_PID" +} +trap cleanup EXIT INT TERM + +echo "=== auth-headers-only: static headers injection (no auth block) ===" +echo "Backend: ${BACKEND_URL} Olla: ${OLLA_URL}" +echo + +free_port "$BACKEND_PORT" +go run "$REPO_ROOT/test/cmd/mockbackend" \ + --addr "127.0.0.1:${BACKEND_PORT}" \ + --require-header "${CUSTOM_HEADER}" \ + --require-value "${CUSTOM_VALUE}" \ + >"${TMPDIR:-/tmp}/mockbackend-headers.log" 2>&1 & +BACKEND_PID=$! +wait_for_mockbackend "$BACKEND_URL" 15 + +CONFIG=$(mktemp "${TMPDIR:-/tmp}/olla-auth-headers-XXXXXX.yaml") +cat >"$CONFIG" <"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 1: custom header present → 200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" == "200" ]]; then + pass "custom header injected → backend accepted (200)" +else + fail "custom header injected → backend accepted (200)" "got HTTP ${status}" +fi + +# ── restart without the custom header ──────────────────────────────────────── +kill_proc "$OLLA_PID" +OLLA_PID="" +free_port "$OLLA_PORT" + +# Remove the headers: block entirely by generating a config without it +CONFIG_NO_HDR=$(mktemp "${TMPDIR:-/tmp}/olla-auth-headers-none-XXXXXX.yaml") +# Strip the headers block (both lines) +grep -v "headers:" "$CONFIG" | grep -v "${CUSTOM_HEADER}" >"$CONFIG_NO_HDR" + +go run "$REPO_ROOT" --config "$CONFIG_NO_HDR" >>"$OLLA_LOG" 2>&1 & +OLLA_PID=$! +wait_for_url "${OLLA_URL}/internal/health" 20 + +# Test 2: header absent → backend rejects → non-200 +status=$(http_status_for "${OLLA_URL}/olla/openai-compatible/v1/chat/completions") +if [[ "$status" != "200" ]]; then + pass "missing header propagates non-200 (got ${status})" +else + fail "missing header propagates non-200" "expected non-200 but got 200" +fi + +rm -f "$CONFIG" "$CONFIG_NO_HDR" +summarise "auth-headers-only" diff --git a/test/scripts/auth/lib.sh b/test/scripts/auth/lib.sh new file mode 100644 index 00000000..d80c8421 --- /dev/null +++ b/test/scripts/auth/lib.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +# Shared helpers for auth test scripts. +# Source this file; do not execute it directly. + +# Colour codes — match the project's ANSI palette +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RESET='\033[0m' + +TOTAL_TESTS=0 +PASSED_TESTS=0 +FAILED_TESTS=() + +pass() { + local name="$1" + TOTAL_TESTS=$((TOTAL_TESTS + 1)) + PASSED_TESTS=$((PASSED_TESTS + 1)) + printf "${GREEN}PASS${RESET}: %s\n" "$name" +} + +fail() { + local name="$1" + local reason="${2:-}" + TOTAL_TESTS=$((TOTAL_TESTS + 1)) + FAILED_TESTS+=("$name") + if [[ -n "$reason" ]]; then + printf "${RED}FAIL${RESET}: %s (%s)\n" "$name" "$reason" >&2 + else + printf "${RED}FAIL${RESET}: %s\n" "$name" >&2 + fi +} + +# summarise prints the final PASS/FAIL summary line and exits with the +# appropriate code. Call from a trap or at the end of each script. +summarise() { + local script_name="${1:-test}" + echo + echo "Results: ${PASSED_TESTS}/${TOTAL_TESTS} passed" + if [[ ${#FAILED_TESTS[@]} -eq 0 ]]; then + printf "${GREEN}PASS${RESET}: %s\n" "$script_name" + return 0 + else + printf "${RED}FAIL${RESET}: %s\n" "$script_name" + return 1 + fi +} + +# wait_for_url polls until the URL returns HTTP 200 or the timeout is reached. +wait_for_url() { + local url="$1" + local timeout="${2:-15}" + local attempt=0 + until curl -sf --max-time 2 "$url" >/dev/null 2>&1; do + attempt=$((attempt + 1)) + if [[ $attempt -ge $timeout ]]; then + printf "${RED}ERROR${RESET}: %s did not become available within %ss\n" "$url" "$timeout" >&2 + return 1 + fi + sleep 1 + done +} + +# wait_for_mockbackend polls the unauthenticated /health endpoint on the mock +# backend. Use this instead of wait_for_url for mockbackend instances because +# the auth-enforced /v1/models path returns 401 which curl -sf treats as failure. +wait_for_mockbackend() { + local base_url="$1" + local timeout="${2:-15}" + wait_for_url "${base_url}/health" "$timeout" +} + +# http_status_for issues a POST with optional bearer token and returns the +# HTTP status code. +http_status_for() { + local url="$1" + local token="${2:-}" + local extra_header="${3:-}" + local extra_value="${4:-}" + + local args=(-s -o /dev/null -w "%{http_code}" --max-time 10 + -X POST + -H "Content-Type: application/json" + -d '{"model":"mock-model","messages":[{"role":"user","content":"hi"}],"max_tokens":5}') + + if [[ -n "$token" ]]; then + args+=(-H "Authorization: Bearer $token") + fi + if [[ -n "$extra_header" && -n "$extra_value" ]]; then + args+=(-H "${extra_header}: ${extra_value}") + fi + + curl "${args[@]}" "$url" +} + +# kill_proc sends SIGTERM to a PID and waits for it to exit. +kill_proc() { + local pid="$1" + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + wait "$pid" 2>/dev/null || true + fi +} + +# free_port kills any process already listening on the given port so repeated +# test runs don't collide with stale backends from previous runs. +# Works on Linux, macOS, and Windows (Git Bash / MSYS2). +free_port() { + local port="$1" + local pids="" + + # Linux: ss is preferred; fall back to netstat -tlnp + if command -v ss >/dev/null 2>&1; then + pids=$(ss -tlnp "sport = :${port}" 2>/dev/null | grep -oP '(?<=pid=)\d+' || true) + fi + + # macOS: lsof + if [[ -z "$pids" ]] && command -v lsof >/dev/null 2>&1; then + pids=$(lsof -ti "tcp:${port}" 2>/dev/null || true) + fi + + # Windows (Git Bash): netstat -ano gives PID in last column + if [[ -z "$pids" ]] && command -v netstat >/dev/null 2>&1; then + pids=$(netstat -ano 2>/dev/null \ + | grep -E "[:.]${port}\s+.+LISTENING" \ + | awk '{print $NF}' || true) + fi + + if [[ -n "$pids" ]]; then + for pid in $pids; do + # On Windows, bash kill may fail; fall through to taskkill + kill "$pid" 2>/dev/null || \ + { command -v taskkill >/dev/null 2>&1 && taskkill /F /PID "$pid" >/dev/null 2>&1; } || true + done + sleep 0.8 + fi +}