Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package cmd

import (
"fmt"
"log"
"net/http"
"net/url"

"github.com/gptscript-ai/cmd"
"github.com/obot-platform/mcp-oauth-proxy/pkg/proxy"
"github.com/obot-platform/mcp-oauth-proxy/pkg/types"
"github.com/spf13/cobra"
)

var (
version = "dev"
buildTime = "unknown"
)

// RootCmd represents the base command when called without any subcommands
type RootCmd struct {
// Database configuration
DatabaseDSN string `name:"database-dsn" env:"DATABASE_DSN" usage:"Database connection string (PostgreSQL or SQLite file path). If empty, uses SQLite at data/oauth_proxy.db"`

// OAuth Provider configuration
OAuthClientID string `name:"oauth-client-id" env:"OAUTH_CLIENT_ID" usage:"OAuth client ID from your OAuth provider" required:"true"`
OAuthClientSecret string `name:"oauth-client-secret" env:"OAUTH_CLIENT_SECRET" usage:"OAuth client secret from your OAuth provider" required:"true"`
OAuthAuthorizeURL string `name:"oauth-authorize-url" env:"OAUTH_AUTHORIZE_URL" usage:"Authorization endpoint URL from your OAuth provider (e.g., https://accounts.google.com)" required:"true"`

// Scopes and MCP configuration
ScopesSupported string `name:"scopes-supported" env:"SCOPES_SUPPORTED" usage:"Comma-separated list of supported OAuth scopes (e.g., 'openid,profile,email')" required:"true"`
MCPServerURL string `name:"mcp-server-url" env:"MCP_SERVER_URL" usage:"URL of the MCP server to proxy requests to" required:"true"`

// Security configuration
EncryptionKey string `name:"encryption-key" env:"ENCRYPTION_KEY" usage:"Base64-encoded 32-byte AES-256 key for encrypting sensitive data (optional)"`

// Server configuration
Port string `name:"port" env:"PORT" usage:"Port to run the server on" default:"8080"`
Host string `name:"host" env:"HOST" usage:"Host to bind the server to" default:"localhost"`
RoutePrefix string `name:"route-prefix" env:"ROUTE_PREFIX" usage:"Optional prefix for all routes (e.g., '/oauth2')"`

// Logging
Verbose bool `name:"verbose,v" usage:"Enable verbose logging"`
Version bool `name:"version" usage:"Show version information"`

Mode string `name:"mode" env:"MODE" usage:"Mode to run the server in" default:"proxy"`
}

func (c *RootCmd) Run(cobraCmd *cobra.Command, args []string) error {
if c.Version {
fmt.Printf("MCP OAuth Proxy\n")
fmt.Printf("Version: %s\n", version)
fmt.Printf("Built: %s\n", buildTime)
return nil
}

// Configure logging
if c.Verbose {
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Println("Verbose logging enabled")
}

// Convert CLI config to internal config format
config := &types.Config{
DatabaseDSN: c.DatabaseDSN,
OAuthClientID: c.OAuthClientID,
OAuthClientSecret: c.OAuthClientSecret,
OAuthAuthorizeURL: c.OAuthAuthorizeURL,
ScopesSupported: c.ScopesSupported,
MCPServerURL: c.MCPServerURL,
EncryptionKey: c.EncryptionKey,
Mode: c.Mode,
RoutePrefix: c.RoutePrefix,
}

// Validate configuration
if err := c.validateConfig(); err != nil {
return fmt.Errorf("configuration validation failed: %w", err)
}

// Create OAuth proxy
oauthProxy, err := proxy.NewOAuthProxy(config)
if err != nil {
return fmt.Errorf("failed to create OAuth proxy: %w", err)
}
defer func() {
if err := oauthProxy.Close(); err != nil {
log.Printf("Error closing database: %v", err)
}
}()

// Get HTTP handler
handler := oauthProxy.GetHandler()

// Start server
address := fmt.Sprintf("%s:%s", c.Host, c.Port)
log.Printf("Starting OAuth proxy server on %s", address)
log.Printf("OAuth Provider: %s", c.OAuthAuthorizeURL)
log.Printf("MCP Server: %s", c.MCPServerURL)
log.Printf("Database: %s", c.getDatabaseType())

return http.ListenAndServe(address, handler)
}

