Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ func TestWithOAuth(t *testing.T) {
t.Fatal("MCP server creation failed")
}

req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)

if w.Code != http.StatusFound {
t.Fatalf("Expected /callback compatibility redirect to return 302, got %d", w.Code)
}

if location := w.Header().Get("Location"); location != "/oauth/callback?code=test-code&state=test-state" {
t.Fatalf("Expected redirect to /oauth/callback preserving query params, got %q", location)
}

t.Logf("✅ WithOAuth() works in proxy mode")
})

Expand Down Expand Up @@ -286,6 +298,10 @@ func TestServerWrapHandler(t *testing.T) {
t.Errorf("Expected WWW-Authenticate header with error, got: %s", authHeader)
}

if !strings.Contains(authHeader, "resource_metadata=") {
t.Errorf("Expected WWW-Authenticate header to include resource_metadata, got: %s", authHeader)
}

if !strings.Contains(w.Body.String(), "invalid_token") {
t.Errorf("Expected JSON error response, got: %s", w.Body.String())
}
Expand Down
40 changes: 40 additions & 0 deletions fixed_redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,46 @@ import (
"testing"
)

func TestConfiguredFixedRedirectURICompatibility(t *testing.T) {
tests := []struct {
name string
config *OAuth2Config
expected string
}{
{
name: "explicit fixed redirect URI wins",
config: &OAuth2Config{
FixedRedirectURI: "https://server.example.com/oauth/callback",
RedirectURIs: "https://old.example.com/oauth/callback",
},
expected: "https://server.example.com/oauth/callback",
},
{
name: "single legacy redirect URI acts as fixed redirect",
config: &OAuth2Config{
RedirectURIs: "https://server.example.com/oauth/callback",
},
expected: "https://server.example.com/oauth/callback",
},
{
name: "multiple redirect URIs do not imply fixed redirect",
config: &OAuth2Config{
RedirectURIs: "https://a.example.com/callback,https://b.example.com/callback",
},
expected: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := &OAuth2Handler{config: tt.config}
if got := handler.configuredFixedRedirectURI(); got != tt.expected {
t.Fatalf("configuredFixedRedirectURI() = %q, want %q", got, tt.expected)
}
})
}
}

func TestFixedRedirectModeLocalhostOnly(t *testing.T) {
key := make([]byte, 32)
_, _ = rand.Read(key)
Expand Down
76 changes: 57 additions & 19 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,11 @@ func (h *OAuth2Handler) HandleAuthorize(w http.ResponseWriter, r *http.Request)
}
}

