Skip to content

Commit cd44bbf

Browse files
aron-muonclaude
andcommitted
Refresh upstream tokens transparently instead of forcing re-auth
The upstreamswap middleware returned 401 when upstream access tokens expired, forcing users through full re-authentication even though valid refresh tokens existed in storage. This happened because: 1. Redis/memory storage TTL was set to access token expiry, deleting the entry (and refresh token) when the access token expired 2. Storage returned nil on ErrExpired, discarding the refresh token 3. The middleware had no refresh path — only 401 Fix all three layers: - Add DefaultRefreshTokenTTL (30 days) to storage entry TTL so refresh tokens survive past access token expiry - Return token data alongside ErrExpired from storage so callers can use the refresh token - Add UpstreamTokenRefresher interface and implementation that wraps the upstream OAuth2Provider and storage - Plumb the refresher through Server → EmbeddedAuthServer → Runner → MiddlewareRunner - Update upstreamswap middleware to attempt refresh before returning 401, only requiring re-auth when the refresh token itself fails Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bc9b534 commit cd44bbf

File tree

15 files changed

+336
-60
lines changed

15 files changed

+336
-60
lines changed

pkg/auth/upstreamswap/middleware.go

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package upstreamswap
77

88
import (
9+
"context"
910
"encoding/json"
1011
"errors"
1112
"fmt"
@@ -48,6 +49,10 @@ type MiddlewareParams struct {
4849
// This allows lazy access to the storage, which may not be available at middleware creation time.
4950
type StorageGetter func() storage.UpstreamTokenStorage
5051

52+
// RefresherGetter is a function that returns an upstream token refresher.
53+
// This allows lazy access to the refresher, which may not be available at middleware creation time.
54+
type RefresherGetter func() storage.UpstreamTokenRefresher
55+
5156
// Middleware wraps the upstream swap middleware functionality.
5257
type Middleware struct {
5358
middleware types.MiddlewareFunction
@@ -81,12 +86,13 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
8186
return fmt.Errorf("invalid upstream swap configuration: %w", err)
8287
}
8388

84-
// Get storage getter from runner.
85-
// The storage getter is a lazy accessor that checks storage availability at request time,
86-
// so it's always non-nil. Actual storage availability is verified when processing requests.
89+
// Get storage getter and refresher getter from runner.
90+
// These are lazy accessors that check availability at request time,
91+
// so they're always non-nil. Actual availability is verified when processing requests.
8792
storageGetter := runner.GetUpstreamTokenStorage()
93+
refresherGetter := runner.GetUpstreamTokenRefresher()
8894

89-
middleware := createMiddlewareFunc(cfg, storageGetter)
95+
middleware := createMiddlewareFunc(cfg, storageGetter, refresherGetter)
9096

9197
upstreamSwapMw := &Middleware{
9298
middleware: middleware,
@@ -141,7 +147,7 @@ func createCustomInjector(headerName string) injectionFunc {
141147
}
142148

143149
// createMiddlewareFunc creates the actual middleware function.
144-
func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.MiddlewareFunction {
150+
func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter, refresherGetter RefresherGetter) types.MiddlewareFunction {
145151
// Determine injection strategy at startup time
146152
strategy := cfg.HeaderStrategy
147153
if strategy == "" {
@@ -188,35 +194,21 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
188194
return
189195
}
190196

191-
// 4. Lookup upstream tokens
192-
tokens, err := stor.GetUpstreamTokens(r.Context(), tsid)
197+
// 4. Lookup upstream tokens, refreshing if expired
198+
tokens, err := getOrRefreshUpstreamTokens(r.Context(), stor, tsid, refresherGetter)
193199
if err != nil {
194200
slog.Warn("Failed to get upstream tokens",
195201
"middleware", "upstreamswap", "error", err)
196-
// Token is expired, was not found, or failed binding validation
197-
// (e.g., subject/client mismatch). All three are client-attributable
198-
// errors that require the caller to re-authenticate with the upstream IdP.
199202
if errors.Is(err, storage.ErrExpired) ||
200203
errors.Is(err, storage.ErrNotFound) ||
201204
errors.Is(err, storage.ErrInvalidBinding) {
202205
writeUpstreamAuthRequired(w)
203206
return
204207
}
205-
// Other storage errors: fail closed to avoid bypassing the token swap
206208
http.Error(w, "authentication service temporarily unavailable", http.StatusServiceUnavailable)
207209
return
208210
}
209211

210-
// 5. Check if expired
211-
// Defense in depth: some storage implementations may return tokens
212-
// without checking expiry (the interface does not require it).
213-
if tokens.IsExpired(time.Now()) {
214-
slog.Warn("Upstream tokens expired",
215-
"middleware", "upstreamswap")
216-
writeUpstreamAuthRequired(w)
217-
return
218-
}
219-
220212
// 6. Inject access token
221213
if tokens.AccessToken == "" {
222214
slog.Warn("Access token is empty",
@@ -233,3 +225,71 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
233225
})
234226
}
235227
}
228+
229+
// getOrRefreshUpstreamTokens retrieves upstream tokens from storage, automatically
230+
// refreshing them if expired and a refresh token is available.
231+
func getOrRefreshUpstreamTokens(
232+
ctx context.Context,
233+
stor storage.UpstreamTokenStorage,
234+
sessionID string,
235+
refresherGetter RefresherGetter,
236+
) (*storage.UpstreamTokens, error) {
237+
tokens, err := stor.GetUpstreamTokens(ctx, sessionID)
238+
if err != nil {
239+
// ErrExpired returns tokens (including refresh token) alongside the error.
240+
// Attempt a refresh before giving up.
241+
if errors.Is(err, storage.ErrExpired) && tokens != nil {
242+
if refreshed := tryRefreshUpstreamTokens(ctx, sessionID, tokens, refresherGetter); refreshed != nil {
243+
return refreshed, nil
244+
}
245+
}
246+
return nil, err
247+
}
248+
249+
// Defense in depth: some storage implementations may return tokens
250+
// without checking expiry (the interface does not require it).
251+
if tokens.IsExpired(time.Now()) {
252+
if refreshed := tryRefreshUpstreamTokens(ctx, sessionID, tokens, refresherGetter); refreshed != nil {
253+
return refreshed, nil
254+
}
255+
return nil, storage.ErrExpired
256+
}
257+
258+
return tokens, nil
259+
}
260+
261+
// tryRefreshUpstreamTokens attempts to refresh expired upstream tokens using the
262+
// configured refresher. Returns the refreshed tokens on success, or nil on failure.
263+
func tryRefreshUpstreamTokens(
264+
ctx context.Context,
265+
sessionID string,
266+
expired *storage.UpstreamTokens,
267+
refresherGetter RefresherGetter,
268+
) *storage.UpstreamTokens {
269+
if expired.RefreshToken == "" {
270+
slog.Debug("No refresh token available, cannot refresh upstream tokens",
271+
"middleware", "upstreamswap")
272+
return nil
273+
}
274+
275+
if refresherGetter == nil {
276+
return nil
277+
}
278+
refresher := refresherGetter()
279+
if refresher == nil {
280+
slog.Debug("Token refresher unavailable, cannot refresh upstream tokens",
281+
"middleware", "upstreamswap")
282+
return nil
283+
}
284+
285+
refreshed, err := refresher.RefreshAndStore(ctx, sessionID, expired)
286+
if err != nil {
287+
slog.Warn("Upstream token refresh failed",
288+
"middleware", "upstreamswap", "error", err)
289+
return nil
290+
}
291+
292+
slog.Debug("Successfully refreshed upstream tokens",
293+
"middleware", "upstreamswap")
294+
return refreshed
295+
}

pkg/auth/upstreamswap/middleware_test.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func TestMiddleware_NoIdentity(t *testing.T) {
9999
}
100100

101101
cfg := &Config{}
102-
middleware := createMiddlewareFunc(cfg, storageGetter)
102+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
103103

104104
var nextCalled bool
105105
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
@@ -131,7 +131,7 @@ func TestMiddleware_NoTsidClaim(t *testing.T) {
131131
}
132132

133133
cfg := &Config{}
134-
middleware := createMiddlewareFunc(cfg, storageGetter)
134+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
135135

136136
var nextCalled bool
137137
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
@@ -166,7 +166,7 @@ func TestMiddleware_StorageUnavailable(t *testing.T) {
166166
}
167167

168168
cfg := &Config{}
169-
middleware := createMiddlewareFunc(cfg, storageGetter)
169+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
170170

171171
var nextCalled bool
172172
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
@@ -222,7 +222,7 @@ func TestMiddleware_ClientAttributableStorageErrors_Returns401(t *testing.T) {
222222
}
223223

224224
cfg := &Config{}
225-
middleware := createMiddlewareFunc(cfg, storageGetter)
225+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
226226

227227
var nextCalled bool
228228
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
@@ -268,7 +268,7 @@ func TestMiddleware_StorageError(t *testing.T) {
268268
}
269269

270270
cfg := &Config{}
271-
middleware := createMiddlewareFunc(cfg, storageGetter)
271+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
272272

273273
var nextCalled bool
274274
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
@@ -319,7 +319,7 @@ func TestMiddleware_SuccessfulSwap_AccessToken(t *testing.T) {
319319
cfg := &Config{
320320
HeaderStrategy: HeaderStrategyReplace,
321321
}
322-
middleware := createMiddlewareFunc(cfg, storageGetter)
322+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
323323

324324
var capturedAuthHeader string
325325
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
@@ -369,7 +369,7 @@ func TestMiddleware_CustomHeader(t *testing.T) {
369369
HeaderStrategy: HeaderStrategyCustom,
370370
CustomHeaderName: "X-Upstream-Token",
371371
}
372-
middleware := createMiddlewareFunc(cfg, storageGetter)
372+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
373373

374374
var capturedCustomHeader string
375375
var capturedAuthHeader string
@@ -422,7 +422,7 @@ func TestMiddleware_ExpiredTokens_Returns401(t *testing.T) {
422422
}
423423

424424
cfg := &Config{}
425-
middleware := createMiddlewareFunc(cfg, storageGetter)
425+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
426426

427427
var nextCalled bool
428428
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
@@ -473,7 +473,7 @@ func TestMiddleware_EmptySelectedToken(t *testing.T) {
473473
}
474474

475475
cfg := &Config{}
476-
middleware := createMiddlewareFunc(cfg, storageGetter)
476+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
477477

478478
var nextCalled bool
479479
var capturedAuthHeader string
@@ -563,7 +563,7 @@ func TestMiddlewareWithContext(t *testing.T) {
563563
}
564564

565565
cfg := &Config{}
566-
middleware := createMiddlewareFunc(cfg, storageGetter)
566+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
567567

568568
// Test that context is properly passed through
569569
var receivedCtx context.Context
@@ -677,6 +677,9 @@ func TestCreateMiddleware(t *testing.T) {
677677
mockRunner.EXPECT().GetUpstreamTokenStorage().Return(func() storage.UpstreamTokenStorage {
678678
return nil // Storage availability is checked at request time
679679
})
680+
mockRunner.EXPECT().GetUpstreamTokenRefresher().Return(func() storage.UpstreamTokenRefresher {
681+
return nil // Refresher availability is checked at request time
682+
})
680683
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(_ string, mw types.Middleware) {
681684
_, ok := mw.(*Middleware)
682685
assert.True(t, ok, "Expected middleware to be of type *upstreamswap.Middleware")
@@ -738,7 +741,7 @@ func TestMiddleware_TsidClaimWrongType(t *testing.T) {
738741
}
739742

740743
cfg := &Config{}
741-
middleware := createMiddlewareFunc(cfg, storageGetter)
744+
middleware := createMiddlewareFunc(cfg, storageGetter, nil)
742745

743746
var nextCalled bool
744747
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {

pkg/authserver/refresher.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package authserver
5+
6+
import (
7+
"context"
8+
"errors"
9+
"fmt"
10+
"log/slog"
11+
12+
"github.com/stacklok/toolhive/pkg/authserver/storage"
13+
"github.com/stacklok/toolhive/pkg/authserver/upstream"
14+
)
15+
16+
// upstreamTokenRefresher implements storage.UpstreamTokenRefresher by wrapping
17+
// an upstream OAuth2Provider (for token refresh) and UpstreamTokenStorage (for
18+
// persisting the refreshed tokens).
19+
type upstreamTokenRefresher struct {
20+
provider upstream.OAuth2Provider
21+
storage storage.UpstreamTokenStorage
22+
}
23+
24+
// RefreshAndStore refreshes expired upstream tokens using the stored refresh token,
25+
// persists the new tokens, and returns them.
26+
func (r *upstreamTokenRefresher) RefreshAndStore(
27+
ctx context.Context,
28+
sessionID string,
29+
expired *storage.UpstreamTokens,
30+
) (*storage.UpstreamTokens, error) {
31+
if expired == nil {
32+
return nil, errors.New("expired tokens are required")
33+
}
34+
if expired.RefreshToken == "" {
35+
return nil, errors.New("no refresh token available for upstream token refresh")
36+
}
37+
38+
slog.Debug("attempting upstream token refresh",
39+
"session_id", sessionID,
40+
"provider_id", expired.ProviderID,
41+
)
42+
43+
// Refresh tokens via the upstream provider
44+
newTokens, err := r.provider.RefreshTokens(ctx, expired.RefreshToken, expired.UpstreamSubject)
45+
if err != nil {
46+
return nil, fmt.Errorf("upstream token refresh failed: %w", err)
47+
}
48+
49+
// Build updated storage tokens preserving binding fields from the original
50+
updated := &storage.UpstreamTokens{
51+
ProviderID: expired.ProviderID,
52+
AccessToken: newTokens.AccessToken,
53+
RefreshToken: newTokens.RefreshToken,
54+
IDToken: newTokens.IDToken,
55+
ExpiresAt: newTokens.ExpiresAt,
56+
UserID: expired.UserID,
57+
UpstreamSubject: expired.UpstreamSubject,
58+
ClientID: expired.ClientID,
59+
}
60+
61+
// If the provider didn't rotate the refresh token, keep the original
62+
if updated.RefreshToken == "" {
63+
updated.RefreshToken = expired.RefreshToken
64+
}
65+
66+
// Store the refreshed tokens
67+
if err := r.storage.StoreUpstreamTokens(ctx, sessionID, updated); err != nil {
68+
// Log but still return the refreshed tokens — the current request can
69+
// proceed even if storage fails. The next request will retry the refresh.
70+
slog.Warn("failed to store refreshed upstream tokens",
71+
"session_id", sessionID,
72+
"error", err,
73+
)
74+
return updated, nil
75+
}
76+
77+
slog.Debug("upstream tokens refreshed successfully",
78+
"session_id", sessionID,
79+
"provider_id", expired.ProviderID,
80+
)
81+
82+
return updated, nil
83+
}

pkg/authserver/runner/embeddedauthserver.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ func (e *EmbeddedAuthServer) IDPTokenStorage() storage.UpstreamTokenStorage {
135135
return e.server.IDPTokenStorage()
136136
}
137137

138+
// UpstreamTokenRefresher returns a refresher that can refresh expired upstream
139+
// tokens using the upstream provider's refresh token grant.
140+
func (e *EmbeddedAuthServer) UpstreamTokenRefresher() storage.UpstreamTokenRefresher {
141+
return e.server.UpstreamTokenRefresher()
142+
}
143+
138144
// createKeyProvider creates a KeyProvider from SigningKeyRunConfig.
139145
// Returns a GeneratingProvider if config is nil or empty (development mode).
140146
func createKeyProvider(cfg *authserver.SigningKeyRunConfig) (keys.KeyProvider, error) {

pkg/authserver/server.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ type Server interface {
3131
// Returns nil if no upstream IDP is configured.
3232
IDPTokenStorage() storage.UpstreamTokenStorage
3333

34+
// UpstreamTokenRefresher returns a refresher that can refresh expired upstream
35+
// tokens using the upstream provider's refresh token grant.
36+
// Returns nil if no upstream IDP is configured.
37+
UpstreamTokenRefresher() storage.UpstreamTokenRefresher
38+
3439
// Close releases resources held by the server.
3540
Close() error
3641
}

pkg/authserver/server_impl.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ func (s *server) IDPTokenStorage() storage.UpstreamTokenStorage {
161161
return s.storage
162162
}
163163

164+
// UpstreamTokenRefresher returns a refresher that wraps the upstream provider
165+
// and storage to transparently refresh expired upstream tokens.
166+
func (s *server) UpstreamTokenRefresher() storage.UpstreamTokenRefresher {
167+
if s.upstreamIDP == nil {
168+
return nil
169+
}
170+
return &upstreamTokenRefresher{
171+
provider: s.upstreamIDP,
172+
storage: s.storage,
173+
}
174+
}
175+
164176
// Close releases resources held by the server.
165177
func (s *server) Close() error {
166178
slog.Debug("closing OAuth authorization server")

0 commit comments

Comments
 (0)