Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add upstream multi-tenant auth support #208

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,13 @@ func (c *TLSConfig) Load() (*tls.Config, error) {

type ConnectConfig struct {
// URL is the Piko server URL to connect to.
URL string
URL string `json:"url" yaml:"url"`

// Token is a token to authenticate with the Piko server.
Token string
Token string `json:"token" yaml:"token"`

// TenantID is the ID of the agent tenant (optional).
TenantID string `json:"tenant_id" yaml:"tenant_id"`

// Timeout is the timeout attempting to connect to the Piko server on
// boot.
Expand Down Expand Up @@ -272,6 +275,17 @@ Piko server 'upstream' port.`,
Token is a token to authenticate with the Piko server.`,
)

fs.StringVar(
&c.TenantID,
"connect.tenant-id",
c.TenantID,
`
Tenant ID of the agent.

Tenants can be used to configure different authentication mechanisms and keys
for different upstream services.`,
)

fs.DurationVar(
&c.Timeout,
"connect.timeout",
Expand Down
1 change: 1 addition & 0 deletions cli/agent/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func runAgent(conf *config.Config, logger log.Logger) error {
upstream := &client.Upstream{
URL: connectURL,
Token: conf.Connect.Token,
TenantID: conf.Connect.TenantID,
TLSConfig: connectTLSConfig,
Logger: logger.WithSubsystem("client"),
}
Expand Down
7 changes: 7 additions & 0 deletions client/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ type Upstream struct {
// Defaults to no authentication.
Token string

// TenantID configures the tenant to authenticate the listener with the
// Piko server.
//
// Defaults to no tenant (optional).
TenantID string

// TLSConfig specifies the TLS configuration to use with the Piko server.
//
// If nil, the default configuration is used.
Expand Down Expand Up @@ -110,6 +116,7 @@ func (u *Upstream) connect(ctx context.Context, endpointID string) (*yamux.Sessi
ctx,
url,
websocket.WithToken(u.Token),
websocket.WithTenantID(u.TenantID),
websocket.WithTLSConfig(u.TLSConfig),
)
if err == nil {
Expand Down
38 changes: 38 additions & 0 deletions pkg/auth/multi_tenant_verifier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package auth

type MultiTenantVerifier struct {
defaultVerifier Verifier
tenantVerifiers map[string]Verifier
}

func NewMultiTenantVerifier(
defaultVerifier Verifier,
tenantVerifiers map[string]Verifier,
) *MultiTenantVerifier {
return &MultiTenantVerifier{
defaultVerifier: defaultVerifier,
tenantVerifiers: tenantVerifiers,
}
}

func (v *MultiTenantVerifier) Verify(token string, tenantID string) (*Token, error) {
if tenantID == "" {
return v.defaultVerifier.Verify(token)
}

if v.tenantVerifiers == nil {
return nil, ErrUnknownTenant
}

verifier, ok := v.tenantVerifiers[tenantID]
if !ok {
return nil, ErrUnknownTenant
}

t, err := verifier.Verify(token)
if err != nil {
return nil, err
}
t.TenantID = tenantID
return t, nil
}
75 changes: 75 additions & 0 deletions pkg/auth/multi_tenant_verifier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package auth

import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)

func TestMultiTenantVerifier(t *testing.T) {
defaultSecretKey := generateTestHSKey(t)
tenant1SecretKey := generateTestHSKey(t)
tenant2SecretKey := generateTestHSKey(t)

endpointClaims := JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
}

defaultVerifier := NewJWTVerifier(&LoadedConfig{
HMACSecretKey: defaultSecretKey,
})
tenantVerifiers := map[string]Verifier{
"tenant-1": NewJWTVerifier(&LoadedConfig{
HMACSecretKey: tenant1SecretKey,
}),
"tenant-2": NewJWTVerifier(&LoadedConfig{
HMACSecretKey: tenant2SecretKey,
}),
}
verifier := NewMultiTenantVerifier(defaultVerifier, tenantVerifiers)

t.Run("default", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims)
tokenString, err := token.SignedString([]byte(defaultSecretKey))
assert.NoError(t, err)

parsedToken, err := verifier.Verify(tokenString, "")
assert.NoError(t, err)

assert.Equal(t, parsedToken.TenantID, "")
})

t.Run("tenant", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims)
tokenString, err := token.SignedString([]byte(tenant1SecretKey))
assert.NoError(t, err)

parsedToken, err := verifier.Verify(tokenString, "tenant-1")
assert.NoError(t, err)

assert.Equal(t, parsedToken.TenantID, "tenant-1")
})

