Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ AUTH_USERS=

SOURCE_ALLOWLIST_CIDRS=
DEST_PORT_ALLOWLIST=443
DEST_HOST_ALLOWLIST=
DEST_SUFFIX_ALLOWLIST=.anthropic.com,.openai.com,.openrouter.ai,.chatgpt.com
DEST_HOST_ALLOWLIST=storage.googleapis.com
DEST_SUFFIX_ALLOWLIST=.claude.ai,.claude.com,.anthropic.com,.openai.com,.openrouter.ai,.chatgpt.com,.github.com,.githubusercontent.com,.githubcopilot.com,.ghcr.io
ALLOW_EMPTY_DEST_ALLOWLIST=false
ALLOW_PRIVATE_DESTINATIONS=false

Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ cp deploy/vps.example.yaml deploy/vps.yaml
- `.openai.com`
- `.openrouter.ai`
- `.chatgpt.com`
- `.github.com`
- `.githubusercontent.com`
- `.githubcopilot.com`
- `.ghcr.io`

默认示例还会通过 `runtime.dest_host_allowlist` 放行精确主机 `storage.googleapis.com`,用于兼容 Claude Code 仍在迁移中的 legacy 下载路径。
默认示例还会通过 `runtime.dest_host_allowlist` 放行精确主机 `storage.googleapis.com`,用于兼容 Claude Code 仍在迁移中的 legacy 下载路径;同时默认放行 Github 常见 API / 下载域族、远程 GitHub MCP 的 `.githubcopilot.com`,以及本地 Docker 方式安装 GitHub MCP 时会用到的 `.ghcr.io`,避免用户还要手动补白名单

不建议直接把白名单放空或改成近似全放开。更稳的做法是按产品域族放行,比如 Claude 用 `.anthropic.com`、`.claude.com`、`.claude.ai`,再补必要的精确 host。

Expand Down
6 changes: 5 additions & 1 deletion README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ The sample already includes common model service domains:
- `.openai.com`
- `.openrouter.ai`
- `.chatgpt.com`
- `.github.com`
- `.githubusercontent.com`
- `.githubcopilot.com`
- `.ghcr.io`

The default sample also allowlists the exact host `storage.googleapis.com` through `runtime.dest_host_allowlist` to cover Claude Code's legacy download path while that migration is still in progress.
The default sample also allowlists the exact host `storage.googleapis.com` through `runtime.dest_host_allowlist` to cover Claude Code's legacy download path while that migration is still in progress. It also includes the common GitHub API and download domain families, `.githubcopilot.com` for the remote GitHub MCP server, and `.ghcr.io` for the local Docker-based GitHub MCP install path so those flows do not require manual allowlist edits.

Avoid removing the allowlist entirely or approximating a wildcard. A better default is to allow vendor product domain families such as `.anthropic.com`, `.claude.com`, and `.claude.ai`, then add a small number of exact hosts only when needed.

Expand Down
4 changes: 4 additions & 0 deletions SEND_THIS_TO_LLM.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
- `.openai.com`
- `.openrouter.ai`
- `.chatgpt.com`
- `.github.com`
- `.githubusercontent.com`
- `.githubcopilot.com`
- `.ghcr.io`
- 默认精确 host 放行:
- `storage.googleapis.com`
- 客户端通过 SSH 隧道访问服务端
Expand Down
4 changes: 4 additions & 0 deletions deploy/vps.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ runtime:
- .openai.com
- .openrouter.ai
- .chatgpt.com
- .github.com
- .githubusercontent.com
- .githubcopilot.com
- .ghcr.io
max_conns_per_ip: 128
server_read_header_timeout: 10s
server_idle_timeout: 90s
Expand Down
28 changes: 28 additions & 0 deletions internal/deploy/deploy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,34 @@ func TestRenderVPSUsersHashesPlaintextPasswords(t *testing.T) {
}
}

func TestNormalizeVPSConfigAppliesDefaultGitHubAllowlist(t *testing.T) {
spec := VPSConfig{
Users: []ProxyUser{
{Username: "alice", Password: "change-me"},
},
}

normalized, err := normalizeVPSConfig(spec)
if err != nil {
t.Fatalf("normalizeVPSConfig() error = %v", err)
}
if !contains(normalized.Runtime.DestHostAllowlist, "storage.googleapis.com") {
t.Fatalf("default host allowlist = %#v, want storage.googleapis.com", normalized.Runtime.DestHostAllowlist)
}
if !contains(normalized.Runtime.DestSuffixAllowlist, ".github.com") {
t.Fatalf("default suffix allowlist = %#v, want .github.com", normalized.Runtime.DestSuffixAllowlist)
}
if !contains(normalized.Runtime.DestSuffixAllowlist, ".githubusercontent.com") {
t.Fatalf("default suffix allowlist = %#v, want .githubusercontent.com", normalized.Runtime.DestSuffixAllowlist)
}
if !contains(normalized.Runtime.DestSuffixAllowlist, ".githubcopilot.com") {
t.Fatalf("default suffix allowlist = %#v, want .githubcopilot.com", normalized.Runtime.DestSuffixAllowlist)
}
if !contains(normalized.Runtime.DestSuffixAllowlist, ".ghcr.io") {
t.Fatalf("default suffix allowlist = %#v, want .ghcr.io", normalized.Runtime.DestSuffixAllowlist)
}
}

