diff --git a/api_test.go b/api_test.go index d52c88e..9131854 100644 --- a/api_test.go +++ b/api_test.go @@ -90,6 +90,18 @@ func TestWithOAuth(t *testing.T) { t.Fatal("MCP server creation failed") } + req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != http.StatusFound { + t.Fatalf("Expected /callback compatibility redirect to return 302, got %d", w.Code) + } + + if location := w.Header().Get("Location"); location != "/oauth/callback?code=test-code&state=test-state" { + t.Fatalf("Expected redirect to /oauth/callback preserving query params, got %q", location) + } + t.Logf("✅ WithOAuth() works in proxy mode") }) @@ -286,6 +298,10 @@ func TestServerWrapHandler(t *testing.T) { t.Errorf("Expected WWW-Authenticate header with error, got: %s", authHeader) } + if !strings.Contains(authHeader, "resource_metadata=") { + t.Errorf("Expected WWW-Authenticate header to include resource_metadata, got: %s", authHeader) + } + if !strings.Contains(w.Body.String(), "invalid_token") { t.Errorf("Expected JSON error response, got: %s", w.Body.String()) } diff --git a/fixed_redirect_test.go b/fixed_redirect_test.go index a4ef20a..6c8cb8a 100644 --- a/fixed_redirect_test.go +++ b/fixed_redirect_test.go @@ -5,6 +5,46 @@ import ( "testing" ) +func TestConfiguredFixedRedirectURICompatibility(t *testing.T) { + tests := []struct { + name string + config *OAuth2Config + expected string + }{ + { + name: "explicit fixed redirect URI wins", + config: &OAuth2Config{ + FixedRedirectURI: "https://server.example.com/oauth/callback", + RedirectURIs: "https://old.example.com/oauth/callback", + }, + expected: "https://server.example.com/oauth/callback", + }, + { + name: "single legacy redirect URI acts as fixed redirect", + config: &OAuth2Config{ + RedirectURIs: "https://server.example.com/oauth/callback", + }, + expected: "https://server.example.com/oauth/callback", + }, + { + name: "multiple redirect URIs do not imply fixed redirect", + config: &OAuth2Config{ + RedirectURIs: "https://a.example.com/callback,https://b.example.com/callback", + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := &OAuth2Handler{config: tt.config} + if got := handler.configuredFixedRedirectURI(); got != tt.expected { + t.Fatalf("configuredFixedRedirectURI() = %q, want %q", got, tt.expected) + } + }) + } +} + func TestFixedRedirectModeLocalhostOnly(t *testing.T) { key := make([]byte, 32) _, _ = rand.Read(key) diff --git a/handlers.go b/handlers.go index 51eea76..639537f 100644 --- a/handlers.go +++ b/handlers.go @@ -350,9 +350,11 @@ func (h *OAuth2Handler) HandleAuthorize(w http.ResponseWriter, r *http.Request) } } - if redirectURI == "" && h.config.FixedRedirectURI != "" { + fixedRedirectURI := h.configuredFixedRedirectURI() + + if redirectURI == "" && fixedRedirectURI != "" { // Fixed redirect mode: Use server's redirect URI to OAuth provider, proxy back to client - h.logger.Info("OAuth2: Fixed redirect mode - using server URI: %s (will proxy to client: %s)", h.config.FixedRedirectURI, clientRedirectURI) + h.logger.Info("OAuth2: Fixed redirect mode - using server URI: %s (will proxy to client: %s)", fixedRedirectURI, clientRedirectURI) // Validate client redirect URI format and security if clientRedirectURI == "" { @@ -396,7 +398,7 @@ func (h *OAuth2Handler) HandleAuthorize(w http.ResponseWriter, r *http.Request) http.Error(w, "Invalid redirect_uri for fixed redirect mode", http.StatusBadRequest) return } - redirectURI = strings.TrimSpace(h.config.FixedRedirectURI) + redirectURI = fixedRedirectURI h.logger.Info("OAuth2: Validated client redirect URI for proxy: %s", clientRedirectURI) // For fixed redirect mode, create signed state with client redirect URI // Create state data with redirect URI @@ -490,7 +492,7 @@ func (h *OAuth2Handler) HandleCallback(w http.ResponseWriter, r *http.Request) { } // If using fixed redirect URI, handle proxy callback - if h.config.FixedRedirectURI != "" { + if h.configuredFixedRedirectURI() != "" { // Verify and decode signed state parameter stateData, err := h.verifyState(state) if err != nil { @@ -610,8 +612,8 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) { // Set redirect URI for token exchange redirectURI := clientRedirectURI - if h.config.FixedRedirectURI != "" && !h.isValidRedirectURI(clientRedirectURI) { - redirectURI = strings.TrimSpace(h.config.FixedRedirectURI) + if fixedRedirectURI := h.configuredFixedRedirectURI(); fixedRedirectURI != "" && !h.isValidRedirectURI(clientRedirectURI) { + redirectURI = fixedRedirectURI h.logger.Info("OAuth2: Token exchange using fixed redirect URI: %s", redirectURI) } @@ -673,35 +675,55 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) { h.logger.Info("OAuth2: Token exchange successful") - // Build response + response := h.buildTokenResponse(token) + + // Send response + h.addSecurityHeaders(w) + w.Header().Set("Content-Type", "application/json") + + if err := json.NewEncoder(w).Encode(response); err != nil { + h.logger.Error("OAuth2: Failed to encode token response: %v", err) + } +} + +// buildTokenResponse builds an RFC 6749-style token response with provider-specific behavior. +func (h *OAuth2Handler) buildTokenResponse(token *oauth2.Token) map[string]interface{} { + accessToken := token.AccessToken + + if h.config != nil && h.config.Provider == "google" { + if idToken, ok := token.Extra("id_token").(string); ok && idToken != "" { + if !looksLikeJWT(token.AccessToken) { + if h.logger != nil { + h.logger.Info("OAuth2: Google provider detected opaque access token, using id_token as access_token for downstream compatibility") + } + accessToken = idToken + } + } + } + response := map[string]interface{}{ - "access_token": token.AccessToken, + "access_token": accessToken, "token_type": token.TokenType, "expires_in": int(time.Until(token.Expiry).Seconds()), } - // Add optional fields if token.RefreshToken != "" { response["refresh_token"] = token.RefreshToken } - // Add ID token if present if idToken, ok := token.Extra("id_token").(string); ok { response["id_token"] = idToken } - // Add scope if present if scope, ok := token.Extra("scope").(string); ok { response["scope"] = scope } - // Send response - h.addSecurityHeaders(w) - w.Header().Set("Content-Type", "application/json") + return response +} - if err := json.NewEncoder(w).Encode(response); err != nil { - h.logger.Error("OAuth2: Failed to encode token response: %v", err) - } +func looksLikeJWT(token string) bool { + return strings.Count(token, ".") == 2 } // showSuccessPage displays a success page after OAuth completion @@ -736,6 +758,22 @@ func truncateString(s string, maxLen int) string { return s[:maxLen] + "..." } +// configuredFixedRedirectURI returns the effective fixed redirect URI. +// Prefer the explicit FixedRedirectURI, but preserve the historical behavior +// where a single RedirectURIs value also acted as fixed redirect mode. +func (h *OAuth2Handler) configuredFixedRedirectURI() string { + if h == nil || h.config == nil { + return "" + } + if trimmed := strings.TrimSpace(h.config.FixedRedirectURI); trimmed != "" { + return trimmed + } + if redirects := strings.TrimSpace(h.config.RedirectURIs); redirects != "" && !strings.Contains(redirects, ",") { + return redirects + } + return "" +} + // pkceTransport adds PKCE code_verifier to token exchange requests type pkceTransport struct { base http.RoundTripper diff --git a/handlers_token_response_test.go b/handlers_token_response_test.go new file mode 100644 index 0000000..c552f88 --- /dev/null +++ b/handlers_token_response_test.go @@ -0,0 +1,97 @@ +package oauth + +import ( + "crypto/sha256" + "fmt" + "testing" + "time" + + "golang.org/x/oauth2" +) + +func tokenHash(v interface{}) string { + s, _ := v.(string) + if s == "" { + return "" + } + return fmt.Sprintf("%x", sha256.Sum256([]byte(s)))[:16] +} + +func TestBuildTokenResponse(t *testing.T) { + tests := []struct { + name string + provider string + accessToken string + idToken string + refreshToken string + scope string + expectedAccessToken string + expectedIDToken string + expectedRefreshToken string + }{ + { + name: "google opaque access token uses id token", + provider: "google", + accessToken: "ya29.a0ARW5m7Opaque", + idToken: "header.payload.signature", + refreshToken: "refresh-token", + scope: "openid profile email", + expectedAccessToken: "header.payload.signature", + expectedIDToken: "header.payload.signature", + expectedRefreshToken: "refresh-token", + }, + { + name: "google jwt access token preserved", + provider: "google", + accessToken: "jwt.access.token", + idToken: "id.token.value", + expectedAccessToken: "jwt.access.token", + expectedIDToken: "id.token.value", + }, + { + name: "non google access token preserved", + provider: "okta", + accessToken: "opaque-token-value", + idToken: "id.token.value", + expectedAccessToken: "opaque-token-value", + expectedIDToken: "id.token.value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := &OAuth2Handler{ + config: &OAuth2Config{Provider: tt.provider}, + logger: &defaultLogger{}, + } + + extra := map[string]interface{}{ + "id_token": tt.idToken, + } + if tt.scope != "" { + extra["scope"] = tt.scope + } + + token := (&oauth2.Token{ + AccessToken: tt.accessToken, + TokenType: "Bearer", + RefreshToken: tt.refreshToken, + Expiry: time.Now().Add(time.Hour), + }).WithExtra(extra) + + response := handler.buildTokenResponse(token) + + if got := response["access_token"]; got != tt.expectedAccessToken { + t.Fatalf("access_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedAccessToken)) + } + + if got := response["id_token"]; got != tt.expectedIDToken { + t.Fatalf("id_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedIDToken)) + } + + if got := response["refresh_token"]; got != tt.expectedRefreshToken && !(got == nil && tt.expectedRefreshToken == "") { + t.Fatalf("refresh_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedRefreshToken)) + } + }) + } +} diff --git a/oauth.go b/oauth.go index f6afb50..586cdce 100644 --- a/oauth.go +++ b/oauth.go @@ -94,6 +94,7 @@ func (s *Server) RegisterHandlers(mux *http.ServeMux) { mux.HandleFunc("/.well-known/oauth-protected-resource", s.handler.HandleProtectedResourceMetadata) mux.HandleFunc("/.well-known/jwks.json", s.handler.HandleJWKS) mux.HandleFunc("/oauth/authorize", s.handler.HandleAuthorize) + mux.HandleFunc("/callback", s.handler.HandleCallbackRedirect) mux.HandleFunc("/oauth/callback", s.handler.HandleCallback) mux.HandleFunc("/oauth/token", s.handler.HandleToken) mux.HandleFunc("/oauth/register", s.handler.HandleRegister) @@ -277,11 +278,15 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " { - s.logger.Info("OAuth: No bearer token provided, returning 401 with discovery info") + s.logger.Debug("OAuth: No bearer token provided, returning 401 with discovery info") metadataURL := s.GetProtectedResourceMetadataURL() - w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token"`) - w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL)) + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + // Preserve the historical Bearer challenge here for clients that + // key off invalid_token when bootstrapping OAuth discovery. + `Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token", resource_metadata="%s"`, + metadataURL, + )) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) @@ -301,8 +306,10 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler { s.logger.Info("OAuth: Token validation failed: %v", err) metadataURL := s.GetProtectedResourceMetadataURL() - w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed"`) - w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL)) + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`, + metadataURL, + )) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) diff --git a/provider/provider.go b/provider/provider.go index ab6a20f..bfd8f91 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -3,8 +3,12 @@ package provider import ( "context" "crypto/tls" + "encoding/json" + "errors" "fmt" + "io" "net/http" + "net/url" "strings" "sync" "time" @@ -55,13 +59,17 @@ type HMACValidator struct { // OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure) type OIDCValidator struct { - verifier *oidc.IDTokenVerifier - provider *oidc.Provider - audience string - TokenValidators []func(claims jwt.MapClaims) error - logger Logger + verifier *oidc.IDTokenVerifier + provider *oidc.Provider + audience string + providerName string + skipAudienceCheck bool + TokenValidators []func(claims jwt.MapClaims) error + logger Logger } +var googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo" + // Initialize sets up the HMAC validator with JWT secret and audience func (v *HMACValidator) Initialize(cfg *Config) error { v.secretOnce.Do(func() { @@ -180,6 +188,8 @@ func (v *OIDCValidator) Initialize(cfg *Config) error { v.logger = &noOpLogger{} } v.audience = cfg.Audience + v.providerName = strings.ToLower(cfg.Provider) + v.skipAudienceCheck = cfg.SkipAudienceCheck // Use standard library context with timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -233,14 +243,22 @@ func (v *OIDCValidator) Initialize(cfg *Config) error { func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) { // Remove Bearer prefix if present tokenString = strings.TrimPrefix(tokenString, "Bearer ") + tokenString = strings.TrimSpace(tokenString) // Use incoming context with timeout for OIDC provider call ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() + if v.shouldUseGoogleOpaqueFallback(tokenString, nil) { + return v.validateGoogleOpaqueToken(ctx, tokenString) + } + // go-oidc handles RSA signature validation, JWKS fetching, and key rotation idToken, err := v.verifier.Verify(ctx, tokenString) if err != nil { + if v.shouldUseGoogleOpaqueFallback(tokenString, err) { + return v.validateGoogleOpaqueToken(ctx, tokenString) + } return nil, fmt.Errorf("token verification failed: %w", err) } @@ -270,11 +288,8 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) ( } // Run extra validation functions - for i, fn := range v.TokenValidators { - err := fn(rawClaims) - if err != nil { - return nil, fmt.Errorf("validation function %d failed with error: %w", i, err) - } + if err := v.runTokenValidators(rawClaims); err != nil { + return nil, err } user := &User{ @@ -293,6 +308,103 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) ( return user, nil } +func (v *OIDCValidator) validateGoogleOpaqueToken(ctx context.Context, tokenString string) (*User, error) { + endpoint := fmt.Sprintf("%s?access_token=%s", googleTokenInfoURL, url.QueryEscape(tokenString)) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("failed to create google tokeninfo request: %w", err) + } + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + var urlErr *url.Error + if errors.As(err, &urlErr) { + err = urlErr.Err + } + return nil, fmt.Errorf("google tokeninfo request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed reading google tokeninfo response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("google tokeninfo validation failed: status %d", resp.StatusCode) + } + + var claims jwt.MapClaims + if err := json.Unmarshal(body, &claims); err != nil { + return nil, fmt.Errorf("failed parsing google tokeninfo response: %w", err) + } + + if err := v.runTokenValidators(claims); err != nil { + return nil, err + } + + return v.userFromGoogleTokenInfoClaims(claims) +} + +func (v *OIDCValidator) userFromGoogleTokenInfoClaims(claims jwt.MapClaims) (*User, error) { + subject, _ := claims["sub"].(string) + if subject == "" { + subject, _ = claims["user_id"].(string) + } + if subject == "" { + return nil, fmt.Errorf("missing subject in token") + } + + email, _ := claims["email"].(string) + username := email + if username == "" { + username = subject + } + + return &User{ + Subject: subject, + Username: username, + Email: email, + }, nil +} + +func (v *OIDCValidator) runTokenValidators(claims jwt.MapClaims) error { + for i, fn := range v.TokenValidators { + if err := fn(claims); err != nil { + return fmt.Errorf("validation function %d failed with error: %w", i, err) + } + } + return nil +} + +func (v *OIDCValidator) shouldUseGoogleOpaqueFallback(token string, err error) bool { + if v.providerName != "google" || !isGoogleOpaqueCandidate(token) { + return false + } + if err == nil { + return !looksLikeJWT(token) + } + return isMalformedJWTError(err) +} + +func isGoogleOpaqueCandidate(token string) bool { + return strings.HasPrefix(token, "ya29.") +} + +func looksLikeJWT(token string) bool { + return strings.Count(token, ".") == 2 +} + +func isMalformedJWTError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "malformed jwt") || strings.Contains(msg, "compact JWS format must have three parts") +} + // validateAudience validates the audience claim matches the expected value for OIDC tokens func (v *OIDCValidator) validateAudience(claims jwt.MapClaims) error { // Extract audience claim (can be string or []string) diff --git a/provider/provider_test.go b/provider/provider_test.go index 7f0bd25..4026a6d 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -2,6 +2,8 @@ package provider import ( "context" + "net/http" + "net/http/httptest" "strings" "testing" "time" @@ -371,3 +373,298 @@ func TestOIDCValidator_InitializationValidation(t *testing.T) { }) } } + +func TestOIDCValidator_GoogleTokenInfoClaimsValidation(t *testing.T) { + tests := []struct { + name string + validator *OIDCValidator + claims jwt.MapClaims + expectErr bool + errContains string + expectedSubject string + expectedUsername string + expectedEmail string + }{ + { + name: "valid claims with email", + validator: &OIDCValidator{ + audience: "my-client-id.apps.googleusercontent.com", + skipAudienceCheck: false, + }, + claims: jwt.MapClaims{ + "aud": "my-client-id.apps.googleusercontent.com", + "sub": "user-123", + "email": "user@example.com", + }, + expectedSubject: "user-123", + expectedUsername: "user@example.com", + expectedEmail: "user@example.com", + }, + { + name: "valid claims without email", + validator: &OIDCValidator{ + audience: "my-client-id.apps.googleusercontent.com", + skipAudienceCheck: false, + }, + claims: jwt.MapClaims{ + "aud": "my-client-id.apps.googleusercontent.com", + "sub": "user-456", + }, + expectedSubject: "user-456", + expectedUsername: "user-456", + expectedEmail: "", + }, + { + name: "valid tokeninfo access token claims with user_id fallback", + validator: &OIDCValidator{ + audience: "my-client-id.apps.googleusercontent.com", + skipAudienceCheck: false, + }, + claims: jwt.MapClaims{ + "aud": "my-client-id.apps.googleusercontent.com", + "user_id": "google-user-123", + "email": "user@example.com", + "scope": "openid email profile", + "issued_to": "my-client-id.apps.googleusercontent.com", + }, + expectedSubject: "google-user-123", + expectedUsername: "user@example.com", + expectedEmail: "user@example.com", + }, + { + name: "missing subject", + validator: &OIDCValidator{ + audience: "my-client-id.apps.googleusercontent.com", + skipAudienceCheck: false, + }, + claims: jwt.MapClaims{ + "aud": "my-client-id.apps.googleusercontent.com", + }, + expectErr: true, + errContains: "missing subject", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, err := tt.validator.userFromGoogleTokenInfoClaims(tt.claims) + + if tt.expectErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Fatalf("expected error containing %q, got %q", tt.errContains, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if user.Subject != tt.expectedSubject { + t.Fatalf("expected subject %q, got %q", tt.expectedSubject, user.Subject) + } + + if user.Username != tt.expectedUsername { + t.Fatalf("expected username %q, got %q", tt.expectedUsername, user.Username) + } + + if user.Email != tt.expectedEmail { + t.Fatalf("expected email %q, got %q", tt.expectedEmail, user.Email) + } + }) + } +} + +func TestOIDCValidator_ValidateGoogleOpaqueTokenRunsTokenValidators(t *testing.T) { + originalURL := googleTokenInfoURL + t.Cleanup(func() { + googleTokenInfoURL = originalURL + }) + + requestErrCh := make(chan string, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.URL.Query().Get("access_token"); got != "opaque-token" { + requestErrCh <- got + http.Error(w, "unexpected access_token", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"aud":"my-client-id.apps.googleusercontent.com","user_id":"google-user-123","email":"user@example.com"}`)) + })) + defer server.Close() + + googleTokenInfoURL = server.URL + + t.Run("validator success", func(t *testing.T) { + called := false + validator := &OIDCValidator{ + TokenValidators: []func(claims jwt.MapClaims) error{ + func(claims jwt.MapClaims) error { + called = true + if claims["aud"] != "my-client-id.apps.googleusercontent.com" { + t.Fatalf("unexpected audience claim: %v", claims["aud"]) + } + return nil + }, + }, + } + + user, err := validator.validateGoogleOpaqueToken(context.Background(), "opaque-token") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case got := <-requestErrCh: + t.Fatalf("expected access_token query parameter %q, got %q", "opaque-token", got) + default: + } + if !called { + t.Fatal("expected token validators to be called") + } + if user.Subject != "google-user-123" { + t.Fatalf("expected subject %q, got %q", "google-user-123", user.Subject) + } + }) + + t.Run("validator failure", func(t *testing.T) { + validator := &OIDCValidator{ + TokenValidators: []func(claims jwt.MapClaims) error{ + func(claims jwt.MapClaims) error { + return validatorError("audience rejected") + }, + }, + } + + _, err := validator.validateGoogleOpaqueToken(context.Background(), "opaque-token") + if err == nil { + t.Fatal("expected error, got nil") + } + select { + case got := <-requestErrCh: + t.Fatalf("expected access_token query parameter %q, got %q", "opaque-token", got) + default: + } + if !strings.Contains(err.Error(), "validation function 0 failed with error: audience rejected") { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestOIDCValidator_ValidateGoogleOpaqueTokenAudienceHandling(t *testing.T) { + originalURL := googleTokenInfoURL + t.Cleanup(func() { + googleTokenInfoURL = originalURL + }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"aud":"different-audience","user_id":"google-user-123","email":"user@example.com"}`)) + })) + defer server.Close() + + googleTokenInfoURL = server.URL + + t.Run("skip audience check", func(t *testing.T) { + validator := &OIDCValidator{ + audience: "my-client-id.apps.googleusercontent.com", + skipAudienceCheck: true, + } + + user, err := validator.validateGoogleOpaqueToken(context.Background(), "opaque-token") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.Subject != "google-user-123" { + t.Fatalf("expected subject %q, got %q", "google-user-123", user.Subject) + } + }) + + t.Run("enforce audience check", func(t *testing.T) { + validator := &OIDCValidator{ + audience: "my-client-id.apps.googleusercontent.com", + TokenValidators: []func(claims jwt.MapClaims) error{}, + } + validator.TokenValidators = append(validator.TokenValidators, validator.validateAudience) + + _, err := validator.validateGoogleOpaqueToken(context.Background(), "opaque-token") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "invalid audience: expected my-client-id.apps.googleusercontent.com, got different-audience") { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestOIDCValidator_ValidateGoogleOpaqueTokenNetworkErrorDoesNotLeakToken(t *testing.T) { + originalURL := googleTokenInfoURL + t.Cleanup(func() { + googleTokenInfoURL = originalURL + }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + googleTokenInfoURL = server.URL + server.Close() + + token := "ya29.sensitive-access-token" + validator := &OIDCValidator{} + + _, err := validator.validateGoogleOpaqueToken(context.Background(), token) + if err == nil { + t.Fatal("expected error, got nil") + } + if strings.Contains(err.Error(), token) { + t.Fatalf("expected sanitized error without token, got: %v", err) + } +} + +func TestOIDCValidator_ShouldUseGoogleOpaqueFallback(t *testing.T) { + validator := &OIDCValidator{providerName: "google"} + + tests := []struct { + name string + token string + err error + want bool + }{ + { + name: "google opaque token candidate without verifier error", + token: "ya29.a0ARW5m7Opaque", + want: true, + }, + { + name: "non google opaque token is rejected", + token: "opaque-token", + want: false, + }, + { + name: "malformed jwt error only falls back for google candidate", + token: "ya29.a0ARW5m7Opaque", + err: validatorError("malformed jwt"), + want: true, + }, + { + name: "malformed jwt error does not fall back for arbitrary token", + token: "opaque-token", + err: validatorError("malformed jwt"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validator.shouldUseGoogleOpaqueFallback(tt.token, tt.err); got != tt.want { + t.Fatalf("shouldUseGoogleOpaqueFallback(%q, %v) = %v, want %v", tt.token, tt.err, got, tt.want) + } + }) + } +} + +type validatorError string + +func (e validatorError) Error() string { + return string(e) +}