Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,15 @@ func TestServerWrapHandler(t *testing.T) {
}

authHeader := w.Header().Get("WWW-Authenticate")
if !strings.Contains(authHeader, "invalid_token") {
if !strings.Contains(authHeader, "invalid_request") {
t.Errorf("Expected WWW-Authenticate header with error, got: %s", authHeader)
}

if !strings.Contains(w.Body.String(), "invalid_token") {
if !strings.Contains(authHeader, "resource_metadata=") {
t.Errorf("Expected WWW-Authenticate header to include resource_metadata, got: %s", authHeader)
}

if !strings.Contains(w.Body.String(), "invalid_request") {
t.Errorf("Expected JSON error response, got: %s", w.Body.String())
}
})
Expand Down
46 changes: 33 additions & 13 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,37 +610,57 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) {

h.logger.Info("OAuth2: Token exchange successful")

// Build response
response := h.buildTokenResponse(token)

// Send response
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)

if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
}
}

// buildTokenResponse builds an RFC 6749-style token response with provider-specific behavior.
func (h *OAuth2Handler) buildTokenResponse(token *oauth2.Token) map[string]interface{} {
accessToken := token.AccessToken

if h.config != nil && h.config.Provider == "google" {
if idToken, ok := token.Extra("id_token").(string); ok && idToken != "" {
if !looksLikeJWT(token.AccessToken) {
if h.logger != nil {
h.logger.Info("OAuth2: Google provider detected opaque access token, using id_token as access_token for downstream compatibility")
}
accessToken = idToken
}
}
}

response := map[string]interface{}{
"access_token": token.AccessToken,
"access_token": accessToken,
"token_type": token.TokenType,
"expires_in": int(time.Until(token.Expiry).Seconds()),
}

// Add optional fields
if token.RefreshToken != "" {
response["refresh_token"] = token.RefreshToken
}

// Add ID token if present
if idToken, ok := token.Extra("id_token").(string); ok {
response["id_token"] = idToken
}

// Add scope if present
if scope, ok := token.Extra("scope").(string); ok {
response["scope"] = scope
}

// Send response
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)
return response
}

if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
}
func looksLikeJWT(token string) bool {
return strings.Count(token, ".") == 2
}

// showSuccessPage displays a success page after OAuth completion
Expand Down
97 changes: 97 additions & 0 deletions handlers_token_response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package oauth

import (
"crypto/sha256"
"fmt"
"testing"
"time"

"golang.org/x/oauth2"
)

func tokenHash(v interface{}) string {
s, _ := v.(string)
if s == "" {
return "<empty>"
}
return fmt.Sprintf("%x", sha256.Sum256([]byte(s)))[:16]
}

func TestBuildTokenResponse(t *testing.T) {
tests := []struct {
name string
provider string
accessToken string
idToken string
refreshToken string
scope string
expectedAccessToken string
expectedIDToken string
expectedRefreshToken string
}{
{
name: "google opaque access token uses id token",
provider: "google",
accessToken: "ya29.a0ARW5m7Opaque",
idToken: "header.payload.signature",
refreshToken: "refresh-token",
scope: "openid profile email",
expectedAccessToken: "header.payload.signature",
expectedIDToken: "header.payload.signature",
expectedRefreshToken: "refresh-token",
},
{
name: "google jwt access token preserved",
provider: "google",
accessToken: "jwt.access.token",
idToken: "id.token.value",
expectedAccessToken: "jwt.access.token",
expectedIDToken: "id.token.value",
},
{
name: "non google access token preserved",
provider: "okta",
accessToken: "opaque-token-value",
idToken: "id.token.value",
expectedAccessToken: "opaque-token-value",
expectedIDToken: "id.token.value",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := &OAuth2Handler{
config: &OAuth2Config{Provider: tt.provider},
logger: &defaultLogger{},
}

extra := map[string]interface{}{
"id_token": tt.idToken,
}
if tt.scope != "" {
extra["scope"] = tt.scope
}

token := (&oauth2.Token{
AccessToken: tt.accessToken,
TokenType: "Bearer",
RefreshToken: tt.refreshToken,
Expiry: time.Now().Add(time.Hour),
}).WithExtra(extra)

response := handler.buildTokenResponse(token)

if got := response["access_token"]; got != tt.expectedAccessToken {
t.Fatalf("access_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedAccessToken))
}

if got := response["id_token"]; got != tt.expectedIDToken {
t.Fatalf("id_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedIDToken))
}

if got := response["refresh_token"]; got != tt.expectedRefreshToken && !(got == nil && tt.expectedRefreshToken == "") {
t.Fatalf("refresh_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedRefreshToken))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
})
}
}
18 changes: 11 additions & 7 deletions oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,17 +263,19 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " {
s.logger.Info("OAuth: No bearer token provided, returning 401 with discovery info")
s.logger.Debug("OAuth: No bearer token provided, returning 401 with discovery info")

metadataURL := s.GetProtectedResourceMetadataURL()
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token"`)
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required", resource_metadata="%s"`,
metadataURL,
))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)

