66package upstreamswap
77
88import (
9+ "context"
910 "encoding/json"
1011 "errors"
1112 "fmt"
@@ -48,6 +49,10 @@ type MiddlewareParams struct {
4849// This allows lazy access to the storage, which may not be available at middleware creation time.
4950type StorageGetter func () storage.UpstreamTokenStorage
5051
52+ // RefresherGetter is a function that returns an upstream token refresher.
53+ // This allows lazy access to the refresher, which may not be available at middleware creation time.
54+ type RefresherGetter func () storage.UpstreamTokenRefresher
55+
5156// Middleware wraps the upstream swap middleware functionality.
5257type Middleware struct {
5358 middleware types.MiddlewareFunction
@@ -81,12 +86,13 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
8186 return fmt .Errorf ("invalid upstream swap configuration: %w" , err )
8287 }
8388
84- // Get storage getter from runner.
85- // The storage getter is a lazy accessor that checks storage availability at request time,
86- // so it's always non-nil. Actual storage availability is verified when processing requests.
89+ // Get storage getter and refresher getter from runner.
90+ // These are lazy accessors that check availability at request time,
91+ // so they're always non-nil. Actual availability is verified when processing requests.
8792 storageGetter := runner .GetUpstreamTokenStorage ()
93+ refresherGetter := runner .GetUpstreamTokenRefresher ()
8894
89- middleware := createMiddlewareFunc (cfg , storageGetter )
95+ middleware := createMiddlewareFunc (cfg , storageGetter , refresherGetter )
9096
9197 upstreamSwapMw := & Middleware {
9298 middleware : middleware ,
@@ -141,7 +147,7 @@ func createCustomInjector(headerName string) injectionFunc {
141147}
142148
143149// createMiddlewareFunc creates the actual middleware function.
144- func createMiddlewareFunc (cfg * Config , storageGetter StorageGetter ) types.MiddlewareFunction {
150+ func createMiddlewareFunc (cfg * Config , storageGetter StorageGetter , refresherGetter RefresherGetter ) types.MiddlewareFunction {
145151 // Determine injection strategy at startup time
146152 strategy := cfg .HeaderStrategy
147153 if strategy == "" {
@@ -188,35 +194,21 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
188194 return
189195 }
190196
191- // 4. Lookup upstream tokens
192- tokens , err := stor . GetUpstreamTokens (r .Context (), tsid )
197+ // 4. Lookup upstream tokens, refreshing if expired
198+ tokens , err := getOrRefreshUpstreamTokens (r .Context (), stor , tsid , refresherGetter )
193199 if err != nil {
194200 slog .Warn ("Failed to get upstream tokens" ,
195201 "middleware" , "upstreamswap" , "error" , err )
196- // Token is expired, was not found, or failed binding validation
197- // (e.g., subject/client mismatch). All three are client-attributable
198- // errors that require the caller to re-authenticate with the upstream IdP.
199202 if errors .Is (err , storage .ErrExpired ) ||
200203 errors .Is (err , storage .ErrNotFound ) ||
201204 errors .Is (err , storage .ErrInvalidBinding ) {
202205 writeUpstreamAuthRequired (w )
203206 return
204207 }
205- // Other storage errors: fail closed to avoid bypassing the token swap
206208 http .Error (w , "authentication service temporarily unavailable" , http .StatusServiceUnavailable )
207209 return
208210 }
209211
210- // 5. Check if expired
211- // Defense in depth: some storage implementations may return tokens
212- // without checking expiry (the interface does not require it).
213- if tokens .IsExpired (time .Now ()) {
214- slog .Warn ("Upstream tokens expired" ,
215- "middleware" , "upstreamswap" )
216- writeUpstreamAuthRequired (w )
217- return
218- }
219-
220212 // 6. Inject access token
221213 if tokens .AccessToken == "" {
222214 slog .Warn ("Access token is empty" ,
@@ -233,3 +225,71 @@ func createMiddlewareFunc(cfg *Config, storageGetter StorageGetter) types.Middle
233225 })
234226 }
235227}
228+
229+ // getOrRefreshUpstreamTokens retrieves upstream tokens from storage, automatically
230+ // refreshing them if expired and a refresh token is available.
231+ func getOrRefreshUpstreamTokens (
232+ ctx context.Context ,
233+ stor storage.UpstreamTokenStorage ,
234+ sessionID string ,
235+ refresherGetter RefresherGetter ,
236+ ) (* storage.UpstreamTokens , error ) {
237+ tokens , err := stor .GetUpstreamTokens (ctx , sessionID )
238+ if err != nil {
239+ // ErrExpired returns tokens (including refresh token) alongside the error.
240+ // Attempt a refresh before giving up.
241+ if errors .Is (err , storage .ErrExpired ) && tokens != nil {
242+ if refreshed := tryRefreshUpstreamTokens (ctx , sessionID , tokens , refresherGetter ); refreshed != nil {
243+ return refreshed , nil
244+ }
245+ }
246+ return nil , err
247+ }
248+
249+ // Defense in depth: some storage implementations may return tokens
250+ // without checking expiry (the interface does not require it).
251+ if tokens .IsExpired (time .Now ()) {
252+ if refreshed := tryRefreshUpstreamTokens (ctx , sessionID , tokens , refresherGetter ); refreshed != nil {
253+ return refreshed , nil
254+ }
255+ return nil , storage .ErrExpired
256+ }
257+
258+ return tokens , nil
259+ }
260+
261+ // tryRefreshUpstreamTokens attempts to refresh expired upstream tokens using the
262+ // configured refresher. Returns the refreshed tokens on success, or nil on failure.
263+ func tryRefreshUpstreamTokens (
264+ ctx context.Context ,
265+ sessionID string ,
266+ expired * storage.UpstreamTokens ,
267+ refresherGetter RefresherGetter ,
268+ ) * storage.UpstreamTokens {
269+ if expired .RefreshToken == "" {
270+ slog .Debug ("No refresh token available, cannot refresh upstream tokens" ,
271+ "middleware" , "upstreamswap" )
272+ return nil
273+ }
274+
275+ if refresherGetter == nil {
276+ return nil
277+ }
278+ refresher := refresherGetter ()
279+ if refresher == nil {
280+ slog .Debug ("Token refresher unavailable, cannot refresh upstream tokens" ,
281+ "middleware" , "upstreamswap" )
282+ return nil
283+ }
284+
285+ refreshed , err := refresher .RefreshAndStore (ctx , sessionID , expired )
286+ if err != nil {
287+ slog .Warn ("Upstream token refresh failed" ,
288+ "middleware" , "upstreamswap" , "error" , err )
289+ return nil
290+ }
291+
292+ slog .Debug ("Successfully refreshed upstream tokens" ,
293+ "middleware" , "upstreamswap" )
294+ return refreshed
295+ }
0 commit comments