From 8da9ba9dbfc7c2cf1b8a962ce27180dea4261b76 Mon Sep 17 00:00:00 2001 From: Daishan Peng Date: Tue, 26 Aug 2025 17:26:04 -0700 Subject: [PATCH 1/2] Enhance: Support cookie flow for mcp-ui Signed-off-by: Daishan Peng wip Signed-off-by: Daishan Peng --- cmd/root.go | 190 ++++++++++++++++++ go.mod | 7 +- go.sum | 12 ++ main.go | 28 +-- main_test.go | 27 ++- pkg/db/db.go | 8 - pkg/db/db_test.go | 3 +- pkg/encryption/encryption.go | 66 +++++++ pkg/mcpui/cookies.go | 117 +++++++++++ pkg/mcpui/jwt.go | 166 ++++++++++++++++ pkg/mcpui/jwt_test.go | 85 ++++++++ pkg/mcpui/manager.go | 127 ++++++++++++ pkg/oauth/authorize/authorize.go | 5 + pkg/oauth/callback/callback.go | 82 +++++++- pkg/oauth/success/success.go | 146 ++++++++++++++ pkg/oauth/token/token.go | 36 ++-- pkg/oauth/validate/validatetoken.go | 291 +++++++++++++++++++++++++++- pkg/proxy/proxy.go | 73 +++---- pkg/proxy/proxy_test.go | 161 +++++++++------ pkg/tokens/manager.go | 10 +- pkg/tokens/manager_test.go | 6 +- 21 files changed, 1460 insertions(+), 186 deletions(-) create mode 100644 cmd/root.go create mode 100644 pkg/mcpui/cookies.go create mode 100644 pkg/mcpui/jwt.go create mode 100644 pkg/mcpui/jwt_test.go create mode 100644 pkg/mcpui/manager.go create mode 100644 pkg/oauth/success/success.go diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..911938d --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,190 @@ +package cmd + +import ( + "fmt" + "log" + "net/http" + "net/url" + + "github.com/gptscript-ai/cmd" + "github.com/obot-platform/mcp-oauth-proxy/pkg/proxy" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" + "github.com/spf13/cobra" +) + +var ( + version = "dev" + buildTime = "unknown" +) + +// RootCmd represents the base command when called without any subcommands +type RootCmd struct { + // Database configuration + DatabaseDSN string `name:"database-dsn" env:"DATABASE_DSN" usage:"Database connection string (PostgreSQL or SQLite file path). If empty, uses SQLite at data/oauth_proxy.db"` + + // OAuth Provider configuration + OAuthClientID string `name:"oauth-client-id" env:"OAUTH_CLIENT_ID" usage:"OAuth client ID from your OAuth provider" required:"true"` + OAuthClientSecret string `name:"oauth-client-secret" env:"OAUTH_CLIENT_SECRET" usage:"OAuth client secret from your OAuth provider" required:"true"` + OAuthAuthorizeURL string `name:"oauth-authorize-url" env:"OAUTH_AUTHORIZE_URL" usage:"Authorization endpoint URL from your OAuth provider (e.g., https://accounts.google.com)" required:"true"` + + // Scopes and MCP configuration + ScopesSupported string `name:"scopes-supported" env:"SCOPES_SUPPORTED" usage:"Comma-separated list of supported OAuth scopes (e.g., 'openid,profile,email')" required:"true"` + MCPServerURL string `name:"mcp-server-url" env:"MCP_SERVER_URL" usage:"URL of the MCP server to proxy requests to" required:"true"` + + // Security configuration + EncryptionKey string `name:"encryption-key" env:"ENCRYPTION_KEY" usage:"Base64-encoded 32-byte AES-256 key for encrypting sensitive data (optional)"` + + // Server configuration + Port string `name:"port" env:"PORT" usage:"Port to run the server on" default:"8080"` + Host string `name:"host" env:"HOST" usage:"Host to bind the server to" default:"localhost"` + + // Logging + Verbose bool `name:"verbose,v" usage:"Enable verbose logging"` + Version bool `name:"version" usage:"Show version information"` + + Mode string `name:"mode" env:"MODE" usage:"Mode to run the server in" default:"proxy"` +} + +func (c *RootCmd) Run(cobraCmd *cobra.Command, args []string) error { + if c.Version { + fmt.Printf("MCP OAuth Proxy\n") + fmt.Printf("Version: %s\n", version) + fmt.Printf("Built: %s\n", buildTime) + return nil + } + + // Configure logging + if c.Verbose { + log.SetFlags(log.LstdFlags | log.Lshortfile) + log.Println("Verbose logging enabled") + } + + // Convert CLI config to internal config format + config := &types.Config{ + DatabaseDSN: c.DatabaseDSN, + OAuthClientID: c.OAuthClientID, + OAuthClientSecret: c.OAuthClientSecret, + OAuthAuthorizeURL: c.OAuthAuthorizeURL, + ScopesSupported: c.ScopesSupported, + MCPServerURL: c.MCPServerURL, + EncryptionKey: c.EncryptionKey, + Mode: c.Mode, + } + + // Validate configuration + if err := c.validateConfig(); err != nil { + return fmt.Errorf("configuration validation failed: %w", err) + } + + // Create OAuth proxy + oauthProxy, err := proxy.NewOAuthProxy(config) + if err != nil { + return fmt.Errorf("failed to create OAuth proxy: %w", err) + } + defer func() { + if err := oauthProxy.Close(); err != nil { + log.Printf("Error closing database: %v", err) + } + }() + + // Get HTTP handler + handler := oauthProxy.GetHandler() + + // Start server + address := fmt.Sprintf("%s:%s", c.Host, c.Port) + log.Printf("Starting OAuth proxy server on %s", address) + log.Printf("OAuth Provider: %s", c.OAuthAuthorizeURL) + log.Printf("MCP Server: %s", c.MCPServerURL) + log.Printf("Database: %s", c.getDatabaseType()) + + return http.ListenAndServe(address, handler) +} + +func (c *RootCmd) validateConfig() error { + if c.OAuthClientID == "" { + return fmt.Errorf("oauth-client-id is required") + } + if c.OAuthClientSecret == "" { + return fmt.Errorf("oauth-client-secret is required") + } + if c.OAuthAuthorizeURL == "" { + return fmt.Errorf("oauth-authorize-url is required") + } + if c.ScopesSupported == "" { + return fmt.Errorf("scopes-supported is required") + } + if c.MCPServerURL == "" { + return fmt.Errorf("mcp-server-url is required") + } + if c.Mode == proxy.ModeProxy { + if u, err := url.Parse(c.MCPServerURL); err != nil || u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("invalid MCP server URL: %w", err) + } else if u.Path != "" && u.Path != "/" || u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("MCP server URL must not contain a path, query, or fragment") + } + } + return nil +} + +func (c *RootCmd) getDatabaseType() string { + if c.DatabaseDSN == "" { + return "SQLite (data/oauth_proxy.db)" + } + if len(c.DatabaseDSN) > 10 && (c.DatabaseDSN[:11] == "postgres://" || c.DatabaseDSN[:14] == "postgresql://") { + return "PostgreSQL" + } + return fmt.Sprintf("SQLite (%s)", c.DatabaseDSN) +} + +// Customizer interface implementation for additional command customization +func (c *RootCmd) Customize(cobraCmd *cobra.Command) { + cobraCmd.Use = "mcp-oauth-proxy" + cobraCmd.Short = "OAuth 2.1 proxy server for MCP (Model Context Protocol)" + cobraCmd.Long = `MCP OAuth Proxy is a comprehensive OAuth 2.1 proxy server that provides +OAuth authorization server functionality with PostgreSQL/SQLite storage. + +This proxy supports multiple OAuth providers (Google, Microsoft, GitHub) and +proxies requests to MCP servers with user context headers. + +Examples: + # Start with environment variables + export OAUTH_CLIENT_ID="your-google-client-id" + export OAUTH_CLIENT_SECRET="your-secret" + export OAUTH_AUTHORIZE_URL="https://accounts.google.com" + export SCOPES_SUPPORTED="openid,profile,email" + export MCP_SERVER_URL="http://localhost:3000" + mcp-oauth-proxy + + # Start with CLI flags + mcp-oauth-proxy \ + --oauth-client-id="your-google-client-id" \ + --oauth-client-secret="your-secret" \ + --oauth-authorize-url="https://accounts.google.com" \ + --scopes-supported="openid,profile,email" \ + --mcp-server-url="http://localhost:3000" + + # Use PostgreSQL database + mcp-oauth-proxy \ + --database-dsn="postgres://user:pass@localhost:5432/oauth_db?sslmode=disable" \ + --oauth-client-id="your-client-id" \ + # ... other required flags + +Configuration: + Configuration values are loaded in this order (later values override earlier ones): + 1. Default values + 2. Environment variables + 3. Command line flags + +Database Support: + - PostgreSQL: Full ACID compliance, recommended for production + - SQLite: Zero configuration, perfect for development and small deployments` + + cobraCmd.Version = version +} + +// Execute is the main entry point for the CLI +func Execute() error { + rootCmd := &RootCmd{} + cobraCmd := cmd.Command(rootCmd) + return cobraCmd.Execute() +} diff --git a/go.mod b/go.mod index 3a9644f..f7c580f 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,12 @@ module github.com/obot-platform/mcp-oauth-proxy go 1.25.0 require ( + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/gorilla/handlers v1.5.2 + github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070 + github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.10.0 + golang.org/x/oauth2 v0.30.0 gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.30.1 @@ -13,6 +17,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.3 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.5 // indirect @@ -23,8 +28,8 @@ require ( github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.41.0 // indirect - golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 76082aa..034505d 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,17 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= +github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070 h1:xm5ZZFraWFwxyE7TBEncCXArubCDZTwG6s5bpMzqhSY= +github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -30,6 +37,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/main.go b/main.go index 24c58a7..2bfccf4 100644 --- a/main.go +++ b/main.go @@ -1,33 +1,13 @@ package main import ( - "log" - "net/http" + "os" - "github.com/obot-platform/mcp-oauth-proxy/pkg/proxy" + "github.com/obot-platform/mcp-oauth-proxy/cmd" ) func main() { - // Load configuration from environment variables - config, err := proxy.LoadConfigFromEnv() - if err != nil { - log.Fatalf("Failed to load configuration: %v", err) + if err := cmd.Execute(); err != nil { + os.Exit(1) } - - proxy, err := proxy.NewOAuthProxy(config) - if err != nil { - log.Fatalf("Failed to create OAuth proxy: %v", err) - } - defer func() { - if err := proxy.Close(); err != nil { - log.Printf("Error closing database: %v", err) - } - }() - - // Get HTTP handler - handler := proxy.GetHandler() - - // Start server - log.Print("Starting OAuth proxy server on localhost:" + config.Port) - log.Fatal(http.ListenAndServe(":"+config.Port, handler)) } diff --git a/main_test.go b/main_test.go index 5407672..0a56422 100644 --- a/main_test.go +++ b/main_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/obot-platform/mcp-oauth-proxy/pkg/proxy" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -62,7 +63,10 @@ func TestIntegrationFlow(t *testing.T) { }() // Create OAuth proxy - config, err := proxy.LoadConfigFromEnv() + config := &types.Config{ + Mode: proxy.ModeProxy, + } + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -180,7 +184,10 @@ func TestOAuthProxyCreation(t *testing.T) { }() // Create OAuth proxy - config, err := proxy.LoadConfigFromEnv() + config := &types.Config{ + Mode: proxy.ModeProxy, + } + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -233,7 +240,10 @@ func TestOAuthProxyStart(t *testing.T) { }() // Create OAuth proxy - config, err := proxy.LoadConfigFromEnv() + config := &types.Config{ + Mode: proxy.ModeProxy, + } + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -296,7 +306,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) { "OAUTH_AUTHORIZE_URL": "https://accounts.google.com", "SCOPES_SUPPORTED": "openid,profile,email", "PROXY_MODE": "forward_auth", - "PORT": "8082", // Different port to avoid conflicts + "PORT": "8082", // Different port to avoid conflicts "DATABASE_DSN": os.Getenv("TEST_DATABASE_DSN"), // Use test database if available } @@ -320,7 +330,10 @@ func TestForwardAuthIntegrationFlow(t *testing.T) { }() // Create OAuth proxy in forward auth mode - config, err := proxy.LoadConfigFromEnv() + config := &types.Config{ + Mode: proxy.ModeForwardAuth, + } + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -375,7 +388,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) { // Test that forward auth mode requires authorization for protected endpoints t.Run("ForwardAuthRequiresAuth", func(t *testing.T) { testPaths := []string{"/api", "/data", "/protected", "/mcp", "/test"} - + for _, path := range testPaths { t.Run("Path_"+path, func(t *testing.T) { w := httptest.NewRecorder() @@ -399,7 +412,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) { // Should get unauthorized (no proxying attempt) assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Contains(t, w.Header().Get("WWW-Authenticate"), "Bearer") - + // Should not have any proxy-related error messages assert.NotContains(t, w.Body.String(), "proxy") assert.NotContains(t, w.Body.String(), "502") diff --git a/pkg/db/db.go b/pkg/db/db.go index a5fef73..612e2b6 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -295,14 +295,6 @@ func (d *Store) RevokeToken(token string) error { return result.Error } -// UpdateTokenRefreshToken updates the refresh token for an existing token -func (d *Store) UpdateTokenRefreshToken(accessToken, newRefreshToken string) error { - hashedAccessToken := hashToken(accessToken) - hashedNewRefreshToken := hashToken(newRefreshToken) - - return d.db.Model(&types.TokenData{}).Where("access_token = ?", hashedAccessToken).Update("refresh_token", hashedNewRefreshToken).Error -} - // CleanupExpiredTokens removes expired tokens and authorization codes func (d *Store) CleanupExpiredTokens() error { now := time.Now() diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go index a058068..3fa777a 100644 --- a/pkg/db/db_test.go +++ b/pkg/db/db_test.go @@ -24,6 +24,7 @@ func TestDatabaseOperations(t *testing.T) { if dsn == "" { t.Skip("Skipping database tests: TEST_DATABASE_DSN is not set") } + db, err := New(dsn) if err != nil { t.Skipf("Skipping database tests: %v", err) @@ -234,8 +235,6 @@ func testTokenOperations(t *testing.T, db *Store) { // Test updating refresh token newRefreshTokenData, err := generateRandomString(16) require.NoError(t, err) - err = db.UpdateTokenRefreshToken(accessTokenData, newRefreshTokenData) - require.NoError(t, err) updatedToken, err := db.GetToken(accessTokenData) require.NoError(t, err) diff --git a/pkg/encryption/encryption.go b/pkg/encryption/encryption.go index 6f15690..e220f52 100644 --- a/pkg/encryption/encryption.go +++ b/pkg/encryption/encryption.go @@ -143,3 +143,69 @@ func DecryptPropsIfNeeded(encryptionKey []byte, props map[string]any) (map[strin return result, nil } + +// EncryptString encrypts a string using AES-256-GCM +func EncryptString(encryptionKey []byte, plaintext string) (string, error) { + // Create AES cipher + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + // Generate random IV + iv := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(iv); err != nil { + return "", fmt.Errorf("failed to generate IV: %w", err) + } + + // Encrypt the data + ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil) + + // Combine IV and ciphertext, then base64 encode + combined := append(iv, ciphertext...) + return base64.StdEncoding.EncodeToString(combined), nil +} + +// DecryptString decrypts a string using AES-256-GCM +func DecryptString(encryptionKey []byte, encryptedData string) (string, error) { + // Decode base64 data + combined, err := base64.StdEncoding.DecodeString(encryptedData) + if err != nil { + return "", fmt.Errorf("failed to decode encrypted data: %w", err) + } + + // Create AES cipher + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + // Extract IV and ciphertext + ivSize := gcm.NonceSize() + if len(combined) < ivSize { + return "", fmt.Errorf("encrypted data too short") + } + + iv := combined[:ivSize] + ciphertext := combined[ivSize:] + + // Decrypt the data + plaintext, err := gcm.Open(nil, iv, ciphertext, nil) + if err != nil { + return "", fmt.Errorf("failed to decrypt data: %w", err) + } + + return string(plaintext), nil +} diff --git a/pkg/mcpui/cookies.go b/pkg/mcpui/cookies.go new file mode 100644 index 0000000..6c31d21 --- /dev/null +++ b/pkg/mcpui/cookies.go @@ -0,0 +1,117 @@ +package mcpui + +import ( + "fmt" + "net/http" + "strings" +) + +const ( + MCPUICookieName = "mcp-ui-code" + MCPUIRefreshCookieName = "mcp-ui-refresh-code" + DefaultCookieMaxAge = 3600 // 1 hour +) + +// CookieManager handles browser cookies for MCP UI authentication +type CookieManager struct { + httpOnly bool + sameSite http.SameSite +} + +// NewCookieManager creates a new cookie manager +func NewCookieManager() *CookieManager { + return &CookieManager{ + httpOnly: true, + sameSite: http.SameSiteStrictMode, + } +} + +// isSecureRequest determines if the request is over HTTPS +func (c *CookieManager) isSecureRequest(r *http.Request) bool { + // Check if request is HTTPS + if r.TLS != nil { + return true + } + + // Check forwarded headers from reverse proxies + if r.Header.Get("X-Forwarded-Proto") == "https" { + return true + } + + if r.Header.Get("X-Forwarded-Ssl") == "on" { + return true + } + + return false +} + +// getDomain extracts the appropriate domain for cookies +func (c *CookieManager) getDomain(r *http.Request) string { + host := r.Host + + // Remove port if present + if colonIndex := strings.Index(host, ":"); colonIndex != -1 { + host = host[:colonIndex] + } + + // For localhost, don't set domain (allows cookies to work on localhost) + if host == "localhost" || host == "127.0.0.1" { + return "" + } + + return host +} + +// SetMCPUICookie sets the MCP UI authentication cookie containing the bearer token +func (c *CookieManager) SetMCPUICookie(w http.ResponseWriter, r *http.Request, bearerToken string) { + // Encode the bearer token for cookie storage + cookie := &http.Cookie{ + Name: MCPUICookieName, + Value: bearerToken, + Path: "/", + Domain: c.getDomain(r), + MaxAge: DefaultCookieMaxAge, + Secure: c.isSecureRequest(r), + HttpOnly: c.httpOnly, + SameSite: c.sameSite, + } + + http.SetCookie(w, cookie) +} + +// SetMCPUIRefreshCookie sets the refresh token cookie +func (c *CookieManager) SetMCPUIRefreshCookie(w http.ResponseWriter, r *http.Request, refreshToken string) { + // Encode the refresh token for cookie storage + cookie := &http.Cookie{ + Name: MCPUIRefreshCookieName, + Value: refreshToken, + Path: "/", + Domain: c.getDomain(r), + MaxAge: DefaultCookieMaxAge * 24 * 30, // 30 days for refresh token + Secure: c.isSecureRequest(r), + HttpOnly: c.httpOnly, + SameSite: c.sameSite, + } + + http.SetCookie(w, cookie) +} + +// GetMCPUICookie retrieves and decodes the MCP UI authentication cookie +func (c *CookieManager) GetMCPUICookie(r *http.Request) (string, error) { + cookie, err := r.Cookie(MCPUICookieName) + if err != nil { + return "", fmt.Errorf("MCP UI cookie not found: %w", err) + } + + return cookie.Value, nil +} + +// GetMCPUIRefreshCookie retrieves and decodes the MCP UI refresh cookie +func (c *CookieManager) GetMCPUIRefreshCookie(r *http.Request) (string, error) { + cookie, err := r.Cookie(MCPUIRefreshCookieName) + if err != nil { + return "", fmt.Errorf("MCP UI refresh cookie not found: %w", err) + } + + return cookie.Value, nil +} diff --git a/pkg/mcpui/jwt.go b/pkg/mcpui/jwt.go new file mode 100644 index 0000000..fc1e3e4 --- /dev/null +++ b/pkg/mcpui/jwt.go @@ -0,0 +1,166 @@ +package mcpui + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// MCPUICodeClaims represents the JWT claims for MCP UI codes +type MCPUICodeClaims struct { + BearerToken string `json:"bearer_token"` + RefreshToken string `json:"refresh_token"` + jwt.RegisteredClaims +} + +// JWTManager handles JWT operations for MCP UI codes with signing and encryption +type JWTManager struct { + signingKey []byte + encryptionKey []byte +} + +// NewJWTManager creates a new JWT manager with the given signing and encryption keys +func NewJWTManager(signingKey []byte, encryptionKey []byte) *JWTManager { + return &JWTManager{ + signingKey: signingKey, + encryptionKey: encryptionKey, + } +} + +// GenerateMCPUICode creates a signed JWT then encrypts it containing the bearer token with 1-minute expiration +func (j *JWTManager) GenerateMCPUICode(bearerToken string, refreshToken string) (string, error) { + // Create claims with 1-minute expiration + claims := MCPUICodeClaims{ + BearerToken: bearerToken, + RefreshToken: refreshToken, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + } + + // Create and sign the JWT first + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedJWT, err := token.SignedString(j.signingKey) + if err != nil { + return "", fmt.Errorf("failed to sign JWT token: %w", err) + } + + // Encrypt the signed JWT using AES-GCM + encryptedJWT, err := j.encrypt([]byte(signedJWT)) + if err != nil { + return "", fmt.Errorf("failed to encrypt signed JWT: %w", err) + } + + // Base64 encode the encrypted data for URL safety + return base64.URLEncoding.EncodeToString(encryptedJWT), nil +} + +// ValidateMCPUICode validates and extracts the bearer token from the encrypted and signed JWT +func (j *JWTManager) ValidateMCPUICode(tokenString string) (string, string, error) { + // Base64 decode the encrypted data + encryptedJWT, err := base64.URLEncoding.DecodeString(tokenString) + if err != nil { + return "", "", fmt.Errorf("failed to decode token: %w", err) + } + + // Decrypt the JWT + signedJWT, err := j.decrypt(encryptedJWT) + if err != nil { + return "", "", fmt.Errorf("failed to decrypt JWT: %w", err) + } + + // Parse and validate the signed JWT + token, err := jwt.ParseWithClaims(string(signedJWT), &MCPUICodeClaims{}, func(token *jwt.Token) (interface{}, error) { + // Validate the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return j.signingKey, nil + }) + + if err != nil { + return "", "", fmt.Errorf("failed to parse JWT token: %w", err) + } + + // Extract claims + if claims, ok := token.Claims.(*MCPUICodeClaims); ok && token.Valid { + return claims.BearerToken, claims.RefreshToken, nil + } + + return "", "", fmt.Errorf("invalid token claims") +} + +// encrypt encrypts data using AES-GCM +func (j *JWTManager) encrypt(data []byte) ([]byte, error) { + // Create AES cipher + block, err := aes.NewCipher(j.encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Generate random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + + // Encrypt data + ciphertext := gcm.Seal(nonce, nonce, data, nil) + return ciphertext, nil +} + +// decrypt decrypts data using AES-GCM +func (j *JWTManager) decrypt(encryptedData []byte) ([]byte, error) { + // Create AES cipher + block, err := aes.NewCipher(j.encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Check minimum length + nonceSize := gcm.NonceSize() + if len(encryptedData) < nonceSize { + return nil, fmt.Errorf("encrypted data too short") + } + + // Extract nonce and ciphertext + nonce := encryptedData[:nonceSize] + ciphertext := encryptedData[nonceSize:] + + // Decrypt data + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt: %w", err) + } + + return plaintext, nil +} + +// EncodeKey encodes a key to base64 for storage +func EncodeKey(key []byte) string { + return base64.StdEncoding.EncodeToString(key) +} + +// DecodeKey decodes a base64-encoded key +func DecodeKey(encodedKey string) ([]byte, error) { + return base64.StdEncoding.DecodeString(encodedKey) +} diff --git a/pkg/mcpui/jwt_test.go b/pkg/mcpui/jwt_test.go new file mode 100644 index 0000000..e2ab65f --- /dev/null +++ b/pkg/mcpui/jwt_test.go @@ -0,0 +1,85 @@ +package mcpui + +import ( + "crypto/rand" + "testing" +) + +func TestJWTManagerEncryption(t *testing.T) { + // Generate a test encryption key + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + + // Create JWT manager (using same key for signing and encryption) + jwtManager := NewJWTManager(key, key) + + // Test bearer token + testBearerToken := "user123:grant456:secret789" + testRefreshToken := "refresh123" + + // Generate encrypted JWT + encryptedJWT, err := jwtManager.GenerateMCPUICode(testBearerToken, testRefreshToken) + if err != nil { + t.Fatalf("Failed to generate MCP UI code: %v", err) + } + + // JWT should not be empty + if encryptedJWT == "" { + t.Fatal("Generated JWT is empty") + } + + // JWT should not contain the bearer token in plaintext + if string(encryptedJWT) == testBearerToken { + t.Fatal("JWT contains plaintext bearer token") + } + + // Validate and decrypt the JWT + extractedBearerToken, extractedRefreshToken, err := jwtManager.ValidateMCPUICode(encryptedJWT) + if err != nil { + t.Fatalf("Failed to validate MCP UI code: %v", err) + } + + // Extracted token should match original + if extractedBearerToken != testBearerToken { + t.Fatalf("Extracted bearer token doesn't match: expected %s, got %s", testBearerToken, extractedBearerToken) + } + + // Extracted refresh token should match original + if extractedRefreshToken != testRefreshToken { + t.Fatalf("Extracted refresh token doesn't match: expected %s, got %s", testRefreshToken, extractedRefreshToken) + } +} + +func TestJWTManagerWrongKey(t *testing.T) { + // Generate two different keys + key1 := make([]byte, 32) + key2 := make([]byte, 32) + if _, err := rand.Read(key1); err != nil { + t.Fatalf("Failed to generate test key1: %v", err) + } + if _, err := rand.Read(key2); err != nil { + t.Fatalf("Failed to generate test key2: %v", err) + } + + // Create JWT managers with different keys + jwtManager1 := NewJWTManager(key1, key1) + jwtManager2 := NewJWTManager(key2, key2) + + // Test bearer token + testBearerToken := "user123:grant456:secret789" + testRefreshToken := "refresh123" + + // Generate JWT with first manager + encryptedJWT, err := jwtManager1.GenerateMCPUICode(testBearerToken, testRefreshToken) + if err != nil { + t.Fatalf("Failed to generate MCP UI code: %v", err) + } + + // Try to validate with second manager (wrong key) - should fail + _, _, err = jwtManager2.ValidateMCPUICode(encryptedJWT) + if err == nil { + t.Fatal("Expected validation to fail with wrong key, but it succeeded") + } +} diff --git a/pkg/mcpui/manager.go b/pkg/mcpui/manager.go new file mode 100644 index 0000000..64ca9fb --- /dev/null +++ b/pkg/mcpui/manager.go @@ -0,0 +1,127 @@ +package mcpui + +import ( + "log" + "net/http" + + "github.com/obot-platform/mcp-oauth-proxy/pkg/providers" + "github.com/obot-platform/mcp-oauth-proxy/pkg/tokens" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" +) + +// Database interface for MCP UI operations +type Database interface { + GetToken(accessToken string) (*types.TokenData, error) + GetTokenByRefreshToken(refreshToken string) (*types.TokenData, error) +} + +// Manager handles MCP UI authentication flow +type Manager struct { + jwtManager *JWTManager + cookieManager *CookieManager + tokenManager *tokens.TokenManager + providers *providers.Manager + providerName string + clientID string + clientSecret string + encryptionKey []byte + db Database +} + +// NewManager creates a new MCP UI manager +func NewManager(encryptionKey []byte, tokenManager *tokens.TokenManager, providers *providers.Manager, providerName, clientID, clientSecret string, db Database) *Manager { + return &Manager{ + jwtManager: NewJWTManager(encryptionKey, encryptionKey), // Use same key for signing and encryption + cookieManager: NewCookieManager(), + tokenManager: tokenManager, + providers: providers, + providerName: providerName, + clientID: clientID, + clientSecret: clientSecret, + encryptionKey: encryptionKey, + db: db, + } +} + +// HandleMCPUIRequest processes requests with mcp-ui-code parameter +func (m *Manager) HandleMCPUIRequest(w http.ResponseWriter, r *http.Request) (string, bool) { + // Check for mcp-ui-code parameter + mcpUICode := r.URL.Query().Get(MCPUICookieName) + if mcpUICode == "" { + return "", false + } + + // Validate and extract bearer token from JWT + bearerToken, refreshToken, err := m.jwtManager.ValidateMCPUICode(mcpUICode) + if err != nil { + log.Printf("Invalid MCP UI code: %v", err) + // JWT expired or invalid, need to initiate OAuth flow + return "", false + } + + // Set cookies with the bearer token + m.cookieManager.SetMCPUICookie(w, r, bearerToken) + if refreshToken != "" { + m.cookieManager.SetMCPUIRefreshCookie(w, r, refreshToken) + } + + log.Printf("Successfully set MCP UI cookies from JWT") + return bearerToken, true +} + +// CheckCookieAuth checks if request has valid cookie authentication +func (m *Manager) CheckCookieAuth(r *http.Request) (string, bool) { + // Try to get bearer token from cookie + bearerToken, err := m.cookieManager.GetMCPUICookie(r) + if err != nil { + return "", false + } + + // Validate the bearer token + _, err = m.tokenManager.ValidateAccessToken(bearerToken) + if err != nil { + log.Printf("Bearer token from cookie is invalid: %v", err) + // Try to refresh the token + refreshedToken, refreshed := m.tryRefreshToken(r) + if refreshed { + return refreshedToken, true + } + return "", false + } + + return bearerToken, true +} + +// tryRefreshToken attempts to refresh an expired token using the refresh cookie +func (m *Manager) tryRefreshToken(r *http.Request) (string, bool) { + // Get refresh token from cookie + refreshToken, err := m.cookieManager.GetMCPUIRefreshCookie(r) + if err != nil { + log.Printf("No refresh token available: %v", err) + return "", false + } + + // Get provider for token refresh + provider, err := m.providers.GetProvider(m.providerName) + if err != nil { + log.Printf("Failed to get provider for token refresh: %v", err) + return "", false + } + + // Refresh the token + _, err = provider.RefreshToken(r.Context(), refreshToken, m.clientID, m.clientSecret) + if err != nil { + log.Printf("Failed to refresh token: %v", err) + return "", false + } + + // For now, return empty as we need database integration to update grants + // This will be implemented when integrating with the main proxy + log.Printf("Token refresh successful but grant update not implemented in standalone manager") + return "", false +} + +// GenerateMCPUICodeForDownstream creates a JWT for sending to downstream MCP server +func (m *Manager) GenerateMCPUICodeForDownstream(bearerToken, refreshToken string) (string, error) { + return m.jwtManager.GenerateMCPUICode(bearerToken, refreshToken) +} diff --git a/pkg/oauth/authorize/authorize.go b/pkg/oauth/authorize/authorize.go index 098e878..cf03b37 100644 --- a/pkg/oauth/authorize/authorize.go +++ b/pkg/oauth/authorize/authorize.go @@ -126,6 +126,11 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { "code_challenge_method": authReq.CodeChallengeMethod, } + // Add redirect parameter if present for post-auth redirect + if rd := params.Get("rd"); rd != "" { + authData["rd"] = rd + } + if err := p.db.StoreAuthRequest(stateKey, authData); err != nil { handlerutils.JSON(w, http.StatusInternalServerError, types.OAuthError{ Error: "server_error", diff --git a/pkg/oauth/callback/callback.go b/pkg/oauth/callback/callback.go index 28f318f..4806112 100644 --- a/pkg/oauth/callback/callback.go +++ b/pkg/oauth/callback/callback.go @@ -10,6 +10,7 @@ import ( "github.com/obot-platform/mcp-oauth-proxy/pkg/encryption" "github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils" + "github.com/obot-platform/mcp-oauth-proxy/pkg/mcpui" "github.com/obot-platform/mcp-oauth-proxy/pkg/providers" "github.com/obot-platform/mcp-oauth-proxy/pkg/types" ) @@ -19,6 +20,7 @@ type Store interface { StoreAuthCode(code, grantID, userID string) error GetAuthRequest(key string) (map[string]any, error) DeleteAuthRequest(key string) error + StoreToken(token *types.TokenData) error } type Handler struct { @@ -27,15 +29,22 @@ type Handler struct { encryptionKey []byte clientID string clientSecret string + mcpUIManager MCPUIManager } -func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret string) http.Handler { +// MCPUIManager interface for generating JWT tokens +type MCPUIManager interface { + GenerateMCPUICodeForDownstream(bearerToken, refreshToken string) (string, error) +} + +func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret string, mcpUIManager MCPUIManager) http.Handler { return &Handler{ db: db, provider: provider, encryptionKey: encryptionKey, clientID: clientID, clientSecret: clientSecret, + mcpUIManager: mcpUIManager, } } @@ -59,6 +68,34 @@ func getStringFromMap(data map[string]any, key string) string { return "" } +// setMCPUISessionCookies sets secure HttpOnly session cookies for MCP UI authentication +func (p *Handler) setMCPUISessionCookies(w http.ResponseWriter, r *http.Request, accessToken, refreshToken string) { + // Determine if request is secure + secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" + + // Set access token cookie (1 hour) + http.SetCookie(w, &http.Cookie{ + Name: mcpui.MCPUICookieName, + Value: accessToken, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteStrictMode, + MaxAge: 3600, // 1 hour + }) + + // Set refresh token cookie (30 days) + http.SetCookie(w, &http.Cookie{ + Name: mcpui.MCPUIRefreshCookieName, + Value: refreshToken, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteStrictMode, + MaxAge: 30 * 24 * 3600, // 30 days + }) +} + func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Handle OAuth callback from external providers code := r.URL.Query().Get("code") @@ -220,6 +257,49 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if rdValue := getStringFromMap(authData, "rd"); rdValue != "" { + // This is an MCP UI flow - you should be able to issue session cookies + log.Printf("🔄 Processing MCP UI OAuth callback") + + // Generate internal application tokens (separate from OAuth provider tokens) + accessTokenSecret := encryption.GenerateRandomString(32) + accessToken := fmt.Sprintf("%s:%s:%s", userInfo.ID, grantID, accessTokenSecret) + + mcpUIRefreshTokenSecret := encryption.GenerateRandomString(32) + mcpUIRefreshToken := fmt.Sprintf("%s:%s:%s", userInfo.ID, grantID, mcpUIRefreshTokenSecret) + + // Store internal tokens in database + tokenData := &types.TokenData{ + AccessToken: accessToken, + RefreshToken: mcpUIRefreshToken, + ClientID: authReq.ClientID, + UserID: userInfo.ID, + GrantID: grantID, + Scope: authReq.Scope, + ExpiresAt: time.Now().Add(1 * time.Hour), // 1 hour for access token + CreatedAt: time.Now(), + Revoked: false, + } + + if err := p.db.StoreToken(tokenData); err != nil { + handlerutils.JSON(w, http.StatusInternalServerError, types.OAuthError{ + Error: "server_error", + ErrorDescription: "Failed to store tokens", + }) + return + } + + // Set secure HttpOnly session cookies + p.setMCPUISessionCookies(w, r, accessToken, mcpUIRefreshToken) + + // Redirect to success page with original path as parameter + baseURL := handlerutils.GetBaseURL(r) + successURL := fmt.Sprintf("%s/auth/mcp-ui/success?rd=%s", baseURL, url.QueryEscape(rdValue)) + + http.Redirect(w, r, successURL, http.StatusFound) + return + } + // Build the redirect URL back to the client redirectURL := authReq.RedirectURI diff --git a/pkg/oauth/success/success.go b/pkg/oauth/success/success.go new file mode 100644 index 0000000..a567022 --- /dev/null +++ b/pkg/oauth/success/success.go @@ -0,0 +1,146 @@ +package success + +import ( + "fmt" + "html/template" + "net/http" + "net/url" + "strings" + + "github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils" +) + +// Handler handles the MCP UI authentication success page +type Handler struct{} + +// NewHandler creates a new success page handler +func NewHandler() http.Handler { + return &Handler{} +} + +// successPageTemplate is the HTML template for the success page +const successPageTemplate = ` + + + + + Authentication Success + + + + +
+

Authentication successful.

+

Redirecting in 5 seconds...

+

Continue manually

+
+ +` + +// ServeHTTP handles the success page request +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Get the redirect path from query parameter + rdParam := r.URL.Query().Get("rd") + if rdParam == "" { + rdParam = "/" // Default to root if no redirect path provided + } + + // Parse and clean the redirect path + redirectURL, err := url.Parse(rdParam) + if err != nil || redirectURL.IsAbs() { + // If invalid or absolute URL, default to root + rdParam = "/" + } else { + // Remove mcp-ui-code parameter from query string + query := redirectURL.Query() + query.Del("mcp-ui-code") + query.Del("mcp-ui-refresh-code") // Also remove refresh code if present + + // Rebuild the path properly + if len(query) == 0 { + // If no query parameters left, just use the path + rdParam = redirectURL.Path + } else { + // If there are still query parameters, include them + redirectURL.RawQuery = query.Encode() + rdParam = redirectURL.String() + } + } + + // Ensure redirect path starts with / + if !strings.HasPrefix(rdParam, "/") { + rdParam = "/" + rdParam + } + + // Build full redirect URL + baseURL := handlerutils.GetBaseURL(r) + fullRedirectURL := baseURL + rdParam + + // Prepare template data + templateData := struct { + RedirectURL string + DisplayPath string + }{ + RedirectURL: fullRedirectURL, + DisplayPath: rdParam, // This is already cleaned above + } + + // Parse and execute template + tmpl, err := template.New("success").Parse(successPageTemplate) + if err != nil { + http.Error(w, "Failed to render page", http.StatusInternalServerError) + return + } + + // Set content type and render the page + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + + if err := tmpl.Execute(w, templateData); err != nil { + _, _ = fmt.Fprintf(w, "Error rendering page: %v", err) + } +} diff --git a/pkg/oauth/token/token.go b/pkg/oauth/token/token.go index ac34bcb..bf73109 100644 --- a/pkg/oauth/token/token.go +++ b/pkg/oauth/token/token.go @@ -216,7 +216,6 @@ func (p *Handler) handleAuthorizationCodeGrant(w http.ResponseWriter, r *http.Re // Generate refresh token in format: userId:grantId:refreshTokenSecret refreshTokenSecret := encryption.GenerateRandomString(32) - refreshToken := fmt.Sprintf("%s:%s:%s", userID, grantID, refreshTokenSecret) // Store tokens in database @@ -306,28 +305,26 @@ func (p *Handler) handleRefreshTokenGrant(w http.ResponseWriter, r *http.Request return } - // Props are stored in the grant and will be accessed when needed - // For simple string token generation, we don't need to decrypt them here - // Generate new access token in format: userId:grantId:accessTokenSecret accessTokenSecret := encryption.GenerateRandomString(32) accessToken := fmt.Sprintf("%s:%s:%s", tokenData.UserID, tokenData.GrantID, accessTokenSecret) - // Generate new refresh token in format: userId:grantId:refreshTokenSecret (OAuth 2.1 refresh token rotation) + // Generate new refresh token refreshTokenSecret := encryption.GenerateRandomString(32) - - newRefreshToken := fmt.Sprintf("%s:%s:%s", tokenData.UserID, tokenData.GrantID, refreshTokenSecret) + refreshToken = fmt.Sprintf("%s:%s:%s", tokenData.UserID, tokenData.GrantID, refreshTokenSecret) + refreshTokenExpiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days from now // Store new token in database (replaces the old one) newTokenData := &types.TokenData{ - AccessToken: accessToken, - RefreshToken: newRefreshToken, - ClientID: clientID, - UserID: tokenData.UserID, - GrantID: tokenData.GrantID, - Scope: tokenData.Scope, - ExpiresAt: time.Now().Add(time.Duration(3600) * time.Second), - CreatedAt: time.Now(), + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + UserID: tokenData.UserID, + GrantID: tokenData.GrantID, + Scope: tokenData.Scope, + ExpiresAt: time.Now().Add(time.Duration(3600) * time.Second), + RefreshTokenExpiresAt: refreshTokenExpiresAt, + CreatedAt: time.Now(), } if err := p.db.StoreToken(newTokenData); err != nil { @@ -339,17 +336,16 @@ func (p *Handler) handleRefreshTokenGrant(w http.ResponseWriter, r *http.Request return } - // Revoke the old token - if err := p.db.RevokeToken(tokenData.AccessToken); err != nil { - log.Printf("Failed to revoke old token: %v", err) - // Don't fail the request, but log the error + // Revoke the old refresh token + if err := p.db.RevokeToken(refreshToken); err != nil { + log.Printf("Failed to revoke old refresh token: %v", err) } response := types.TokenResponse{ AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: 3600, - RefreshToken: newRefreshToken, + RefreshToken: refreshToken, Scope: tokenData.Scope, } diff --git a/pkg/oauth/validate/validatetoken.go b/pkg/oauth/validate/validatetoken.go index 26f8a3c..09969e9 100644 --- a/pkg/oauth/validate/validatetoken.go +++ b/pkg/oauth/validate/validatetoken.go @@ -2,31 +2,232 @@ package validate import ( "context" + "crypto/sha256" + "encoding/base64" "fmt" + "log" "net/http" "strings" + "time" "github.com/obot-platform/mcp-oauth-proxy/pkg/encryption" "github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils" + "github.com/obot-platform/mcp-oauth-proxy/pkg/mcpui" + "github.com/obot-platform/mcp-oauth-proxy/pkg/providers" "github.com/obot-platform/mcp-oauth-proxy/pkg/tokens" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" ) type TokenValidator struct { - tokenManager *tokens.TokenManager - encryptionKey []byte + tokenManager *tokens.TokenManager + encryptionKey []byte + mcpUIManager MCPUIManager // Optional MCP UI manager for JWT handling + db TokenStore // Database for refresh operations + provider providers.Provider // OAuth provider for generating auth URLs + clientID string // OAuth client ID + clientSecret string // OAuth client secret + scopesSupported []string // Supported OAuth scopes } -func NewTokenValidator(tokenManager *tokens.TokenManager, encryptionKey []byte) *TokenValidator { +// TokenStore interface for database operations needed by validator +type TokenStore interface { + GetToken(accessToken string) (*types.TokenData, error) + GetTokenByRefreshToken(refreshToken string) (*types.TokenData, error) + StoreToken(token *types.TokenData) error + RevokeToken(token string) error + StoreAuthRequest(key string, data map[string]any) error +} + +// MCPUIManager interface for JWT handling +type MCPUIManager interface { + HandleMCPUIRequest(w http.ResponseWriter, r *http.Request) (string, bool) +} + +func NewTokenValidator(tokenManager *tokens.TokenManager, encryptionKey []byte, db TokenStore, provider providers.Provider, clientID, clientSecret string, scopesSupported []string) *TokenValidator { return &TokenValidator{ - tokenManager: tokenManager, - encryptionKey: encryptionKey, + tokenManager: tokenManager, + encryptionKey: encryptionKey, + db: db, + provider: provider, + clientID: clientID, + clientSecret: clientSecret, + scopesSupported: scopesSupported, + } +} + +func (p *TokenValidator) SetMCPUIManager(manager MCPUIManager) { + p.mcpUIManager = manager +} + +func (p *TokenValidator) SetOAuthConfig(provider providers.Provider, clientID, clientSecret string, scopesSupported []string) { + p.provider = provider + p.clientID = clientID + p.clientSecret = clientSecret + p.scopesSupported = scopesSupported +} + +// generatePKCE generates PKCE code verifier and challenge +func generatePKCE() (codeVerifier, codeChallenge string) { + // Generate a random code verifier (43-128 characters, base64url) + codeVerifier = encryption.GenerateRandomString(32) // This generates 32 bytes -> ~43 chars in base64 + + // Generate code challenge using S256 method + hash := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hash[:]) + + return codeVerifier, codeChallenge +} + +// validateCookieAuth handles cookie-based authentication with refresh capability +// Returns (tokenInfo, bearerToken, success) +func (p *TokenValidator) validateCookieAuth(w http.ResponseWriter, r *http.Request, bearerToken string) (*tokens.TokenInfo, string, bool) { + // Validate the bearer token directly + tokenInfo, err := p.tokenManager.GetTokenInfo(bearerToken) + if err == nil { + return tokenInfo, bearerToken, true + } + + // Step 2: Token is invalid, check for refresh token + refreshCookie, refreshErr := r.Cookie(mcpui.MCPUIRefreshCookieName) + if refreshErr != nil || refreshCookie.Value == "" { + return nil, "", false } + + refreshToken := refreshCookie.Value + + // Step 3: Start refresh process + return p.performMCPUIRefresh(w, r, bearerToken, refreshToken) +} + +// performMCPUIRefresh handles the MCP UI token refresh process +// Returns (tokenInfo, bearerToken, success) +func (p *TokenValidator) performMCPUIRefresh(w http.ResponseWriter, r *http.Request, accessToken, refreshToken string) (*tokens.TokenInfo, string, bool) { + // Get token data by MCP UI refresh token + tokenData, err := p.db.GetTokenByRefreshToken(refreshToken) + if err != nil { + return nil, "", false + } + + return p.rotateMCPUITokens(w, r, tokenData) +} + +// rotateMCPUITokens rotates access token and MCP UI refresh token, keeping regular refresh token unchanged +// Returns (tokenInfo, bearerToken, success) +func (p *TokenValidator) rotateMCPUITokens(w http.ResponseWriter, r *http.Request, tokenData *types.TokenData) (*tokens.TokenInfo, string, bool) { + // Generate new access token + accessTokenSecret := encryption.GenerateRandomString(32) + newAccessToken := fmt.Sprintf("%s:%s:%s", tokenData.UserID, tokenData.GrantID, accessTokenSecret) + + // Generate new MCP UI refresh token + mcpUIRefreshTokenSecret := encryption.GenerateRandomString(32) + newMCPUIRefreshToken := fmt.Sprintf("%s:%s:%s", tokenData.UserID, tokenData.GrantID, mcpUIRefreshTokenSecret) + + // Create updated token data (regular refresh token stays unchanged) + updatedTokenData := &types.TokenData{ + AccessToken: newAccessToken, + RefreshToken: newMCPUIRefreshToken, + ClientID: tokenData.ClientID, + UserID: tokenData.UserID, + GrantID: tokenData.GrantID, + Scope: tokenData.Scope, + ExpiresAt: time.Now().Add(1 * time.Hour), // 1 hour + RefreshTokenExpiresAt: tokenData.RefreshTokenExpiresAt, // Keep unchanged + CreatedAt: time.Now(), + Revoked: false, + } + + // Store updated tokens in database + if err := p.db.StoreToken(updatedTokenData); err != nil { + return nil, "", false + } + + // Revoke old access token + if err := p.db.RevokeToken(tokenData.AccessToken); err != nil { + log.Printf("❌ Failed to revoke old access token: %v", err) + } + + // Set cookies with new tokens + p.setCookiesForRefresh(w, r, newAccessToken, newMCPUIRefreshToken) + + // Validate the new access token and return token info + newTokenInfo, err := p.tokenManager.GetTokenInfo(newAccessToken) + if err != nil { + log.Printf("❌ Failed to validate newly created access token: %v", err) + return nil, "", false + } + + return newTokenInfo, newAccessToken, true +} + +// setCookiesForRefresh sets the MCP UI cookies after successful refresh +func (p *TokenValidator) setCookiesForRefresh(w http.ResponseWriter, r *http.Request, accessToken, refreshToken string) { + // Set access token cookie + http.SetCookie(w, &http.Cookie{ + Name: mcpui.MCPUICookieName, + Value: accessToken, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Secure: r.TLS != nil, + MaxAge: 3600, // 1 hour + }) + + // Set refresh token cookie + http.SetCookie(w, &http.Cookie{ + Name: mcpui.MCPUIRefreshCookieName, + Value: refreshToken, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Secure: r.TLS != nil, + MaxAge: 30 * 24 * 3600, // 30 days + }) } func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { + // Try cookie-based authentication with refresh capability + var bearerTokenFromCookie string + mcpUICodeCookie, err := r.Cookie(mcpui.MCPUICookieName) + if err == nil && mcpUICodeCookie.Value != "" { + bearerTokenFromCookie = mcpUICodeCookie.Value + } + + if tokenInfo, bearerToken, success := p.validateCookieAuth(w, r, bearerTokenFromCookie); success { + // Add token info to request context + ctx := context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo) + ctx = context.WithValue(ctx, bearerTokenKey{}, bearerToken) + next(w, r.WithContext(ctx)) + return + } else if bearerTokenFromCookie != "" { + p.handleOauthFlow(w, r) + return + } + + // Fall back to URL parameter (contains JWT) + mcpUICodeParam := r.URL.Query().Get("mcp-ui-code") + if mcpUICodeParam != "" { + bearerToken, handled := p.mcpUIManager.HandleMCPUIRequest(w, r) + if handled { + tokenInfo, err := p.tokenManager.GetTokenInfo(bearerToken) + if err != nil { + p.handleOauthFlow(w, r) + return + } + + // Add token info to request context + ctx := context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo) + ctx = context.WithValue(ctx, bearerTokenKey{}, bearerToken) + next(w, r.WithContext(ctx)) + return + } else { + p.handleOauthFlow(w, r) + return + } + } + // Return 401 with proper WWW-Authenticate header resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r)) wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="%s"`, resourceMetadataUrl) @@ -81,12 +282,90 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler tokenInfo.Props = decryptedProps } - next(w, r.WithContext(context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo))) + // Store both tokenInfo and the original bearer token string + ctx := context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo) + ctx = context.WithValue(ctx, bearerTokenKey{}, token) + next(w, r.WithContext(ctx)) } } +func (p *TokenValidator) handleOauthFlow(w http.ResponseWriter, r *http.Request) { + // Check if this is a browser request by looking at user agent + userAgent := r.Header.Get("User-Agent") + if userAgent == "" { + // Not a browser request, return 401 + resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r)) + wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Authentication required", resource_metadata="%s"`, resourceMetadataUrl) + w.Header().Set("WWW-Authenticate", wwwAuthValue) + handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{ + "error": "invalid_token", + "error_description": "Authentication required", + }) + return + } + + // Check if OAuth provider is configured + if p.provider == nil || p.clientID == "" || p.clientSecret == "" { + handlerutils.JSON(w, http.StatusInternalServerError, map[string]string{ + "error": "server_error", + "error_description": "OAuth provider not configured", + }) + return + } + + // Generate PKCE parameters + codeVerifier, codeChallenge := generatePKCE() + + // Generate a random state key for this auth request + stateKey := encryption.GenerateRandomString(32) + + // Get current path for post-auth redirect + currentPath := r.URL.Path + if r.URL.RawQuery != "" { + currentPath += "?" + r.URL.RawQuery + } + + // Store the auth request data in the database with PKCE parameters + authData := map[string]any{ + "response_type": "code", + "client_id": p.clientID, + "redirect_uri": fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r)), + "scope": strings.Join(p.scopesSupported, " "), + "state": stateKey, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "code_verifier": codeVerifier, // Store for later use in token exchange + "rd": currentPath, // Store original path for post-auth redirect + } + + if err := p.db.StoreAuthRequest(stateKey, authData); err != nil { + handlerutils.JSON(w, http.StatusInternalServerError, map[string]string{ + "error": "server_error", + "error_description": "Failed to store authorization request", + }) + return + } + + // Build the authorization URL directly to OAuth provider (not our /authorize) + redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r)) + scope := strings.Join(p.scopesSupported, " ") + + // Generate authorization URL with PKCE + authURL := p.provider.GetAuthorizationURL(p.clientID, redirectURI, scope, stateKey) + + http.Redirect(w, r, authURL, http.StatusFound) +} + func GetTokenInfo(r *http.Request) *tokens.TokenInfo { return r.Context().Value(tokenInfoKey{}).(*tokens.TokenInfo) } +func GetBearerToken(r *http.Request) string { + if token := r.Context().Value(bearerTokenKey{}); token != nil { + return token.(string) + } + return "" +} + type tokenInfoKey struct{} +type bearerTokenKey struct{} diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index d5aba5e..35345e5 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -18,6 +18,7 @@ import ( "github.com/obot-platform/mcp-oauth-proxy/pkg/db" "github.com/obot-platform/mcp-oauth-proxy/pkg/encryption" "github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils" + "github.com/obot-platform/mcp-oauth-proxy/pkg/mcpui" "github.com/obot-platform/mcp-oauth-proxy/pkg/oauth/authorize" "github.com/obot-platform/mcp-oauth-proxy/pkg/oauth/callback" "github.com/obot-platform/mcp-oauth-proxy/pkg/oauth/register" @@ -31,12 +32,8 @@ import ( "golang.org/x/oauth2" ) -const ( - ModeProxy = "proxy" - ModeForwardAuth = "forward_auth" -) - type OAuthProxy struct { + mcpUIManager *mcpui.Manager metadata *types.OAuthMetadata db *db.Store rateLimiter *ratelimit.RateLimiter @@ -52,43 +49,10 @@ type OAuthProxy struct { cancel context.CancelFunc } -// LoadConfigFromEnv loads configuration from environment variables -func LoadConfigFromEnv() (*types.Config, error) { - config := &types.Config{ - DatabaseDSN: os.Getenv("DATABASE_DSN"), - OAuthClientID: os.Getenv("OAUTH_CLIENT_ID"), - OAuthClientSecret: os.Getenv("OAUTH_CLIENT_SECRET"), - OAuthAuthorizeURL: os.Getenv("OAUTH_AUTHORIZE_URL"), - ScopesSupported: os.Getenv("SCOPES_SUPPORTED"), - EncryptionKey: os.Getenv("ENCRYPTION_KEY"), - MCPServerURL: os.Getenv("MCP_SERVER_URL"), - Mode: os.Getenv("PROXY_MODE"), - Port: os.Getenv("PORT"), - } - - if config.Port == "" { - config.Port = "8080" - } - - switch config.Mode { - case "": - fmt.Println("Defaulting to proxy mode") - config.Mode = ModeProxy - case ModeProxy, ModeForwardAuth: - default: - return nil, fmt.Errorf("invalid mode: %s", config.Mode) - } - - if config.Mode == ModeProxy { - if u, err := url.Parse(config.MCPServerURL); err != nil || u.Scheme != "http" && u.Scheme != "https" { - return nil, fmt.Errorf("invalid MCP server URL: %w", err) - } else if u.Path != "" && u.Path != "/" || u.RawQuery != "" || u.Fragment != "" { - return nil, fmt.Errorf("MCP server URL must not contain a path, query, or fragment") - } - } - - return config, nil -} +const ( + ModeProxy = "proxy" + ModeForwardAuth = "forward_auth" +) func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) { databaseDSN := config.DatabaseDSN @@ -131,6 +95,20 @@ func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) { return nil, fmt.Errorf("failed to initialize token manager: %w", err) } + encryptionKey, err := base64.StdEncoding.DecodeString(config.EncryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to decode encryption key: %w", err) + } + mcpUIManager := mcpui.NewManager( + encryptionKey, // Use encryption key for JWE encryption + tokenManager, + providerManager, + provider, + config.OAuthClientID, + config.OAuthClientSecret, + db, // Add database for refresh token operations + ) + // Split and trim scopes to handle whitespace scopesSupported := ParseScopesSupported(config.ScopesSupported) @@ -144,12 +122,8 @@ func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) { RegistrationEndpointAuthMethodsSupported: []string{"client_secret_post"}, } - encryptionKey, err := base64.StdEncoding.DecodeString(config.EncryptionKey) - if err != nil { - log.Fatalf("Failed to decode encryption key: %v", err) - } - return &OAuthProxy{ + mcpUIManager: mcpUIManager, metadata: metadata, db: db, rateLimiter: rateLimiter, @@ -220,9 +194,9 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) { authorizeHandler := authorize.NewHandler(p.db, provider, p.metadata.ScopesSupported, p.GetOAuthClientID(), p.GetOAuthClientSecret()) tokenHandler := token.NewHandler(p.db) - callbackHandler := callback.NewHandler(p.db, provider, p.encryptionKey, p.GetOAuthClientID(), p.GetOAuthClientSecret()) + callbackHandler := callback.NewHandler(p.db, provider, p.encryptionKey, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.mcpUIManager) revokeHandler := revoke.NewHandler(p.db) - tokenValidator := validate.NewTokenValidator(p.tokenManager, p.encryptionKey) + tokenValidator := validate.NewTokenValidator(p.tokenManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported) mux.HandleFunc("GET /health", p.withCORS(p.healthHandler)) @@ -236,7 +210,6 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) { // Metadata endpoints mux.HandleFunc("GET /.well-known/oauth-authorization-server", p.withCORS(p.oauthMetadataHandler)) mux.HandleFunc("GET /.well-known/oauth-protected-resource", p.withCORS(p.protectedResourceMetadataHandler)) - mux.HandleFunc("GET /.well-known/oauth-protected-resource/{path...}", p.withCORS(p.protectedResourceMetadataHandler)) // Protect everything else mux.HandleFunc("/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler)))) diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index c4c1e17..4215a5c 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -6,6 +6,7 @@ import ( "os" "testing" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -22,82 +23,106 @@ func TestLoadConfigFromEnv(t *testing.T) { defer func() { for key, value := range originalVars { if value != "" { - os.Setenv(key, value) + _ = os.Setenv(key, value) } else { - os.Unsetenv(key) + _ = os.Unsetenv(key) } } }() t.Run("DefaultMode", func(t *testing.T) { - os.Unsetenv("PROXY_MODE") - os.Setenv("MCP_SERVER_URL", "http://localhost:8081") + _ = os.Unsetenv("PROXY_MODE") + _ = os.Setenv("MCP_SERVER_URL", "http://localhost:8081") - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeProxy, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) assert.Equal(t, ModeProxy, config.Mode) }) t.Run("DefaultPort", func(t *testing.T) { - os.Unsetenv("PORT") - os.Setenv("PROXY_MODE", ModeForwardAuth) + _ = os.Unsetenv("PORT") + _ = os.Setenv("PROXY_MODE", ModeForwardAuth) - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeForwardAuth, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) assert.Equal(t, "8080", config.Port) }) t.Run("CustomPort", func(t *testing.T) { - os.Setenv("PORT", "9090") - os.Setenv("PROXY_MODE", ModeForwardAuth) + _ = os.Setenv("PORT", "9090") + _ = os.Setenv("PROXY_MODE", ModeForwardAuth) - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeForwardAuth, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) assert.Equal(t, "9090", config.Port) }) t.Run("ValidProxyMode", func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeProxy) - os.Setenv("MCP_SERVER_URL", "http://localhost:8081") + _ = os.Setenv("PROXY_MODE", ModeProxy) + _ = os.Setenv("MCP_SERVER_URL", "http://localhost:8081") - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeProxy, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) assert.Equal(t, ModeProxy, config.Mode) }) t.Run("ValidForwardAuthMode", func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeForwardAuth) + _ = os.Setenv("PROXY_MODE", ModeForwardAuth) os.Unsetenv("MCP_SERVER_URL") // Not required for forward auth mode - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeForwardAuth, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) assert.Equal(t, ModeForwardAuth, config.Mode) }) t.Run("InvalidMode", func(t *testing.T) { - os.Setenv("PROXY_MODE", "invalid_mode") + _ = os.Setenv("PROXY_MODE", "invalid_mode") - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: "invalid_mode", + } + _, err := NewOAuthProxy(config) assert.Error(t, err) assert.Nil(t, config) assert.Contains(t, err.Error(), "invalid mode: invalid_mode") }) t.Run("ProxyModeRequiresMCPServerURL", func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeProxy) + _ = os.Setenv("PROXY_MODE", ModeProxy) os.Unsetenv("MCP_SERVER_URL") - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeProxy, + } + _, err := NewOAuthProxy(config) assert.Error(t, err) assert.Nil(t, config) assert.Contains(t, err.Error(), "invalid MCP server URL") }) t.Run("ForwardAuthModeDoesNotRequireMCPServerURL", func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeForwardAuth) + _ = os.Setenv("PROXY_MODE", ModeForwardAuth) os.Unsetenv("MCP_SERVER_URL") - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeForwardAuth, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) assert.Equal(t, ModeForwardAuth, config.Mode) assert.Empty(t, config.MCPServerURL) @@ -116,10 +141,13 @@ func TestLoadConfigFromEnv(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeProxy) - os.Setenv("MCP_SERVER_URL", tc.url) + _ = os.Setenv("PROXY_MODE", ModeProxy) + _ = os.Setenv("MCP_SERVER_URL", tc.url) - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeProxy, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) assert.Equal(t, tc.url, config.MCPServerURL) }) @@ -140,10 +168,13 @@ func TestLoadConfigFromEnv(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeProxy) - os.Setenv("MCP_SERVER_URL", tc.url) + _ = os.Setenv("PROXY_MODE", ModeProxy) + _ = os.Setenv("MCP_SERVER_URL", tc.url) - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeProxy, + } + _, err := NewOAuthProxy(config) assert.Error(t, err) assert.Nil(t, config) }) @@ -258,7 +289,7 @@ func TestOAuthProxyCreationWithModes(t *testing.T) { defer func() { for key, value := range oldVars { if value != "" { - os.Setenv(key, value) + _ = os.Setenv(key, value) } else { os.Unsetenv(key) } @@ -267,14 +298,17 @@ func TestOAuthProxyCreationWithModes(t *testing.T) { // Set base environment for key, value := range baseEnvVars { - os.Setenv(key, value) + _ = os.Setenv(key, value) } t.Run("ProxyMode", func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeProxy) - os.Setenv("MCP_SERVER_URL", "http://localhost:8081") + _ = os.Setenv("PROXY_MODE", ModeProxy) + _ = os.Setenv("MCP_SERVER_URL", "http://localhost:8081") - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeProxy, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) proxy, err := NewOAuthProxy(config) @@ -291,10 +325,13 @@ func TestOAuthProxyCreationWithModes(t *testing.T) { }) t.Run("ForwardAuthMode", func(t *testing.T) { - os.Setenv("PROXY_MODE", ModeForwardAuth) + _ = os.Setenv("PROXY_MODE", ModeForwardAuth) os.Unsetenv("MCP_SERVER_URL") // Not required for forward auth - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeForwardAuth, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) proxy, err := NewOAuthProxy(config) @@ -331,14 +368,14 @@ func TestForwardAuthModeIntegration(t *testing.T) { oldVars := make(map[string]string) for key, value := range testEnvVars { oldVars[key] = os.Getenv(key) - os.Setenv(key, value) + _ = os.Setenv(key, value) } // Restore environment after test defer func() { for key, value := range oldVars { if value != "" { - os.Setenv(key, value) + _ = os.Setenv(key, value) } else { os.Unsetenv(key) } @@ -346,7 +383,10 @@ func TestForwardAuthModeIntegration(t *testing.T) { }() // Create OAuth proxy in forward auth mode - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: ModeForwardAuth, + } + _, err := NewOAuthProxy(config) require.NoError(t, err) require.Equal(t, ModeForwardAuth, config.Mode) @@ -409,11 +449,11 @@ func TestForwardAuthModeIntegration(t *testing.T) { // TestModeSpecificValidation tests that validation rules work correctly for different modes func TestModeSpecificValidation(t *testing.T) { testCases := []struct { - name string - mode string - mcpServerURL string - expectError bool - errorContains string + name string + mode string + mcpServerURL string + expectError bool + errorContains string }{ { name: "ProxyModeValidURL", @@ -456,10 +496,10 @@ func TestModeSpecificValidation(t *testing.T) { errorContains: "must not contain a path, query", }, { - name: "ForwardAuthNoURL", - mode: ModeForwardAuth, + name: "ForwardAuthNoURL", + mode: ModeForwardAuth, mcpServerURL: "", - expectError: false, + expectError: false, }, { name: "ForwardAuthWithURL", @@ -480,27 +520,30 @@ func TestModeSpecificValidation(t *testing.T) { originalURL := os.Getenv("MCP_SERVER_URL") defer func() { if originalMode != "" { - os.Setenv("PROXY_MODE", originalMode) + _ = os.Setenv("PROXY_MODE", originalMode) } else { - os.Unsetenv("PROXY_MODE") + _ = os.Unsetenv("PROXY_MODE") } if originalURL != "" { - os.Setenv("MCP_SERVER_URL", originalURL) + _ = os.Setenv("MCP_SERVER_URL", originalURL) } else { - os.Unsetenv("MCP_SERVER_URL") + _ = os.Unsetenv("MCP_SERVER_URL") } }() for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - os.Setenv("PROXY_MODE", tc.mode) + _ = os.Setenv("PROXY_MODE", tc.mode) if tc.mcpServerURL != "" { - os.Setenv("MCP_SERVER_URL", tc.mcpServerURL) + _ = os.Setenv("MCP_SERVER_URL", tc.mcpServerURL) } else { - os.Unsetenv("MCP_SERVER_URL") + _ = os.Unsetenv("MCP_SERVER_URL") } - config, err := LoadConfigFromEnv() + config := &types.Config{ + Mode: tc.mode, + } + _, err := NewOAuthProxy(config) if tc.expectError { assert.Error(t, err) @@ -521,7 +564,7 @@ func TestModeSpecificValidation(t *testing.T) { // TestHeaderOverwriting tests that headers are properly overwritten func TestHeaderOverwriting(t *testing.T) { header := make(http.Header) - + // Set initial headers header.Set("X-Forwarded-User", "old-user") header.Set("X-Forwarded-Email", "old@example.com") @@ -545,9 +588,9 @@ func TestHeaderOverwriting(t *testing.T) { func TestSpecialCharactersInHeaders(t *testing.T) { header := make(http.Header) props := map[string]any{ - "user_id": "user@domain.com", - "email": "test+tag@example.com", - "name": "John O'Doe", + "user_id": "user@domain.com", + "email": "test+tag@example.com", + "name": "John O'Doe", "access_token": "token-with-special-chars_123", } @@ -577,4 +620,4 @@ func BenchmarkSetHeaders(t *testing.B) { } setHeaders(header, props) } -} \ No newline at end of file +} diff --git a/pkg/tokens/manager.go b/pkg/tokens/manager.go index 078c2d0..b2e49ec 100644 --- a/pkg/tokens/manager.go +++ b/pkg/tokens/manager.go @@ -10,7 +10,7 @@ import ( // TokenManager handles token generation and validation type TokenManager struct { - db Database + DB Database } // Database interface for token operations @@ -29,13 +29,13 @@ type TokenClaims struct { // NewTokenManager creates a new token manager func NewTokenManager(db Database) (*TokenManager, error) { return &TokenManager{ - db: db, + DB: db, }, nil } // ValidateAccessToken validates and parses a simple string access token func (tm *TokenManager) ValidateAccessToken(tokenString string) (*TokenClaims, error) { - if tm.db == nil { + if tm.DB == nil { return nil, fmt.Errorf("database not configured for token validation") } @@ -49,7 +49,7 @@ func (tm *TokenManager) ValidateAccessToken(tokenString string) (*TokenClaims, e grantID := parts[1] // Get token data from database - tokenData, err := tm.db.GetToken(tokenString) + tokenData, err := tm.DB.GetToken(tokenString) if err != nil { return nil, fmt.Errorf("token not found: %w", err) } @@ -65,7 +65,7 @@ func (tm *TokenManager) ValidateAccessToken(tokenString string) (*TokenClaims, e } // Get the grant to access props - grant, err := tm.db.GetGrant(grantID, userID) + grant, err := tm.DB.GetGrant(grantID, userID) if err != nil { return nil, fmt.Errorf("grant not found: %w", err) } diff --git a/pkg/tokens/manager_test.go b/pkg/tokens/manager_test.go index 8e7ebd8..7e687e3 100644 --- a/pkg/tokens/manager_test.go +++ b/pkg/tokens/manager_test.go @@ -52,7 +52,7 @@ func TestTokenManager(t *testing.T) { tokenManager, err := NewTokenManager(mockDB) require.NoError(t, err) - assert.Equal(t, mockDB, tokenManager.db) + assert.Equal(t, mockDB, tokenManager.DB) }) t.Run("TestValidateAccessToken", func(t *testing.T) { @@ -61,13 +61,13 @@ func TestTokenManager(t *testing.T) { require.NoError(t, err) // Test without database - tokenManager.db = nil + tokenManager.DB = nil _, err = tokenManager.ValidateAccessToken("test_token") assert.Error(t, err) assert.Contains(t, err.Error(), "database not configured") // Test with database but no token - tokenManager.db = mockDB + tokenManager.DB = mockDB _, err = tokenManager.ValidateAccessToken("non_existent_token") assert.Error(t, err) // The error could be either "token not found" or "invalid token format" depending on the token format From 881d984e7fd33bc5d3fc1671e1a68ee9b7749a1e Mon Sep 17 00:00:00 2001 From: Daishan Peng Date: Fri, 5 Sep 2025 16:09:43 -0700 Subject: [PATCH 2/2] Add prefix Signed-off-by: Daishan Peng --- cmd/root.go | 6 ++++-- main.go | 2 ++ pkg/oauth/authorize/authorize.go | 6 ++++-- pkg/oauth/callback/callback.go | 6 ++++-- pkg/proxy/proxy.go | 37 +++++++++++++++++++------------- pkg/types/types.go | 1 + 6 files changed, 37 insertions(+), 21 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 911938d..dd9ab88 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -35,8 +35,9 @@ type RootCmd struct { EncryptionKey string `name:"encryption-key" env:"ENCRYPTION_KEY" usage:"Base64-encoded 32-byte AES-256 key for encrypting sensitive data (optional)"` // Server configuration - Port string `name:"port" env:"PORT" usage:"Port to run the server on" default:"8080"` - Host string `name:"host" env:"HOST" usage:"Host to bind the server to" default:"localhost"` + Port string `name:"port" env:"PORT" usage:"Port to run the server on" default:"8080"` + Host string `name:"host" env:"HOST" usage:"Host to bind the server to" default:"localhost"` + RoutePrefix string `name:"route-prefix" env:"ROUTE_PREFIX" usage:"Optional prefix for all routes (e.g., '/oauth2')"` // Logging Verbose bool `name:"verbose,v" usage:"Enable verbose logging"` @@ -69,6 +70,7 @@ func (c *RootCmd) Run(cobraCmd *cobra.Command, args []string) error { MCPServerURL: c.MCPServerURL, EncryptionKey: c.EncryptionKey, Mode: c.Mode, + RoutePrefix: c.RoutePrefix, } // Validate configuration diff --git a/main.go b/main.go index 2bfccf4..2893544 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "github.com/obot-platform/mcp-oauth-proxy/cmd" @@ -8,6 +9,7 @@ import ( func main() { if err := cmd.Execute(); err != nil { + fmt.Println(err) os.Exit(1) } } diff --git a/pkg/oauth/authorize/authorize.go b/pkg/oauth/authorize/authorize.go index cf03b37..5fc66ac 100644 --- a/pkg/oauth/authorize/authorize.go +++ b/pkg/oauth/authorize/authorize.go @@ -24,15 +24,17 @@ type Handler struct { scopesSupported []string clientID string clientSecret string + routePrefix string } -func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string, clientID, clientSecret string) http.Handler { +func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string, clientID, clientSecret, routePrefix string) http.Handler { return &Handler{ db: db, provider: provider, scopesSupported: scopesSupported, clientID: clientID, clientSecret: clientSecret, + routePrefix: routePrefix, } } @@ -139,7 +141,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r)) + redirectURI := fmt.Sprintf("%s%s/callback", handlerutils.GetBaseURL(r), p.routePrefix) // Generate authorization URL with the provider authURL := p.provider.GetAuthorizationURL( diff --git a/pkg/oauth/callback/callback.go b/pkg/oauth/callback/callback.go index 4806112..090bc31 100644 --- a/pkg/oauth/callback/callback.go +++ b/pkg/oauth/callback/callback.go @@ -30,6 +30,7 @@ type Handler struct { clientID string clientSecret string mcpUIManager MCPUIManager + routePrefix string } // MCPUIManager interface for generating JWT tokens @@ -37,7 +38,7 @@ type MCPUIManager interface { GenerateMCPUICodeForDownstream(bearerToken, refreshToken string) (string, error) } -func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret string, mcpUIManager MCPUIManager) http.Handler { +func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret, routePrefix string, mcpUIManager MCPUIManager) http.Handler { return &Handler{ db: db, provider: provider, @@ -45,6 +46,7 @@ func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, cli clientID: clientID, clientSecret: clientSecret, mcpUIManager: mcpUIManager, + routePrefix: routePrefix, } } @@ -150,7 +152,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() // Get provider credentials - redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r)) + redirectURI := fmt.Sprintf("%s%s/callback", handlerutils.GetBaseURL(r), p.routePrefix) // Exchange code for tokens tokenInfo, err := p.provider.ExchangeCodeForToken(r.Context(), code, p.clientID, p.clientSecret, redirectURI) diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 35345e5..fdca45f 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -192,27 +192,30 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) { log.Fatalf("Failed to get provider: %v", err) } - authorizeHandler := authorize.NewHandler(p.db, provider, p.metadata.ScopesSupported, p.GetOAuthClientID(), p.GetOAuthClientSecret()) + authorizeHandler := authorize.NewHandler(p.db, provider, p.metadata.ScopesSupported, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.config.RoutePrefix) tokenHandler := token.NewHandler(p.db) - callbackHandler := callback.NewHandler(p.db, provider, p.encryptionKey, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.mcpUIManager) + callbackHandler := callback.NewHandler(p.db, provider, p.encryptionKey, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.config.RoutePrefix, p.mcpUIManager) revokeHandler := revoke.NewHandler(p.db) tokenValidator := validate.NewTokenValidator(p.tokenManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported) - mux.HandleFunc("GET /health", p.withCORS(p.healthHandler)) + // Get route prefix from config + prefix := p.config.RoutePrefix + + mux.HandleFunc("GET "+prefix+"/health", p.withCORS(p.healthHandler)) // OAuth endpoints - mux.HandleFunc("GET /authorize", p.withCORS(p.withRateLimit(authorizeHandler))) - mux.HandleFunc("GET /callback", p.withCORS(p.withRateLimit(callbackHandler))) - mux.HandleFunc("POST /token", p.withCORS(p.withRateLimit(tokenHandler))) - mux.HandleFunc("POST /revoke", p.withCORS(p.withRateLimit(revokeHandler))) - mux.HandleFunc("POST /register", p.withCORS(p.withRateLimit(register.NewHandler(p.db)))) + mux.HandleFunc("GET "+prefix+"/authorize", p.withCORS(p.withRateLimit(authorizeHandler))) + mux.HandleFunc("GET "+prefix+"/callback", p.withCORS(p.withRateLimit(callbackHandler))) + mux.HandleFunc("POST "+prefix+"/token", p.withCORS(p.withRateLimit(tokenHandler))) + mux.HandleFunc("POST "+prefix+"/revoke", p.withCORS(p.withRateLimit(revokeHandler))) + mux.HandleFunc("POST "+prefix+"/register", p.withCORS(p.withRateLimit(register.NewHandler(p.db)))) // Metadata endpoints mux.HandleFunc("GET /.well-known/oauth-authorization-server", p.withCORS(p.oauthMetadataHandler)) mux.HandleFunc("GET /.well-known/oauth-protected-resource", p.withCORS(p.protectedResourceMetadataHandler)) // Protect everything else - mux.HandleFunc("/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler)))) + mux.HandleFunc(prefix+"/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler)))) } // GetHandler returns an http.Handler for the OAuth proxy @@ -270,21 +273,22 @@ func (p *OAuthProxy) healthHandler(w http.ResponseWriter, r *http.Request) { func (p *OAuthProxy) oauthMetadataHandler(w http.ResponseWriter, r *http.Request) { baseURL := handlerutils.GetBaseURL(r) + prefix := p.config.RoutePrefix // Create dynamic metadata based on the request metadata := &types.OAuthMetadata{ Issuer: baseURL, ServiceDocumentation: p.metadata.ServiceDocumentation, - AuthorizationEndpoint: fmt.Sprintf("%s/authorize", baseURL), + AuthorizationEndpoint: fmt.Sprintf("%s%s/authorize", baseURL, prefix), ResponseTypesSupported: p.metadata.ResponseTypesSupported, CodeChallengeMethodsSupported: p.metadata.CodeChallengeMethodsSupported, - TokenEndpoint: fmt.Sprintf("%s/token", baseURL), + TokenEndpoint: fmt.Sprintf("%s%s/token", baseURL, prefix), TokenEndpointAuthMethodsSupported: p.metadata.TokenEndpointAuthMethodsSupported, GrantTypesSupported: p.metadata.GrantTypesSupported, ScopesSupported: p.metadata.ScopesSupported, - RevocationEndpoint: fmt.Sprintf("%s/revoke", baseURL), + RevocationEndpoint: fmt.Sprintf("%s%s/revoke", baseURL, prefix), RevocationEndpointAuthMethodsSupported: p.metadata.RevocationEndpointAuthMethodsSupported, - RegistrationEndpoint: fmt.Sprintf("%s/register", baseURL), + RegistrationEndpoint: fmt.Sprintf("%s%s/register", baseURL, prefix), RegistrationEndpointAuthMethodsSupported: p.metadata.RegistrationEndpointAuthMethodsSupported, } @@ -293,9 +297,12 @@ func (p *OAuthProxy) oauthMetadataHandler(w http.ResponseWriter, r *http.Request func (p *OAuthProxy) protectedResourceMetadataHandler(w http.ResponseWriter, r *http.Request) { baseURL := handlerutils.GetBaseURL(r) + prefix := p.config.RoutePrefix + resourceURL := baseURL + prefix + metadata := types.OAuthProtectedResourceMetadata{ - Resource: baseURL, - AuthorizationServers: []string{baseURL}, + Resource: resourceURL, + AuthorizationServers: []string{baseURL + prefix}, Scopes: p.metadata.ScopesSupported, ResourceName: p.resourceName, ResourceDocumentation: p.metadata.ServiceDocumentation, diff --git a/pkg/types/types.go b/pkg/types/types.go index f5bd47f..595f1d4 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -15,6 +15,7 @@ type Config struct { EncryptionKey string MCPServerURL string Mode string + RoutePrefix string } // TokenData represents stored token data for OAuth 2.1 compliance