if err := json.NewEncoder(w).Encode(oauthErrorResponse{
Error: "invalid_token",
ErrorDescription: "Missing or invalid access token",
Error: "invalid_request",
ErrorDescription: "Bearer token required",
}); err != nil {
s.logger.Error("Error encoding OAuth error response: %v", err)
}
Expand All @@ -287,8 +289,10 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
s.logger.Info("OAuth: Token validation failed: %v", err)

metadataURL := s.GetProtectedResourceMetadataURL()
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed"`)
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`,
metadataURL,
))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)

Expand Down
113 changes: 103 additions & 10 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package provider
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -54,13 +57,17 @@ type HMACValidator struct {

// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
type OIDCValidator struct {
verifier *oidc.IDTokenVerifier
provider *oidc.Provider
audience string
TokenValidators []func(claims jwt.MapClaims) error
logger Logger
verifier *oidc.IDTokenVerifier
provider *oidc.Provider
audience string
providerName string
skipAudienceCheck bool
TokenValidators []func(claims jwt.MapClaims) error
logger Logger
}

var googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo"

// Initialize sets up the HMAC validator with JWT secret and audience
func (v *HMACValidator) Initialize(cfg *Config) error {
v.secretOnce.Do(func() {
Expand Down Expand Up @@ -173,6 +180,8 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
v.logger = &noOpLogger{}
}
v.audience = cfg.Audience
v.providerName = strings.ToLower(cfg.Provider)
v.skipAudienceCheck = cfg.SkipAudienceCheck

// Use standard library context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Expand Down Expand Up @@ -226,14 +235,22 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) {
// Remove Bearer prefix if present
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
tokenString = strings.TrimSpace(tokenString)

// Use incoming context with timeout for OIDC provider call
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

if v.providerName == "google" && !looksLikeJWT(tokenString) {
return v.validateGoogleOpaqueToken(ctx, tokenString)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
}

// go-oidc handles RSA signature validation, JWKS fetching, and key rotation
idToken, err := v.verifier.Verify(ctx, tokenString)
if err != nil {
if v.providerName == "google" && isMalformedJWTError(err) {
return v.validateGoogleOpaqueToken(ctx, tokenString)
}
return nil, fmt.Errorf("token verification failed: %w", err)
}

Expand Down Expand Up @@ -263,11 +280,8 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
}

// Run extra validation functions
for i, fn := range v.TokenValidators {
err := fn(rawClaims)
if err != nil {
return nil, fmt.Errorf("validation function %d failed with error: %w", i, err)
}
if err := v.runTokenValidators(rawClaims); err != nil {
return nil, err
}

return &User{
Expand All @@ -277,6 +291,85 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
}, nil
}

func (v *OIDCValidator) validateGoogleOpaqueToken(ctx context.Context, tokenString string) (*User, error) {
endpoint := fmt.Sprintf("%s?access_token=%s", googleTokenInfoURL, url.QueryEscape(tokenString))

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create google tokeninfo request: %w", err)
}

client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("google tokeninfo request failed: %w", err)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
defer func() { _ = resp.Body.Close() }()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed reading google tokeninfo response: %w", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("google tokeninfo validation failed: status %d", resp.StatusCode)
}

var claims jwt.MapClaims
if err := json.Unmarshal(body, &claims); err != nil {
return nil, fmt.Errorf("failed parsing google tokeninfo response: %w", err)
}

if err := v.runTokenValidators(claims); err != nil {
return nil, err
}

return v.userFromGoogleTokenInfoClaims(claims)
}

func (v *OIDCValidator) userFromGoogleTokenInfoClaims(claims jwt.MapClaims) (*User, error) {
subject, _ := claims["sub"].(string)
if subject == "" {
subject, _ = claims["user_id"].(string)
}
if subject == "" {
return nil, fmt.Errorf("missing subject in token")
}

email, _ := claims["email"].(string)
username := email
if username == "" {
username = subject
}

return &User{
Subject: subject,
Username: username,
Email: email,
}, nil
}

func (v *OIDCValidator) runTokenValidators(claims jwt.MapClaims) error {
for i, fn := range v.TokenValidators {
if err := fn(claims); err != nil {
return fmt.Errorf("validation function %d failed with error: %w", i, err)
}
}
return nil
}

func looksLikeJWT(token string) bool {
return strings.Count(token, ".") == 2
}

func isMalformedJWTError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "malformed jwt") || strings.Contains(msg, "compact JWS format must have three parts")
}

// validateAudience validates the audience claim matches the expected value for OIDC tokens
func (v *OIDCValidator) validateAudience(claims jwt.MapClaims) error {
// Extract audience claim (can be string or []string)
Expand Down
Loading