Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ func TestServerWrapHandler(t *testing.T) {
t.Errorf("Expected WWW-Authenticate header with error, got: %s", authHeader)
}

if !strings.Contains(authHeader, "resource_metadata=") {
t.Errorf("Expected WWW-Authenticate header to include resource_metadata, got: %s", authHeader)
}

if !strings.Contains(w.Body.String(), "invalid_token") {
t.Errorf("Expected JSON error response, got: %s", w.Body.String())
}
Expand Down
46 changes: 33 additions & 13 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,37 +610,57 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) {

h.logger.Info("OAuth2: Token exchange successful")

// Build response
response := h.buildTokenResponse(token)

// Send response
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)

if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
}
}

// buildTokenResponse builds an RFC 6749-style token response with provider-specific behavior.
func (h *OAuth2Handler) buildTokenResponse(token *oauth2.Token) map[string]interface{} {
accessToken := token.AccessToken

if h.config != nil && h.config.Provider == "google" {
if idToken, ok := token.Extra("id_token").(string); ok && idToken != "" {
if !looksLikeJWT(token.AccessToken) {
if h.logger != nil {
h.logger.Info("OAuth2: Google provider detected opaque access token, using id_token as access_token for downstream compatibility")
}
accessToken = idToken
}
}
}

response := map[string]interface{}{
"access_token": token.AccessToken,
"access_token": accessToken,
"token_type": token.TokenType,
"expires_in": int(time.Until(token.Expiry).Seconds()),
}

// Add optional fields
if token.RefreshToken != "" {
response["refresh_token"] = token.RefreshToken
}

// Add ID token if present
if idToken, ok := token.Extra("id_token").(string); ok {
response["id_token"] = idToken
}

// Add scope if present
if scope, ok := token.Extra("scope").(string); ok {
response["scope"] = scope
}

// Send response
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)
return response
}

if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
}
func looksLikeJWT(token string) bool {
return strings.Count(token, ".") == 2
}

// showSuccessPage displays a success page after OAuth completion
Expand Down
83 changes: 83 additions & 0 deletions handlers_token_response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package oauth

import (
"testing"
"time"

"golang.org/x/oauth2"
)

func TestBuildTokenResponseGoogleOpaqueAccessTokenUsesIDToken(t *testing.T) {
handler := &OAuth2Handler{
config: &OAuth2Config{Provider: "google"},
logger: &defaultLogger{},
}

idToken := "header.payload.signature"
token := (&oauth2.Token{
AccessToken: "ya29.a0ARW5m7Opaque",
TokenType: "Bearer",
RefreshToken: "refresh-token",
Expiry: time.Now().Add(time.Hour),
}).WithExtra(map[string]interface{}{
"id_token": idToken,
"scope": "openid profile email",
})

response := handler.buildTokenResponse(token)

if response["access_token"] != idToken {
t.Fatalf("expected access_token to be mapped to id_token for Google opaque token")
}

if response["id_token"] != idToken {
t.Fatalf("expected id_token in response")
}

if response["refresh_token"] != "refresh-token" {
t.Fatalf("expected refresh_token in response")
}
}

func TestBuildTokenResponseGoogleJWTAccessTokenPreserved(t *testing.T) {
handler := &OAuth2Handler{
config: &OAuth2Config{Provider: "google"},
logger: &defaultLogger{},
}

jwtAccessToken := "jwt.access.token"
token := (&oauth2.Token{
AccessToken: jwtAccessToken,
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}).WithExtra(map[string]interface{}{
"id_token": "id.token.value",
})

response := handler.buildTokenResponse(token)

if response["access_token"] != jwtAccessToken {
t.Fatalf("expected JWT access_token to remain unchanged")
}
}

func TestBuildTokenResponseNonGoogleAccessTokenPreserved(t *testing.T) {
handler := &OAuth2Handler{
config: &OAuth2Config{Provider: "okta"},
logger: &defaultLogger{},
}

token := (&oauth2.Token{
AccessToken: "opaque-token-value",
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
}).WithExtra(map[string]interface{}{
"id_token": "id.token.value",
})

response := handler.buildTokenResponse(token)

if response["access_token"] != "opaque-token-value" {
t.Fatalf("expected non-Google access_token to remain unchanged")
}
}
14 changes: 9 additions & 5 deletions oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,13 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " {
s.logger.Info("OAuth: No bearer token provided, returning 401 with discovery info")
s.logger.Debug("OAuth: No bearer token provided, returning 401 with discovery info")