func TestRenderClientArtifacts(t *testing.T) {
spec := ClientConfig{
InstallDir: "/home/test/.config/codex-gateway",
Expand Down
2 changes: 1 addition & 1 deletion internal/deploy/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func normalizeVPSConfig(spec VPSConfig) (VPSConfig, error) {
runtime.DestHostAllowlist = []string{"storage.googleapis.com"}
}
if len(runtime.DestSuffixAllowlist) == 0 {
runtime.DestSuffixAllowlist = []string{".claude.ai", ".claude.com", ".anthropic.com", ".openai.com", ".openrouter.ai", ".chatgpt.com"}
runtime.DestSuffixAllowlist = []string{".claude.ai", ".claude.com", ".anthropic.com", ".openai.com", ".openrouter.ai", ".chatgpt.com", ".github.com", ".githubusercontent.com", ".githubcopilot.com", ".ghcr.io"}
}
if runtime.MaxConnsPerIP == 0 {
runtime.MaxConnsPerIP = 128
Expand Down
103 changes: 91 additions & 12 deletions internal/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ var metricDurationBuckets = []time.Duration{
30 * time.Second,
}

const defaultResolvedDialFallbackDelay = 250 * time.Millisecond

type Metrics struct {
startedAt time.Time

Expand Down Expand Up @@ -215,6 +217,7 @@ type Options struct {
Metrics *Metrics

UpstreamDialTimeout time.Duration
UpstreamDialFallbackDelay time.Duration
UpstreamTLSHandshakeTimeout time.Duration
UpstreamResponseHeaderTimeout time.Duration
IdleTimeout time.Duration
Expand All @@ -230,12 +233,19 @@ type Handler struct {
limiter *limiter.ConcurrencyLimiter
metrics *Metrics

dialer net.Dialer
transport *http.Transport
tunnelIdleTimeout time.Duration
dialer net.Dialer
baseDialContext func(context.Context, string, string) (net.Conn, error)
transport *http.Transport
tunnelIdleTimeout time.Duration
resolvedDialFallbackDelay time.Duration
}

func NewHandler(options Options) *Handler {
fallbackDelay := options.UpstreamDialFallbackDelay
if fallbackDelay <= 0 {
fallbackDelay = defaultResolvedDialFallbackDelay
}

handler := &Handler{
appLogger: options.AppLogger,
auditLogger: options.AuditLogger,
Expand All @@ -246,8 +256,10 @@ func NewHandler(options Options) *Handler {
Timeout: options.UpstreamDialTimeout,
KeepAlive: 30 * time.Second,
},
tunnelIdleTimeout: options.TunnelIdleTimeout,
tunnelIdleTimeout: options.TunnelIdleTimeout,
resolvedDialFallbackDelay: fallbackDelay,
}
handler.baseDialContext = handler.dialer.DialContext

handler.transport = &http.Transport{
Proxy: nil,
Expand Down Expand Up @@ -403,22 +415,89 @@ func (h *Handler) dialContext(ctx context.Context, network, address string) (net
recordDialTrace(ctx, trace)
return conn, nil
}
return h.dialer.DialContext(ctx, network, address)
return h.baseDialContext(ctx, network, address)
}

func (h *Handler) dialResolvedTargets(ctx context.Context, network string, addresses []string) (net.Conn, dialTrace, error) {
if len(addresses) == 0 {
return nil, dialTrace{}, errors.New("no dial targets")
}

// We pre-resolve addresses for policy enforcement, so we need our own
// fallback race to avoid stalling on a slow first IP (for example, broken IPv6).
attemptCtx, cancel := context.WithCancel(ctx)
defer cancel()

type dialResult struct {
address string
conn net.Conn
err error
}

results := make(chan dialResult, len(addresses))
launched := 0
completed := 0
nextToLaunch := 0

startAttempt := func(address string) {
launched++
go func() {
conn, err := h.baseDialContext(attemptCtx, network, address)
if err == nil {
select {
case results <- dialResult{address: address, conn: conn}:
case <-attemptCtx.Done():
Comment on lines +447 to +449
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Close redundant successful fallback connections

When multiple dialResolvedTargets attempts succeed around the same time, each successful goroutine can enqueue its net.Conn into results, but the caller returns after the first success and never drains/cleans the remaining queued connections. In dual-stack or multi-A record cases this leaks upstream sockets per request until GC/fd pressure becomes significant, because the extra successful conns are not closed anywhere after being sent.

Useful? React with 👍 / 👎.

_ = conn.Close()
}
return
}
select {
case results <- dialResult{address: address, err: err}:
case <-attemptCtx.Done():
}
}()
}

startAttempt(addresses[nextToLaunch])
nextToLaunch++

var fallback <-chan time.Time
if nextToLaunch < len(addresses) {
fallback = time.After(h.resolvedDialFallbackDelay)
}

var lastErr error
for index, address := range addresses {
conn, err := h.dialer.DialContext(ctx, network, address)
if err == nil {
return conn, dialTrace{selectedAddress: address, attempts: index + 1}, nil
}
lastErr = err
if ctx.Err() != nil {

for completed < launched || nextToLaunch < len(addresses) {
select {
case result := <-results:
completed++
if result.err == nil {
cancel()
return result.conn, dialTrace{selectedAddress: result.address, attempts: launched}, nil
}
lastErr = result.err
if ctx.Err() != nil {
return nil, dialTrace{}, ctx.Err()
}
if completed == launched && nextToLaunch < len(addresses) {
startAttempt(addresses[nextToLaunch])
nextToLaunch++
if nextToLaunch < len(addresses) {
fallback = time.After(h.resolvedDialFallbackDelay)
} else {
fallback = nil
}
}
case <-fallback:
startAttempt(addresses[nextToLaunch])
nextToLaunch++
if nextToLaunch < len(addresses) {
fallback = time.After(h.resolvedDialFallbackDelay)
} else {
fallback = nil
}
case <-ctx.Done():
return nil, dialTrace{}, ctx.Err()
}
}
Expand Down
92 changes: 92 additions & 0 deletions internal/proxy/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proxy
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
Expand Down Expand Up @@ -154,6 +155,97 @@ func TestHTTPForwardFallsBackToNextResolvedAddress(t *testing.T) {
}
}

func TestHTTPForwardDoesNotBlockOnSlowFirstResolvedAddress(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("race-ok"))
}))
defer upstream.Close()

_, port := listenerAddrPort(t, upstream.Listener.Addr())

appLogs := &bytes.Buffer{}
auditLogs := &bytes.Buffer{}
loggers, err := logging.NewWithWriters("debug", "json", appLogs, auditLogs)
if err != nil {
t.Fatalf("NewWithWriters() error = %v", err)
}

hash, err := auth.HashPassword(testPassword, bcrypt.MinCost)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}

handler := NewHandler(Options{
AppLogger: loggers.App,
AuditLogger: loggers.Audit,
AccessLogEnabled: true,
AuthStore: auth.NewMapStore(map[string]string{testUsername: hash}),
Limiter: limiter.New(4),
UpstreamDialFallbackDelay: 25 * time.Millisecond,
Policy: Policy{
AllowedPorts: map[uint16]struct{}{port: {}},
HostMatcher: netutil.NewHostMatcher(map[string]struct{}{testHostName: {}}, nil),
Resolver: staticResolver{
testHostName: {
netip.MustParseAddr("127.0.0.2"),
netip.MustParseAddr("127.0.0.3"),
},
},
AllowPrivate: true,
},
SourceIPs: netutil.NewPrefixMatcher(nil),
Metrics: NewMetrics(),
UpstreamDialTimeout: 500 * time.Millisecond,
UpstreamTLSHandshakeTimeout: 500 * time.Millisecond,
UpstreamResponseHeaderTimeout: 500 * time.Millisecond,
IdleTimeout: 500 * time.Millisecond,
TunnelIdleTimeout: 500 * time.Millisecond,
})

actualDialer := net.Dialer{Timeout: time.Second}
upstreamAddr := upstream.Listener.Addr().String()
handler.baseDialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
switch host {
case "127.0.0.2":
select {
case <-time.After(300 * time.Millisecond):
return nil, context.DeadlineExceeded
case <-ctx.Done():
return nil, ctx.Err()
}
case "127.0.0.3":
return actualDialer.DialContext(ctx, network, upstreamAddr)
default:
return actualDialer.DialContext(ctx, network, address)
}
}

server := httptest.NewServer(handler)
defer server.Close()

startedAt := time.Now()
response := doRawHTTP(t, proxyAddress(server), rawForwardRequest(port))
elapsed := time.Since(startedAt)

if response.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", response.StatusCode, http.StatusOK)
}
if body := strings.TrimSpace(response.Body); body != "race-ok" {
t.Fatalf("body = %q, want %q", body, "race-ok")
}
if elapsed >= 250*time.Millisecond {
t.Fatalf("request took %v, want fallback before slow dial finishes", elapsed)
}
if !strings.Contains(auditLogs.String(), "\"resolved_ip\":\"127.0.0.3\"") {
t.Fatalf("audit log missing fast fallback address: %s", auditLogs.String())
}
}

func TestConcurrencyLimitReturns429(t *testing.T) {
started := make(chan struct{})
release := make(chan struct{})
Expand Down
Loading