if redirectURI == "" && h.config.FixedRedirectURI != "" {
fixedRedirectURI := h.configuredFixedRedirectURI()

if redirectURI == "" && fixedRedirectURI != "" {
// Fixed redirect mode: Use server's redirect URI to OAuth provider, proxy back to client
h.logger.Info("OAuth2: Fixed redirect mode - using server URI: %s (will proxy to client: %s)", h.config.FixedRedirectURI, clientRedirectURI)
h.logger.Info("OAuth2: Fixed redirect mode - using server URI: %s (will proxy to client: %s)", fixedRedirectURI, clientRedirectURI)

// Validate client redirect URI format and security
if clientRedirectURI == "" {
Expand Down Expand Up @@ -365,7 +367,7 @@ func (h *OAuth2Handler) HandleAuthorize(w http.ResponseWriter, r *http.Request)
http.Error(w, fmt.Sprintf("Invalid redirect_uri for fixed redirect mode: %s", clientRedirectURI), http.StatusBadRequest)
return
}
redirectURI = strings.TrimSpace(h.config.FixedRedirectURI)
redirectURI = fixedRedirectURI
h.logger.Info("OAuth2: Validated client redirect URI for proxy: %s", clientRedirectURI)
// For fixed redirect mode, create signed state with client redirect URI
// Create state data with redirect URI
Expand Down Expand Up @@ -456,7 +458,7 @@ func (h *OAuth2Handler) HandleCallback(w http.ResponseWriter, r *http.Request) {
}

// If using fixed redirect URI, handle proxy callback
if h.config.FixedRedirectURI != "" {
if h.configuredFixedRedirectURI() != "" {
// Verify and decode signed state parameter
stateData, err := h.verifyState(state)
if err != nil {
Expand Down Expand Up @@ -553,8 +555,8 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) {

// Set redirect URI for token exchange
redirectURI := clientRedirectURI
if h.config.FixedRedirectURI != "" && !h.isValidRedirectURI(clientRedirectURI) {
redirectURI = strings.TrimSpace(h.config.FixedRedirectURI)
if fixedRedirectURI := h.configuredFixedRedirectURI(); fixedRedirectURI != "" && !h.isValidRedirectURI(clientRedirectURI) {
redirectURI = fixedRedirectURI
h.logger.Info("OAuth2: Token exchange using fixed redirect URI: %s", redirectURI)
}

Expand Down Expand Up @@ -610,37 +612,57 @@ func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) {

h.logger.Info("OAuth2: Token exchange successful")

// Build response
response := h.buildTokenResponse(token)

// Send response
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)

if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
}
}

// buildTokenResponse builds an RFC 6749-style token response with provider-specific behavior.
func (h *OAuth2Handler) buildTokenResponse(token *oauth2.Token) map[string]interface{} {
accessToken := token.AccessToken

if h.config != nil && h.config.Provider == "google" {
if idToken, ok := token.Extra("id_token").(string); ok && idToken != "" {
if !looksLikeJWT(token.AccessToken) {
if h.logger != nil {
h.logger.Info("OAuth2: Google provider detected opaque access token, using id_token as access_token for downstream compatibility")
}
accessToken = idToken
}
}
}

response := map[string]interface{}{
"access_token": token.AccessToken,
"access_token": accessToken,
"token_type": token.TokenType,
"expires_in": int(time.Until(token.Expiry).Seconds()),
}

// Add optional fields
if token.RefreshToken != "" {
response["refresh_token"] = token.RefreshToken
}

// Add ID token if present
if idToken, ok := token.Extra("id_token").(string); ok {
response["id_token"] = idToken
}

// Add scope if present
if scope, ok := token.Extra("scope").(string); ok {
response["scope"] = scope
}

// Send response
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)
return response
}

if err := json.NewEncoder(w).Encode(response); err != nil {
h.logger.Error("OAuth2: Failed to encode token response: %v", err)
}
func looksLikeJWT(token string) bool {
return strings.Count(token, ".") == 2
}

// showSuccessPage displays a success page after OAuth completion
Expand Down Expand Up @@ -676,6 +698,22 @@ func truncateString(s string, maxLen int) string {
return s[:maxLen] + "..."
}

// configuredFixedRedirectURI returns the effective fixed redirect URI.
// Prefer the explicit FixedRedirectURI, but preserve the historical behavior
// where a single RedirectURIs value also acted as fixed redirect mode.
func (h *OAuth2Handler) configuredFixedRedirectURI() string {
if h == nil || h.config == nil {
return ""
}
if trimmed := strings.TrimSpace(h.config.FixedRedirectURI); trimmed != "" {
return trimmed
}
if redirects := strings.TrimSpace(h.config.RedirectURIs); redirects != "" && !strings.Contains(redirects, ",") {
return redirects
}
return ""
}

// pkceTransport adds PKCE code_verifier to token exchange requests
type pkceTransport struct {
base http.RoundTripper
Expand Down
97 changes: 97 additions & 0 deletions handlers_token_response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package oauth

import (
"crypto/sha256"
"fmt"
"testing"
"time"

"golang.org/x/oauth2"
)

func tokenHash(v interface{}) string {
s, _ := v.(string)
if s == "" {
return "<empty>"
}
return fmt.Sprintf("%x", sha256.Sum256([]byte(s)))[:16]
}

func TestBuildTokenResponse(t *testing.T) {
tests := []struct {
name string
provider string
accessToken string
idToken string
refreshToken string
scope string
expectedAccessToken string
expectedIDToken string
expectedRefreshToken string
}{
{
name: "google opaque access token uses id token",
provider: "google",
accessToken: "ya29.a0ARW5m7Opaque",
idToken: "header.payload.signature",
refreshToken: "refresh-token",
scope: "openid profile email",
expectedAccessToken: "header.payload.signature",
expectedIDToken: "header.payload.signature",
expectedRefreshToken: "refresh-token",
},
{
name: "google jwt access token preserved",
provider: "google",
accessToken: "jwt.access.token",
idToken: "id.token.value",
expectedAccessToken: "jwt.access.token",
expectedIDToken: "id.token.value",
},
{
name: "non google access token preserved",
provider: "okta",
accessToken: "opaque-token-value",
idToken: "id.token.value",
expectedAccessToken: "opaque-token-value",
expectedIDToken: "id.token.value",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := &OAuth2Handler{
config: &OAuth2Config{Provider: tt.provider},
logger: &defaultLogger{},
}

extra := map[string]interface{}{
"id_token": tt.idToken,
}
if tt.scope != "" {
extra["scope"] = tt.scope
}

token := (&oauth2.Token{
AccessToken: tt.accessToken,
TokenType: "Bearer",
RefreshToken: tt.refreshToken,
Expiry: time.Now().Add(time.Hour),
}).WithExtra(extra)

response := handler.buildTokenResponse(token)

if got := response["access_token"]; got != tt.expectedAccessToken {
t.Fatalf("access_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedAccessToken))
}

if got := response["id_token"]; got != tt.expectedIDToken {
t.Fatalf("id_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedIDToken))
}

if got := response["refresh_token"]; got != tt.expectedRefreshToken && !(got == nil && tt.expectedRefreshToken == "") {
t.Fatalf("refresh_token hash = %s, want hash %s", tokenHash(got), tokenHash(tt.expectedRefreshToken))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
})
}
}
17 changes: 12 additions & 5 deletions oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func (s *Server) RegisterHandlers(mux *http.ServeMux) {
mux.HandleFunc("/.well-known/oauth-protected-resource", s.handler.HandleProtectedResourceMetadata)
mux.HandleFunc("/.well-known/jwks.json", s.handler.HandleJWKS)
mux.HandleFunc("/oauth/authorize", s.handler.HandleAuthorize)
mux.HandleFunc("/callback", s.handler.HandleCallbackRedirect)
mux.HandleFunc("/oauth/callback", s.handler.HandleCallback)
mux.HandleFunc("/oauth/token", s.handler.HandleToken)
mux.HandleFunc("/oauth/register", s.handler.HandleRegister)
Expand Down Expand Up @@ -263,11 +264,15 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " {
s.logger.Info("OAuth: No bearer token provided, returning 401 with discovery info")
s.logger.Debug("OAuth: No bearer token provided, returning 401 with discovery info")

metadataURL := s.GetProtectedResourceMetadataURL()
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token"`)
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
// Preserve the historical Bearer challenge here for clients that
// key off invalid_token when bootstrapping OAuth discovery.
`Bearer realm="OAuth", error="invalid_token", error_description="Missing or invalid access token", resource_metadata="%s"`,
metadataURL,
))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)

Expand All @@ -287,8 +292,10 @@ func (s *Server) WrapHandler(next http.Handler) http.Handler {
s.logger.Info("OAuth: Token validation failed: %v", err)

metadataURL := s.GetProtectedResourceMetadataURL()
w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed"`)
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`,
metadataURL,
))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)

Expand Down
Loading
Loading