metadataURL := s.GetProtectedResourceMetadataURL()
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token"`)
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token", resource_metadata="%s"`,
metadataURL,
))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)

Expand All @@ -287,8 +289,10 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
s.logger.Info("OAuth: Token validation failed: %v", err)

metadataURL := s.GetProtectedResourceMetadataURL()
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed"`)
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`,
metadataURL,
))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)

Expand Down
100 changes: 95 additions & 5 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package provider
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -54,13 +57,17 @@ type HMACValidator struct {

// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
type OIDCValidator struct {
verifier *oidc.IDTokenVerifier
provider *oidc.Provider
audience string
TokenValidators []func(claims jwt.MapClaims) error
logger Logger
verifier *oidc.IDTokenVerifier
provider *oidc.Provider
audience string
providerName string
skipAudienceCheck bool
TokenValidators []func(claims jwt.MapClaims) error
logger Logger
}

var googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo"

// Initialize sets up the HMAC validator with JWT secret and audience
func (v *HMACValidator) Initialize(cfg *Config) error {
v.secretOnce.Do(func() {
Expand Down Expand Up @@ -173,6 +180,8 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
v.logger = &noOpLogger{}
}
v.audience = cfg.Audience
v.providerName = strings.ToLower(cfg.Provider)
v.skipAudienceCheck = cfg.SkipAudienceCheck

// Use standard library context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Expand Down Expand Up @@ -226,14 +235,22 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) {
// Remove Bearer prefix if present
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
tokenString = strings.TrimSpace(tokenString)

// Use incoming context with timeout for OIDC provider call
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

if v.providerName == "google" && !looksLikeJWT(tokenString) {
return v.validateGoogleOpaqueToken(ctx, tokenString)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
}

// go-oidc handles RSA signature validation, JWKS fetching, and key rotation
idToken, err := v.verifier.Verify(ctx, tokenString)
if err != nil {
if v.providerName == "google" && isMalformedJWTError(err) {
return v.validateGoogleOpaqueToken(ctx, tokenString)
}
return nil, fmt.Errorf("token verification failed: %w", err)
}

Expand Down Expand Up @@ -277,6 +294,79 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
}, nil
}

func (v *OIDCValidator) validateGoogleOpaqueToken(ctx context.Context, tokenString string) (*User, error) {
endpoint := fmt.Sprintf("%s?access_token=%s", googleTokenInfoURL, url.QueryEscape(tokenString))

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create google tokeninfo request: %w", err)
}

client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("google tokeninfo request failed: %w", err)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
defer func() { _ = resp.Body.Close() }()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed reading google tokeninfo response: %w", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("google tokeninfo validation failed: status %d", resp.StatusCode)
}

var claims map[string]interface{}
if err := json.Unmarshal(body, &claims); err != nil {
return nil, fmt.Errorf("failed parsing google tokeninfo response: %w", err)
}

return v.userFromGoogleTokenInfoClaims(claims)
}

func (v *OIDCValidator) userFromGoogleTokenInfoClaims(claims map[string]interface{}) (*User, error) {
aud, _ := claims["aud"].(string)
if !v.skipAudienceCheck {
if aud == "" {
return nil, fmt.Errorf("missing audience claim")
}
if aud != v.audience {
return nil, fmt.Errorf("invalid audience: expected %s, got %s", v.audience, aud)
}
}

subject, _ := claims["sub"].(string)
if subject == "" {
return nil, fmt.Errorf("missing subject in token")
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

email, _ := claims["email"].(string)
username := email
if username == "" {
username = subject
}

return &User{
Subject: subject,
Username: username,
Email: email,
}, nil
}

func looksLikeJWT(token string) bool {
return strings.Count(token, ".") == 2
}

func isMalformedJWTError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "malformed jwt") || strings.Contains(msg, "compact JWS format must have three parts")
}

// validateAudience validates the audience claim matches the expected value for OIDC tokens
func (v *OIDCValidator) validateAudience(claims jwt.MapClaims) error {
// Extract audience claim (can be string or []string)
Expand Down
Loading