func (c *RootCmd) validateConfig() error {
if c.OAuthClientID == "" {
return fmt.Errorf("oauth-client-id is required")
}
if c.OAuthClientSecret == "" {
return fmt.Errorf("oauth-client-secret is required")
}
if c.OAuthAuthorizeURL == "" {
return fmt.Errorf("oauth-authorize-url is required")
}
if c.ScopesSupported == "" {
return fmt.Errorf("scopes-supported is required")
}
if c.MCPServerURL == "" {
return fmt.Errorf("mcp-server-url is required")
}
if c.Mode == proxy.ModeProxy {
if u, err := url.Parse(c.MCPServerURL); err != nil || u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("invalid MCP server URL: %w", err)
} else if u.Path != "" && u.Path != "/" || u.RawQuery != "" || u.Fragment != "" {
return fmt.Errorf("MCP server URL must not contain a path, query, or fragment")
}
}
return nil
}

func (c *RootCmd) getDatabaseType() string {
if c.DatabaseDSN == "" {
return "SQLite (data/oauth_proxy.db)"
}
if len(c.DatabaseDSN) > 10 && (c.DatabaseDSN[:11] == "postgres://" || c.DatabaseDSN[:14] == "postgresql://") {
return "PostgreSQL"
}
return fmt.Sprintf("SQLite (%s)", c.DatabaseDSN)
}

// Customizer interface implementation for additional command customization
func (c *RootCmd) Customize(cobraCmd *cobra.Command) {
cobraCmd.Use = "mcp-oauth-proxy"
cobraCmd.Short = "OAuth 2.1 proxy server for MCP (Model Context Protocol)"
cobraCmd.Long = `MCP OAuth Proxy is a comprehensive OAuth 2.1 proxy server that provides
OAuth authorization server functionality with PostgreSQL/SQLite storage.

This proxy supports multiple OAuth providers (Google, Microsoft, GitHub) and
proxies requests to MCP servers with user context headers.

Examples:
# Start with environment variables
export OAUTH_CLIENT_ID="your-google-client-id"
export OAUTH_CLIENT_SECRET="your-secret"
export OAUTH_AUTHORIZE_URL="https://accounts.google.com"
export SCOPES_SUPPORTED="openid,profile,email"
export MCP_SERVER_URL="http://localhost:3000"
mcp-oauth-proxy

# Start with CLI flags
mcp-oauth-proxy \
--oauth-client-id="your-google-client-id" \
--oauth-client-secret="your-secret" \
--oauth-authorize-url="https://accounts.google.com" \
--scopes-supported="openid,profile,email" \
--mcp-server-url="http://localhost:3000"

# Use PostgreSQL database
mcp-oauth-proxy \
--database-dsn="postgres://user:pass@localhost:5432/oauth_db?sslmode=disable" \
--oauth-client-id="your-client-id" \
# ... other required flags

Configuration:
Configuration values are loaded in this order (later values override earlier ones):
1. Default values
2. Environment variables
3. Command line flags

Database Support:
- PostgreSQL: Full ACID compliance, recommended for production
- SQLite: Zero configuration, perfect for development and small deployments`

cobraCmd.Version = version
}

// Execute is the main entry point for the CLI
func Execute() error {
rootCmd := &RootCmd{}
cobraCmd := cmd.Command(rootCmd)
return cobraCmd.Execute()
}
7 changes: 6 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ module github.com/obot-platform/mcp-oauth-proxy
go 1.25.0

