From f9a613b584c05b4637ff92692f50a3679c2ae6f7 Mon Sep 17 00:00:00 2001 From: Andrew Dunstall Date: Mon, 23 Dec 2024 13:47:47 +0000 Subject: [PATCH 1/3] client: add upstream tenant id option --- client/upstream.go | 7 +++++++ pkg/websocket/conn.go | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/client/upstream.go b/client/upstream.go index 7575a8c..dbb6193 100644 --- a/client/upstream.go +++ b/client/upstream.go @@ -41,6 +41,12 @@ type Upstream struct { // Defaults to no authentication. Token string + // TenantID configures the tenant to authenticate the listener with the + // Piko server. + // + // Defaults to no tenant (optional). + TenantID string + // TLSConfig specifies the TLS configuration to use with the Piko server. // // If nil, the default configuration is used. @@ -110,6 +116,7 @@ func (u *Upstream) connect(ctx context.Context, endpointID string) (*yamux.Sessi ctx, url, websocket.WithToken(u.Token), + websocket.WithTenantID(u.TenantID), websocket.WithTLSConfig(u.TLSConfig), ) if err == nil { diff --git a/pkg/websocket/conn.go b/pkg/websocket/conn.go index c9cdb40..a053318 100644 --- a/pkg/websocket/conn.go +++ b/pkg/websocket/conn.go @@ -49,6 +49,7 @@ func (e *RetryableError) Error() string { type dialOptions struct { token string + tenantID string tlsConfig *tls.Config } @@ -66,6 +67,16 @@ func WithToken(token string) DialOption { return tokenOption(token) } +type tenantIDOption string + +func (o tenantIDOption) apply(opts *dialOptions) { + opts.tenantID = string(o) +} + +func WithTenantID(tenantID string) DialOption { + return tenantIDOption(tenantID) +} + type tlsConfigOption struct { TLSConfig *tls.Config } @@ -113,6 +124,9 @@ func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) { if options.token != "" { header.Set("Authorization", "Bearer "+options.token) } + if options.tenantID != "" { + header.Set("x-piko-tenant-id", options.tenantID) + } wsConn, resp, err := dialer.DialContext( ctx, url, header, From 3a05d0611d8e716972b02cdee8e56e84a7d42c63 Mon Sep 17 00:00:00 2001 From: Andrew Dunstall Date: Mon, 23 Dec 2024 13:55:51 +0000 Subject: [PATCH 2/3] agent: add upstream tenant id config --- agent/config/config.go | 18 ++++++++++++++++-- cli/agent/command.go | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/agent/config/config.go b/agent/config/config.go index f382c12..7fabe2b 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -226,10 +226,13 @@ func (c *TLSConfig) Load() (*tls.Config, error) { type ConnectConfig struct { // URL is the Piko server URL to connect to. - URL string + URL string `json:"url" yaml:"url"` // Token is a token to authenticate with the Piko server. - Token string + Token string `json:"token" yaml:"token"` + + // TenantID is the ID of the agent tenant (optional). + TenantID string `json:"tenant_id" yaml:"tenant_id"` // Timeout is the timeout attempting to connect to the Piko server on // boot. @@ -272,6 +275,17 @@ Piko server 'upstream' port.`, Token is a token to authenticate with the Piko server.`, ) + fs.StringVar( + &c.TenantID, + "connect.tenant-id", + c.TenantID, + ` +Tenant ID of the agent. + +Tenants can be used to configure different authentication mechanisms and keys +for different upstream services.`, + ) + fs.DurationVar( &c.Timeout, "connect.timeout", diff --git a/cli/agent/command.go b/cli/agent/command.go index 1217dab..815a418 100644 --- a/cli/agent/command.go +++ b/cli/agent/command.go @@ -116,6 +116,7 @@ func runAgent(conf *config.Config, logger log.Logger) error { upstream := &client.Upstream{ URL: connectURL, Token: conf.Connect.Token, + TenantID: conf.Connect.TenantID, TLSConfig: connectTLSConfig, Logger: logger.WithSubsystem("client"), } From 66224be653a26dda12da9d264eec1d1161e4cf61 Mon Sep 17 00:00:00 2001 From: Andrew Dunstall Date: Mon, 23 Dec 2024 14:29:29 +0000 Subject: [PATCH 3/3] server: add upstream multi-tenant auth support --- pkg/auth/multi_tenant_verifier.go | 38 ++++++++++++ pkg/auth/multi_tenant_verifier_test.go | 75 ++++++++++++++++++++++++ pkg/auth/verifier.go | 8 ++- pkg/middleware/auth.go | 23 +++++++- pkg/middleware/auth_test.go | 81 ++++++++++++++++++++++---- server/admin/server.go | 2 +- server/admin/server_test.go | 8 +-- server/config/config.go | 30 ++++++++++ server/config/config_test.go | 22 +++++++ server/proxy/server.go | 2 +- server/proxy/server_test.go | 16 ++--- server/server.go | 29 +++++++-- server/upstream/server.go | 6 +- server/upstream/server_test.go | 20 +++---- 14 files changed, 314 insertions(+), 46 deletions(-) create mode 100644 pkg/auth/multi_tenant_verifier.go create mode 100644 pkg/auth/multi_tenant_verifier_test.go diff --git a/pkg/auth/multi_tenant_verifier.go b/pkg/auth/multi_tenant_verifier.go new file mode 100644 index 0000000..ab609d1 --- /dev/null +++ b/pkg/auth/multi_tenant_verifier.go @@ -0,0 +1,38 @@ +package auth + +type MultiTenantVerifier struct { + defaultVerifier Verifier + tenantVerifiers map[string]Verifier +} + +func NewMultiTenantVerifier( + defaultVerifier Verifier, + tenantVerifiers map[string]Verifier, +) *MultiTenantVerifier { + return &MultiTenantVerifier{ + defaultVerifier: defaultVerifier, + tenantVerifiers: tenantVerifiers, + } +} + +func (v *MultiTenantVerifier) Verify(token string, tenantID string) (*Token, error) { + if tenantID == "" { + return v.defaultVerifier.Verify(token) + } + + if v.tenantVerifiers == nil { + return nil, ErrUnknownTenant + } + + verifier, ok := v.tenantVerifiers[tenantID] + if !ok { + return nil, ErrUnknownTenant + } + + t, err := verifier.Verify(token) + if err != nil { + return nil, err + } + t.TenantID = tenantID + return t, nil +} diff --git a/pkg/auth/multi_tenant_verifier_test.go b/pkg/auth/multi_tenant_verifier_test.go new file mode 100644 index 0000000..fc4ea7d --- /dev/null +++ b/pkg/auth/multi_tenant_verifier_test.go @@ -0,0 +1,75 @@ +package auth + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestMultiTenantVerifier(t *testing.T) { + defaultSecretKey := generateTestHSKey(t) + tenant1SecretKey := generateTestHSKey(t) + tenant2SecretKey := generateTestHSKey(t) + + endpointClaims := JWTClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + defaultVerifier := NewJWTVerifier(&LoadedConfig{ + HMACSecretKey: defaultSecretKey, + }) + tenantVerifiers := map[string]Verifier{ + "tenant-1": NewJWTVerifier(&LoadedConfig{ + HMACSecretKey: tenant1SecretKey, + }), + "tenant-2": NewJWTVerifier(&LoadedConfig{ + HMACSecretKey: tenant2SecretKey, + }), + } + verifier := NewMultiTenantVerifier(defaultVerifier, tenantVerifiers) + + t.Run("default", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims) + tokenString, err := token.SignedString([]byte(defaultSecretKey)) + assert.NoError(t, err) + + parsedToken, err := verifier.Verify(tokenString, "") + assert.NoError(t, err) + + assert.Equal(t, parsedToken.TenantID, "") + }) + + t.Run("tenant", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims) + tokenString, err := token.SignedString([]byte(tenant1SecretKey)) + assert.NoError(t, err) + + parsedToken, err := verifier.Verify(tokenString, "tenant-1") + assert.NoError(t, err) + + assert.Equal(t, parsedToken.TenantID, "tenant-1") + }) + + // Tests tenant 1 using tenants 2 key. + t.Run("incorrect key", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims) + tokenString, err := token.SignedString([]byte(tenant2SecretKey)) + assert.NoError(t, err) + + _, err = verifier.Verify(tokenString, "tenant-1") + assert.Equal(t, ErrInvalidToken, err) + }) + + t.Run("unknown tenant", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims) + tokenString, err := token.SignedString([]byte(defaultSecretKey)) + assert.NoError(t, err) + + _, err = verifier.Verify(tokenString, "unknown") + assert.Equal(t, ErrUnknownTenant, err) + }) +} diff --git a/pkg/auth/verifier.go b/pkg/auth/verifier.go index 75a74b6..ff0f8bb 100644 --- a/pkg/auth/verifier.go +++ b/pkg/auth/verifier.go @@ -7,8 +7,9 @@ import ( ) var ( - ErrInvalidToken = errors.New("invalid token") - ErrExpiredToken = errors.New("expired token") + ErrInvalidToken = errors.New("invalid token") + ErrExpiredToken = errors.New("expired token") + ErrUnknownTenant = errors.New("unknown tenant") ) // Token represents an authenticated Piko token. @@ -21,6 +22,9 @@ type Token struct { // to access (either connect to or listen on). If empty then all endpoints // are allowed. Endpoints []string + + // TenantID is the ID of the client tenant. + TenantID string } // EndpointPermitted returns whether the token it permitted to access the diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index 44b1440..7057baa 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -18,11 +18,11 @@ const ( // Auth is middleware to verify token requests. type Auth struct { - verifier auth.Verifier + verifier *auth.MultiTenantVerifier logger log.Logger } -func NewAuth(verifier auth.Verifier, logger log.Logger) *Auth { +func NewAuth(verifier *auth.MultiTenantVerifier, logger log.Logger) *Auth { return &Auth{ verifier: verifier, logger: logger, @@ -38,7 +38,9 @@ func (m *Auth) Verify(c *gin.Context) { return } - token, err := m.verifier.Verify(tokenString) + tenantID := m.parseTenant(c) + + token, err := m.verifier.Verify(tokenString, tenantID) if err != nil { if errors.Is(err, auth.ErrInvalidToken) { m.logger.Warn( @@ -62,6 +64,17 @@ func (m *Auth) Verify(c *gin.Context) { ) return } + if errors.Is(err, auth.ErrUnknownTenant) { + m.logger.Warn( + "auth unknwon tenant", + zap.Error(err), + ) + c.AbortWithStatusJSON( + http.StatusUnauthorized, + gin.H{"error": "unknown tenant"}, + ) + return + } m.logger.Warn( "unknown verification error", @@ -114,3 +127,7 @@ func (m *Auth) parseToken(c *gin.Context) (string, bool) { return tokenString, true } + +func (m *Auth) parseTenant(c *gin.Context) string { + return c.Request.Header.Get("x-piko-tenant-id") +} diff --git a/pkg/middleware/auth_test.go b/pkg/middleware/auth_test.go index 41f78eb..e67568c 100644 --- a/pkg/middleware/auth_test.go +++ b/pkg/middleware/auth_test.go @@ -31,7 +31,7 @@ type errorMessage struct { func TestAuth(t *testing.T) { t.Run("authorization ok", func(t *testing.T) { - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ @@ -39,7 +39,7 @@ func TestAuth(t *testing.T) { Endpoints: []string{"e1", "e2", "e3"}, }, nil }, - } + }, nil) m := NewAuth(verifier, log.NewNopLogger()) w := httptest.NewRecorder() @@ -59,7 +59,7 @@ func TestAuth(t *testing.T) { }) t.Run("x-piko-authorization ok", func(t *testing.T) { - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ @@ -67,7 +67,7 @@ func TestAuth(t *testing.T) { Endpoints: []string{"e1", "e2", "e3"}, }, nil }, - } + }, nil) m := NewAuth(verifier, log.NewNopLogger()) w := httptest.NewRecorder() @@ -89,13 +89,49 @@ func TestAuth(t *testing.T) { assert.Equal(t, []string{"e1", "e2", "e3"}, token.(*auth.Token).Endpoints) }) + t.Run("tenant ok", func(t *testing.T) { + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ + handler: func(_ string) (*auth.Token, error) { + t.Error("expected tenant") + return nil, fmt.Errorf("expected tenant") + }, + }, map[string]auth.Verifier{ + "tenant-1": &fakeVerifier{ + handler: func(token string) (*auth.Token, error) { + assert.Equal(t, "123", token) + return &auth.Token{ + Expiry: time.Now().Add(time.Hour), + Endpoints: []string{"e1", "e2", "e3"}, + }, nil + }, + }, + }) + m := NewAuth(verifier, log.NewNopLogger()) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "http://example.com/foo", nil) + c.Request.Header.Add("Authorization", "Bearer 123") + c.Request.Header.Add("x-piko-tenant-id", "tenant-1") + + m.Verify(c) + + resp := w.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify the token was added to context. + token, ok := c.Get(TokenContextKey) + assert.True(t, ok) + assert.Equal(t, []string{"e1", "e2", "e3"}, token.(*auth.Token).Endpoints) + }) + t.Run("invalid token", func(t *testing.T) { - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{}, fmt.Errorf("foo: %w", auth.ErrInvalidToken) }, - } + }, nil) m := NewAuth(verifier, log.NewNopLogger()) w := httptest.NewRecorder() @@ -114,12 +150,12 @@ func TestAuth(t *testing.T) { }) t.Run("expired token", func(t *testing.T) { - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{}, fmt.Errorf("foo: %w", auth.ErrExpiredToken) }, - } + }, nil) m := NewAuth(verifier, log.NewNopLogger()) w := httptest.NewRecorder() @@ -137,13 +173,38 @@ func TestAuth(t *testing.T) { assert.Equal(t, "expired token", errMessage.Error) }) + t.Run("unknown tenant", func(t *testing.T) { + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ + handler: func(token string) (*auth.Token, error) { + assert.Equal(t, "123", token) + return &auth.Token{}, fmt.Errorf("foo: %w", auth.ErrExpiredToken) + }, + }, nil) + m := NewAuth(verifier, log.NewNopLogger()) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "http://example.com/foo", nil) + c.Request.Header.Add("Authorization", "Bearer 123") + c.Request.Header.Add("x-piko-tenant-id", "tenant-123") + + m.Verify(c) + + resp := w.Result() + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + var errMessage errorMessage + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&errMessage)) + assert.Equal(t, "unknown tenant", errMessage.Error) + }) + t.Run("unknown error", func(t *testing.T) { - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{}, fmt.Errorf("unknown") }, - } + }, nil) m := NewAuth(verifier, log.NewNopLogger()) w := httptest.NewRecorder() diff --git a/server/admin/server.go b/server/admin/server.go index 306b8f7..795fcb4 100644 --- a/server/admin/server.go +++ b/server/admin/server.go @@ -43,7 +43,7 @@ type Server struct { func NewServer( clusterState *cluster.State, registry *prometheus.Registry, - verifier auth.Verifier, + verifier *auth.MultiTenantVerifier, tlsConfig *tls.Config, logger log.Logger, ) *Server { diff --git a/server/admin/server_test.go b/server/admin/server_test.go index 36ee39d..9d69c87 100644 --- a/server/admin/server_test.go +++ b/server/admin/server_test.go @@ -235,14 +235,14 @@ func TestServer_Authentication(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ Expiry: time.Now().Add(time.Hour), }, nil }, - } + }, nil) s := NewServer( nil, @@ -272,12 +272,12 @@ func TestServer_Authentication(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return nil, auth.ErrInvalidToken }, - } + }, nil) s := NewServer( nil, diff --git a/server/config/config.go b/server/config/config.go index eec52e2..f8be2cb 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -84,6 +84,24 @@ keys and values, including the request line.`, ) } +type TenantConfig struct { + ID string `json:"id" yaml:"id"` + + Auth auth.Config `json:"auth" yaml:"auth"` +} + +func (c *TenantConfig) Validate() error { + if c.ID == "" { + return fmt.Errorf("missing tenant id") + } + if !c.Auth.Enabled() { + // Require tenants to be authenticated (theres no point otherwise). + return fmt.Errorf("tenant auth disabled") + } + + return nil +} + type ProxyConfig struct { // BindAddr is the address to bind to listen for incoming HTTP connections. BindAddr string `json:"bind_addr" yaml:"bind_addr"` @@ -177,6 +195,8 @@ type UpstreamConfig struct { Auth auth.Config `json:"auth" yaml:"auth"` TLS TLSConfig `json:"tls" yaml:"tls"` + + Tenants []TenantConfig `json:"tenants" yaml:"tenants"` } func (c *UpstreamConfig) Validate() error { @@ -186,6 +206,16 @@ func (c *UpstreamConfig) Validate() error { if err := c.TLS.Validate(); err != nil { return fmt.Errorf("tls: %w", err) } + if !c.Auth.Enabled() && len(c.Tenants) > 0 { + // Require default authentication when enabling tenants (theres no point + // otherwise). + return fmt.Errorf("cannot have tenants without default authentication") + } + for _, tenant := range c.Tenants { + if err := tenant.Validate(); err != nil { + return fmt.Errorf("tenant: %w", err) + } + } return nil } diff --git a/server/config/config_test.go b/server/config/config_test.go index f356095..318fbed 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -63,6 +63,14 @@ upstream: cert: /piko/cert.pem key: /piko/key.pem + tenants: + - id: tenant-1 + auth: + hmac_secret_key: hmac-secret-key + - id: tenant-2 + auth: + rsa_public_key: rsa-public-key + admin: bind_addr: 10.15.104.25:8002 advertise_addr: 1.2.3.4:8002 @@ -154,6 +162,20 @@ grace_period: 2m Cert: "/piko/cert.pem", Key: "/piko/key.pem", }, + Tenants: []TenantConfig{ + { + ID: "tenant-1", + Auth: auth.Config{ + HMACSecretKey: "hmac-secret-key", + }, + }, + { + ID: "tenant-2", + Auth: auth.Config{ + RSAPublicKey: "rsa-public-key", + }, + }, + }, }, Admin: AdminConfig{ BindAddr: "10.15.104.25:8002", diff --git a/server/proxy/server.go b/server/proxy/server.go index 99c00c2..996cd6c 100644 --- a/server/proxy/server.go +++ b/server/proxy/server.go @@ -33,7 +33,7 @@ func NewServer( upstreams upstream.Manager, proxyConfig config.ProxyConfig, registry *prometheus.Registry, - verifier auth.Verifier, + verifier *auth.MultiTenantVerifier, tlsConfig *tls.Config, logger log.Logger, ) *Server { diff --git a/server/proxy/server_test.go b/server/proxy/server_test.go index c581d07..228ae78 100644 --- a/server/proxy/server_test.go +++ b/server/proxy/server_test.go @@ -430,7 +430,7 @@ func TestServer_Authentication(t *testing.T) { )) defer upstreamServer.Close() - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ @@ -438,7 +438,7 @@ func TestServer_Authentication(t *testing.T) { Endpoints: []string{"my-endpoint"}, }, nil }, - } + }, nil) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -484,7 +484,7 @@ func TestServer_Authentication(t *testing.T) { }) t.Run("endpoint not permitted", func(t *testing.T) { - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ @@ -492,7 +492,7 @@ func TestServer_Authentication(t *testing.T) { Endpoints: []string{"my-endpoint"}, }, nil }, - } + }, nil) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -544,14 +544,14 @@ func TestServer_Authentication(t *testing.T) { )) defer upstreamServer.Close() - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ Expiry: time.Now().Add(time.Hour), }, nil }, - } + }, nil) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -597,12 +597,12 @@ func TestServer_Authentication(t *testing.T) { }) t.Run("unauthenticated", func(t *testing.T) { - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return nil, auth.ErrInvalidToken }, - } + }, nil) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/server/server.go b/server/server.go index 7a2443a..dbe3465 100644 --- a/server/server.go +++ b/server/server.go @@ -116,13 +116,15 @@ func NewServer(conf *config.Config, logger log.Logger) (*Server, error) { // Proxy server. - var proxyVerifier auth.Verifier + var proxyVerifier *auth.MultiTenantVerifier if conf.Proxy.Auth.Enabled() { verifierConf, err := conf.Proxy.Auth.Load() if err != nil { return nil, fmt.Errorf("proxy: load auth: %w", err) } - proxyVerifier = auth.NewJWTVerifier(verifierConf) + proxyVerifier = auth.NewMultiTenantVerifier( + auth.NewJWTVerifier(verifierConf), nil, + ) } proxyTLSConfig, err := conf.Proxy.TLS.Load() if err != nil { @@ -139,13 +141,26 @@ func NewServer(conf *config.Config, logger log.Logger) (*Server, error) { // Upstream server. - var upstreamVerifier auth.Verifier + var upstreamVerifier *auth.MultiTenantVerifier if conf.Upstream.Auth.Enabled() { verifierConf, err := conf.Upstream.Auth.Load() if err != nil { return nil, fmt.Errorf("upstream: load auth: %w", err) } - upstreamVerifier = auth.NewJWTVerifier(verifierConf) + defaultUpstreamVerifier := auth.NewJWTVerifier(verifierConf) + + upstreamTenantVerifiers := make(map[string]auth.Verifier) + for _, tenantConf := range conf.Upstream.Tenants { + tenantVerifierConf, err := tenantConf.Auth.Load() + if err != nil { + return nil, fmt.Errorf("upstream: tenant %s: load auth: %w", tenantConf.ID, err) + } + upstreamTenantVerifiers[tenantConf.ID] = auth.NewJWTVerifier(tenantVerifierConf) + } + + upstreamVerifier = auth.NewMultiTenantVerifier( + defaultUpstreamVerifier, upstreamTenantVerifiers, + ) } upstreamTLSConfig, err := conf.Upstream.TLS.Load() if err != nil { @@ -160,13 +175,15 @@ func NewServer(conf *config.Config, logger log.Logger) (*Server, error) { // Admin server. - var adminVerifier auth.Verifier + var adminVerifier *auth.MultiTenantVerifier if conf.Admin.Auth.Enabled() { verifierConf, err := conf.Admin.Auth.Load() if err != nil { return nil, fmt.Errorf("admin: load auth: %w", err) } - adminVerifier = auth.NewJWTVerifier(verifierConf) + adminVerifier = auth.NewMultiTenantVerifier( + auth.NewJWTVerifier(verifierConf), nil, + ) } adminTLSConfig, err := conf.Admin.TLS.Load() if err != nil { diff --git a/server/upstream/server.go b/server/upstream/server.go index e51f45f..9c92a9b 100644 --- a/server/upstream/server.go +++ b/server/upstream/server.go @@ -36,7 +36,7 @@ type Server struct { func NewServer( upstreams Manager, - verifier auth.Verifier, + verifier *auth.MultiTenantVerifier, tlsConfig *tls.Config, logger log.Logger, ) *Server { @@ -102,6 +102,7 @@ func (s *Server) Shutdown(ctx context.Context) error { func (s *Server) upstreamRoute(c *gin.Context) { endpointID := c.Param("endpointID") + var tenantID string token, ok := c.Get(middleware.TokenContextKey) if ok { // If the token contains a set of permitted endpoints, verify the @@ -121,6 +122,7 @@ func (s *Server) upstreamRoute(c *gin.Context) { ) return } + tenantID = endpointToken.TenantID } wsConn, err := s.websocketUpgrader.Upgrade(c.Writer, c.Request, nil) @@ -136,11 +138,13 @@ func (s *Server) upstreamRoute(c *gin.Context) { "upstream connected", zap.String("endpoint-id", endpointID), zap.String("client-ip", c.ClientIP()), + zap.String("tenant-id", tenantID), ) defer s.logger.Info( "upstream disconnected", zap.String("endpoint-id", endpointID), zap.String("client-ip", c.ClientIP()), + zap.String("tenant-id", tenantID), ) ctx := s.ctx diff --git a/server/upstream/server_test.go b/server/upstream/server_test.go index c40c545..7f84609 100644 --- a/server/upstream/server_test.go +++ b/server/upstream/server_test.go @@ -118,7 +118,7 @@ func TestServer_Authentication(t *testing.T) { manager := newFakeManager() - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ @@ -126,7 +126,7 @@ func TestServer_Authentication(t *testing.T) { Endpoints: []string{"my-endpoint"}, }, nil }, - } + }, nil) s := NewServer(manager, verifier, nil, log.NewNopLogger()) go func() { @@ -156,7 +156,7 @@ func TestServer_Authentication(t *testing.T) { manager := newFakeManager() - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ @@ -165,7 +165,7 @@ func TestServer_Authentication(t *testing.T) { Endpoints: []string{"my-endpoint"}, }, nil }, - } + }, nil) s := NewServer(manager, verifier, nil, log.NewNopLogger()) go func() { @@ -196,7 +196,7 @@ func TestServer_Authentication(t *testing.T) { manager := newFakeManager() - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ @@ -204,7 +204,7 @@ func TestServer_Authentication(t *testing.T) { Endpoints: []string{"foo"}, }, nil }, - } + }, nil) s := NewServer(manager, verifier, nil, log.NewNopLogger()) go func() { @@ -228,14 +228,14 @@ func TestServer_Authentication(t *testing.T) { manager := newFakeManager() - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return &auth.Token{ Expiry: time.Now().Add(time.Hour), }, nil }, - } + }, nil) s := NewServer(manager, verifier, nil, log.NewNopLogger()) go func() { @@ -265,12 +265,12 @@ func TestServer_Authentication(t *testing.T) { manager := newFakeManager() - verifier := &fakeVerifier{ + verifier := auth.NewMultiTenantVerifier(&fakeVerifier{ handler: func(token string) (*auth.Token, error) { assert.Equal(t, "123", token) return nil, auth.ErrInvalidToken }, - } + }, nil) s := NewServer(manager, verifier, nil, log.NewNopLogger()) go func() {