// Tests tenant 1 using tenants 2 key.
t.Run("incorrect key", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims)
tokenString, err := token.SignedString([]byte(tenant2SecretKey))
assert.NoError(t, err)

_, err = verifier.Verify(tokenString, "tenant-1")
assert.Equal(t, ErrInvalidToken, err)
})

t.Run("unknown tenant", func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims)
tokenString, err := token.SignedString([]byte(defaultSecretKey))
assert.NoError(t, err)

_, err = verifier.Verify(tokenString, "unknown")
assert.Equal(t, ErrUnknownTenant, err)
})
}
8 changes: 6 additions & 2 deletions pkg/auth/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
)

var (
ErrInvalidToken = errors.New("invalid token")
ErrExpiredToken = errors.New("expired token")
ErrInvalidToken = errors.New("invalid token")
ErrExpiredToken = errors.New("expired token")
ErrUnknownTenant = errors.New("unknown tenant")
)

// Token represents an authenticated Piko token.
Expand All @@ -21,6 +22,9 @@ type Token struct {
// to access (either connect to or listen on). If empty then all endpoints
// are allowed.
Endpoints []string

// TenantID is the ID of the client tenant.
TenantID string
}

// EndpointPermitted returns whether the token it permitted to access the
Expand Down
23 changes: 20 additions & 3 deletions pkg/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ const (

// Auth is middleware to verify token requests.
type Auth struct {
verifier auth.Verifier
verifier *auth.MultiTenantVerifier
logger log.Logger
}

func NewAuth(verifier auth.Verifier, logger log.Logger) *Auth {
func NewAuth(verifier *auth.MultiTenantVerifier, logger log.Logger) *Auth {
return &Auth{
verifier: verifier,
logger: logger,
Expand All @@ -38,7 +38,9 @@ func (m *Auth) Verify(c *gin.Context) {
return
}

token, err := m.verifier.Verify(tokenString)
tenantID := m.parseTenant(c)

token, err := m.verifier.Verify(tokenString, tenantID)
if err != nil {
if errors.Is(err, auth.ErrInvalidToken) {
m.logger.Warn(
Expand All @@ -62,6 +64,17 @@ func (m *Auth) Verify(c *gin.Context) {
)
return
}
if errors.Is(err, auth.ErrUnknownTenant) {
m.logger.Warn(
"auth unknwon tenant",
zap.Error(err),
)
c.AbortWithStatusJSON(
http.StatusUnauthorized,
gin.H{"error": "unknown tenant"},
)
return
}

m.logger.Warn(
"unknown verification error",
Expand Down Expand Up @@ -114,3 +127,7 @@ func (m *Auth) parseToken(c *gin.Context) (string, bool) {

return tokenString, true
}

func (m *Auth) parseTenant(c *gin.Context) string {
return c.Request.Header.Get("x-piko-tenant-id")
}
81 changes: 71 additions & 10 deletions pkg/middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ type errorMessage struct {

func TestAuth(t *testing.T) {
t.Run("authorization ok", func(t *testing.T) {
verifier := &fakeVerifier{
verifier := auth.NewMultiTenantVerifier(&fakeVerifier{
handler: func(token string) (*auth.Token, error) {
assert.Equal(t, "123", token)
return &auth.Token{
Expiry: time.Now().Add(time.Hour),
Endpoints: []string{"e1", "e2", "e3"},
}, nil
},
}
}, nil)
m := NewAuth(verifier, log.NewNopLogger())

w := httptest.NewRecorder()
Expand All @@ -59,15 +59,15 @@ func TestAuth(t *testing.T) {
})

t.Run("x-piko-authorization ok", func(t *testing.T) {
verifier := &fakeVerifier{
verifier := auth.NewMultiTenantVerifier(&fakeVerifier{
handler: func(token string) (*auth.Token, error) {
assert.Equal(t, "123", token)
return &auth.Token{
Expiry: time.Now().Add(time.Hour),
Endpoints: []string{"e1", "e2", "e3"},
}, nil
},
}
}, nil)
m := NewAuth(verifier, log.NewNopLogger())

w := httptest.NewRecorder()
Expand All @@ -89,13 +89,49 @@ func TestAuth(t *testing.T) {
assert.Equal(t, []string{"e1", "e2", "e3"}, token.(*auth.Token).Endpoints)
})

t.Run("tenant ok", func(t *testing.T) {
verifier := auth.NewMultiTenantVerifier(&fakeVerifier{
handler: func(_ string) (*auth.Token, error) {
t.Error("expected tenant")
return nil, fmt.Errorf("expected tenant")
},
}, map[string]auth.Verifier{
"tenant-1": &fakeVerifier{
handler: func(token string) (*auth.Token, error) {
assert.Equal(t, "123", token)
return &auth.Token{
Expiry: time.Now().Add(time.Hour),
Endpoints: []string{"e1", "e2", "e3"},
}, nil
},
},
})
m := NewAuth(verifier, log.NewNopLogger())

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "http://example.com/foo", nil)
c.Request.Header.Add("Authorization", "Bearer 123")
c.Request.Header.Add("x-piko-tenant-id", "tenant-1")

m.Verify(c)

resp := w.Result()
assert.Equal(t, http.StatusOK, resp.StatusCode)

// Verify the token was added to context.
token, ok := c.Get(TokenContextKey)
assert.True(t, ok)
assert.Equal(t, []string{"e1", "e2", "e3"}, token.(*auth.Token).Endpoints)
})

t.Run("invalid token", func(t *testing.T) {
verifier := &fakeVerifier{
verifier := auth.NewMultiTenantVerifier(&fakeVerifier{
handler: func(token string) (*auth.Token, error) {
assert.Equal(t, "123", token)
return &auth.Token{}, fmt.Errorf("foo: %w", auth.ErrInvalidToken)
},
}
}, nil)
m := NewAuth(verifier, log.NewNopLogger())

w := httptest.NewRecorder()
Expand All @@ -114,12 +150,12 @@ func TestAuth(t *testing.T) {
})

t.Run("expired token", func(t *testing.T) {
verifier := &fakeVerifier{
verifier := auth.NewMultiTenantVerifier(&fakeVerifier{
handler: func(token string) (*auth.Token, error) {
assert.Equal(t, "123", token)
return &auth.Token{}, fmt.Errorf("foo: %w", auth.ErrExpiredToken)
},
}
}, nil)
m := NewAuth(verifier, log.NewNopLogger())

w := httptest.NewRecorder()
Expand All @@ -137,13 +173,38 @@ func TestAuth(t *testing.T) {
assert.Equal(t, "expired token", errMessage.Error)
})

t.Run("unknown tenant", func(t *testing.T) {
verifier := auth.NewMultiTenantVerifier(&fakeVerifier{
handler: func(token string) (*auth.Token, error) {
assert.Equal(t, "123", token)
return &auth.Token{}, fmt.Errorf("foo: %w", auth.ErrExpiredToken)
},
}, nil)
m := NewAuth(verifier, log.NewNopLogger())

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "http://example.com/foo", nil)
c.Request.Header.Add("Authorization", "Bearer 123")
c.Request.Header.Add("x-piko-tenant-id", "tenant-123")

m.Verify(c)

resp := w.Result()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)

var errMessage errorMessage
assert.NoError(t, json.NewDecoder(resp.Body).Decode(&errMessage))
assert.Equal(t, "unknown tenant", errMessage.Error)
})

t.Run("unknown error", func(t *testing.T) {
verifier := &fakeVerifier{
verifier := auth.NewMultiTenantVerifier(&fakeVerifier{
handler: func(token string) (*auth.Token, error) {
assert.Equal(t, "123", token)
return &auth.Token{}, fmt.Errorf("unknown")
},
}
}, nil)
m := NewAuth(verifier, log.NewNopLogger())

w := httptest.NewRecorder()
Expand Down
Loading
Loading