require (
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/gorilla/handlers v1.5.2
github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0
golang.org/x/oauth2 v0.30.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.30.1
Expand All @@ -13,6 +17,7 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.5 // indirect
Expand All @@ -23,8 +28,8 @@ require (
github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/text v0.28.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
12 changes: 12 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070 h1:xm5ZZFraWFwxyE7TBEncCXArubCDZTwG6s5bpMzqhSY=
github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
Expand All @@ -30,6 +37,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
30 changes: 6 additions & 24 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,15 @@
package main

import (
"log"
"net/http"
"fmt"
"os"

"github.com/obot-platform/mcp-oauth-proxy/pkg/proxy"
"github.com/obot-platform/mcp-oauth-proxy/cmd"
)

func main() {
// Load configuration from environment variables
config, err := proxy.LoadConfigFromEnv()
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
if err := cmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}

proxy, err := proxy.NewOAuthProxy(config)
if err != nil {
log.Fatalf("Failed to create OAuth proxy: %v", err)
}
defer func() {
if err := proxy.Close(); err != nil {
log.Printf("Error closing database: %v", err)
}
}()

// Get HTTP handler
handler := proxy.GetHandler()

// Start server
log.Print("Starting OAuth proxy server on localhost:" + config.Port)
log.Fatal(http.ListenAndServe(":"+config.Port, handler))
}
27 changes: 20 additions & 7 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/obot-platform/mcp-oauth-proxy/pkg/proxy"
"github.com/obot-platform/mcp-oauth-proxy/pkg/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -62,7 +63,10 @@ func TestIntegrationFlow(t *testing.T) {
}()

// Create OAuth proxy
config, err := proxy.LoadConfigFromEnv()
config := &types.Config{
Mode: proxy.ModeProxy,
}
_, err := proxy.NewOAuthProxy(config)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
Expand Down Expand Up @@ -180,7 +184,10 @@ func TestOAuthProxyCreation(t *testing.T) {
}()

// Create OAuth proxy
config, err := proxy.LoadConfigFromEnv()
config := &types.Config{
Mode: proxy.ModeProxy,
}
_, err := proxy.NewOAuthProxy(config)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
Expand Down Expand Up @@ -233,7 +240,10 @@ func TestOAuthProxyStart(t *testing.T) {
}()

// Create OAuth proxy
config, err := proxy.LoadConfigFromEnv()
config := &types.Config{
Mode: proxy.ModeProxy,
}
_, err := proxy.NewOAuthProxy(config)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
Expand Down Expand Up @@ -296,7 +306,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) {
"OAUTH_AUTHORIZE_URL": "https://accounts.google.com",
"SCOPES_SUPPORTED": "openid,profile,email",
"PROXY_MODE": "forward_auth",
"PORT": "8082", // Different port to avoid conflicts
"PORT": "8082", // Different port to avoid conflicts
"DATABASE_DSN": os.Getenv("TEST_DATABASE_DSN"), // Use test database if available
}

Expand All @@ -320,7 +330,10 @@ func TestForwardAuthIntegrationFlow(t *testing.T) {
}()

// Create OAuth proxy in forward auth mode
config, err := proxy.LoadConfigFromEnv()
config := &types.Config{
Mode: proxy.ModeForwardAuth,
}
_, err := proxy.NewOAuthProxy(config)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
Expand Down Expand Up @@ -375,7 +388,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) {
// Test that forward auth mode requires authorization for protected endpoints
t.Run("ForwardAuthRequiresAuth", func(t *testing.T) {
testPaths := []string{"/api", "/data", "/protected", "/mcp", "/test"}

for _, path := range testPaths {
t.Run("Path_"+path, func(t *testing.T) {
w := httptest.NewRecorder()
Expand All @@ -399,7 +412,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) {
// Should get unauthorized (no proxying attempt)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Header().Get("WWW-Authenticate"), "Bearer")

// Should not have any proxy-related error messages
assert.NotContains(t, w.Body.String(), "proxy")
assert.NotContains(t, w.Body.String(), "502")
Expand Down
8 changes: 0 additions & 8 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,6 @@ func (d *Store) RevokeToken(token string) error {
return result.Error
}

// UpdateTokenRefreshToken updates the refresh token for an existing token
func (d *Store) UpdateTokenRefreshToken(accessToken, newRefreshToken string) error {
hashedAccessToken := hashToken(accessToken)
hashedNewRefreshToken := hashToken(newRefreshToken)

return d.db.Model(&types.TokenData{}).Where("access_token = ?", hashedAccessToken).Update("refresh_token", hashedNewRefreshToken).Error
}

// CleanupExpiredTokens removes expired tokens and authorization codes
func (d *Store) CleanupExpiredTokens() error {
now := time.Now()
Expand Down
Loading
Loading