Skip to content

Commit cfe91e2

Browse files
authored
chore: include deployment_id in jwt (#659)
1 parent 19e09d5 commit cfe91e2

File tree

8 files changed

+118
-40
lines changed

8 files changed

+118
-40
lines changed

internal/apirouter/auth_middleware.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,19 @@ func APIKeyOrTenantJWTAuthMiddleware(apiKey string, jwtKey string) gin.HandlerFu
107107
}
108108

109109
// Try JWT auth
110-
tokenTenantID, err := JWT.ExtractTenantID(jwtKey, token)
110+
claims, err := JWT.Extract(jwtKey, token)
111111
if err != nil {
112112
c.AbortWithStatus(http.StatusUnauthorized)
113113
return
114114
}
115115

116116
// If tenantID param exists, verify it matches token
117-
if paramTenantID := c.Param("tenantID"); paramTenantID != "" && paramTenantID != tokenTenantID {
117+
if paramTenantID := c.Param("tenantID"); paramTenantID != "" && paramTenantID != claims.TenantID {
118118
c.AbortWithStatus(http.StatusUnauthorized)
119119
return
120120
}
121121

122-
c.Set("tenantID", tokenTenantID)
122+
c.Set("tenantID", claims.TenantID)
123123
c.Set(authRoleKey, RoleTenant)
124124
c.Next()
125125
}
@@ -143,19 +143,19 @@ func TenantJWTAuthMiddleware(apiKey string, jwtKey string) gin.HandlerFunc {
143143
return
144144
}
145145

146-
tokenTenantID, err := JWT.ExtractTenantID(jwtKey, token)
146+
claims, err := JWT.Extract(jwtKey, token)
147147
if err != nil {
148148
c.AbortWithStatus(http.StatusUnauthorized)
149149
return
150150
}
151151

152152
// If tenantID param exists, verify it matches token
153-
if paramTenantID := c.Param("tenantID"); paramTenantID != "" && paramTenantID != tokenTenantID {
153+
if paramTenantID := c.Param("tenantID"); paramTenantID != "" && paramTenantID != claims.TenantID {
154154
c.AbortWithStatus(http.StatusUnauthorized)
155155
return
156156
}
157157

158-
c.Set("tenantID", tokenTenantID)
158+
c.Set("tenantID", claims.TenantID)
159159
c.Set(authRoleKey, RoleTenant)
160160
c.Next()
161161
}

internal/apirouter/auth_middleware_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ func TestAPIKeyOrTenantJWTAuthMiddleware(t *testing.T) {
176176
c.Params = []gin.Param{{Key: "tenantID", Value: "different_tenant"}}
177177

178178
// Create JWT token for tenantID
179-
token, err := apirouter.JWT.New(jwtSecret, tenantID)
179+
token, err := apirouter.JWT.New(jwtSecret, apirouter.JWTClaims{TenantID: tenantID})
180180
if err != nil {
181181
t.Fatal(err)
182182
}
@@ -201,7 +201,7 @@ func TestAPIKeyOrTenantJWTAuthMiddleware(t *testing.T) {
201201
c.Params = []gin.Param{{Key: "tenantID", Value: tenantID}}
202202

203203
// Create JWT token for tenantID
204-
token, err := apirouter.JWT.New(jwtSecret, tenantID)
204+
token, err := apirouter.JWT.New(jwtSecret, apirouter.JWTClaims{TenantID: tenantID})
205205
if err != nil {
206206
t.Fatal(err)
207207
}
@@ -251,7 +251,7 @@ func TestAPIKeyOrTenantJWTAuthMiddleware(t *testing.T) {
251251
}
252252

253253
func newJWTToken(t *testing.T, secret string, tenantID string) string {
254-
token, err := apirouter.JWT.New(secret, tenantID)
254+
token, err := apirouter.JWT.New(secret, apirouter.JWTClaims{TenantID: tenantID})
255255
if err != nil {
256256
t.Fatal(err)
257257
}

internal/apirouter/jwt.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,28 @@ var (
1919
ErrInvalidToken = errors.New("invalid token")
2020
)
2121

22-
func (_ jsonwebtoken) New(jwtSecret string, tenantID string) (string, error) {
22+
// JWTClaims contains the custom claims for JWT tokens
23+
type JWTClaims struct {
24+
TenantID string
25+
DeploymentID string
26+
}
27+
28+
func (_ jsonwebtoken) New(jwtSecret string, claims JWTClaims) (string, error) {
2329
now := time.Now()
24-
token := jwt.NewWithClaims(signingMethod, jwt.MapClaims{
30+
mapClaims := jwt.MapClaims{
2531
"iss": issuer,
26-
"sub": tenantID,
32+
"sub": claims.TenantID,
2733
"iat": now.Unix(),
2834
"exp": now.Add(24 * time.Hour).Unix(),
29-
})
35+
}
36+
if claims.DeploymentID != "" {
37+
mapClaims["deployment_id"] = claims.DeploymentID
38+
}
39+
token := jwt.NewWithClaims(signingMethod, mapClaims)
3040
return token.SignedString([]byte(jwtSecret))
3141
}
3242

33-
func (_ jsonwebtoken) ExtractTenantID(jwtSecret string, tokenString string) (string, error) {
43+
func (_ jsonwebtoken) Extract(jwtSecret string, tokenString string) (JWTClaims, error) {
3444
token, err := jwt.Parse(
3545
tokenString,
3646
func(token *jwt.Token) (interface{}, error) {
@@ -39,14 +49,28 @@ func (_ jsonwebtoken) ExtractTenantID(jwtSecret string, tokenString string) (str
3949
jwt.WithIssuer(issuer),
4050
)
4151
if err != nil || !token.Valid {
42-
return "", ErrInvalidToken
52+
return JWTClaims{}, ErrInvalidToken
53+
}
54+
55+
claims, ok := token.Claims.(jwt.MapClaims)
56+
if !ok {
57+
return JWTClaims{}, ErrInvalidToken
4358
}
4459

4560
tenantID, err := token.Claims.GetSubject()
4661
if err != nil {
47-
return "", ErrInvalidToken
62+
return JWTClaims{}, ErrInvalidToken
63+
}
64+
65+
var deploymentID string
66+
if did, ok := claims["deployment_id"].(string); ok {
67+
deploymentID = did
4868
}
49-
return tenantID, nil
69+
70+
return JWTClaims{
71+
TenantID: tenantID,
72+
DeploymentID: deploymentID,
73+
}, nil
5074
}
5175

5276
func (_ jsonwebtoken) Verify(jwtSecret string, tokenString string, tenantID string) (bool, error) {

internal/apirouter/jwt_test.go

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,34 @@ func TestJWT(t *testing.T) {
1616
const issuer = "outpost"
1717
const jwtKey = "supersecret"
1818
const tenantID = "tenantID"
19+
const deploymentID = "deployment123"
1920
var signingMethod = jwt.SigningMethodHS256
2021

2122
t.Run("should generate a new jwt token", func(t *testing.T) {
2223
t.Parallel()
23-
token, err := apirouter.JWT.New(jwtKey, tenantID)
24+
token, err := apirouter.JWT.New(jwtKey, apirouter.JWTClaims{TenantID: tenantID})
2425
assert.Nil(t, err)
2526
assert.NotEqual(t, "", token)
2627
})
2728

29+
t.Run("should generate a new jwt token with deployment_id", func(t *testing.T) {
30+
t.Parallel()
31+
token, err := apirouter.JWT.New(jwtKey, apirouter.JWTClaims{
32+
TenantID: tenantID,
33+
DeploymentID: deploymentID,
34+
})
35+
assert.Nil(t, err)
36+
assert.NotEqual(t, "", token)
37+
38+
// Verify deployment_id is in the token
39+
claims, err := apirouter.JWT.Extract(jwtKey, token)
40+
assert.Nil(t, err)
41+
assert.Equal(t, deploymentID, claims.DeploymentID)
42+
})
43+
2844
t.Run("should verify a valid jwt token", func(t *testing.T) {
2945
t.Parallel()
30-
token, err := apirouter.JWT.New(jwtKey, tenantID)
46+
token, err := apirouter.JWT.New(jwtKey, apirouter.JWTClaims{TenantID: tenantID})
3147
if err != nil {
3248
t.Fatal(err)
3349
}
@@ -36,24 +52,51 @@ func TestJWT(t *testing.T) {
3652
assert.True(t, valid)
3753
})
3854

39-
t.Run("should extract tenantID from valid token", func(t *testing.T) {
55+
t.Run("should extract claims from valid token", func(t *testing.T) {
4056
t.Parallel()
41-
token, err := apirouter.JWT.New(jwtKey, tenantID)
57+
token, err := apirouter.JWT.New(jwtKey, apirouter.JWTClaims{TenantID: tenantID})
4258
if err != nil {
4359
t.Fatal(err)
4460
}
45-
extractedTenantID, err := apirouter.JWT.ExtractTenantID(jwtKey, token)
61+
claims, err := apirouter.JWT.Extract(jwtKey, token)
4662
assert.Nil(t, err)
47-
assert.Equal(t, tenantID, extractedTenantID)
63+
assert.Equal(t, tenantID, claims.TenantID)
64+
assert.Equal(t, "", claims.DeploymentID)
4865
})
4966

50-
t.Run("should fail to extract tenantID from invalid token", func(t *testing.T) {
67+
t.Run("should fail to extract claims from invalid token", func(t *testing.T) {
5168
t.Parallel()
52-
_, err := apirouter.JWT.ExtractTenantID(jwtKey, "invalid_token")
69+
_, err := apirouter.JWT.Extract(jwtKey, "invalid_token")
5370
assert.ErrorIs(t, err, apirouter.ErrInvalidToken)
5471
})
5572

56-
t.Run("should fail to extract tenantID from token with invalid issuer", func(t *testing.T) {
73+
t.Run("should extract all claims from valid token", func(t *testing.T) {
74+
t.Parallel()
75+
token, err := apirouter.JWT.New(jwtKey, apirouter.JWTClaims{
76+
TenantID: tenantID,
77+
DeploymentID: deploymentID,
78+
})
79+
if err != nil {
80+
t.Fatal(err)
81+
}
82+
claims, err := apirouter.JWT.Extract(jwtKey, token)
83+
assert.Nil(t, err)
84+
assert.Equal(t, tenantID, claims.TenantID)
85+
assert.Equal(t, deploymentID, claims.DeploymentID)
86+
})
87+
88+
t.Run("should return empty deployment_id when not in token", func(t *testing.T) {
89+
t.Parallel()
90+
token, err := apirouter.JWT.New(jwtKey, apirouter.JWTClaims{TenantID: tenantID})
91+
if err != nil {
92+
t.Fatal(err)
93+
}
94+
claims, err := apirouter.JWT.Extract(jwtKey, token)
95+
assert.Nil(t, err)
96+
assert.Equal(t, "", claims.DeploymentID)
97+
})
98+
99+
t.Run("should fail to extract claims from token with invalid issuer", func(t *testing.T) {
57100
t.Parallel()
58101
now := time.Now()
59102
jwtToken := jwt.NewWithClaims(signingMethod, jwt.MapClaims{
@@ -66,7 +109,7 @@ func TestJWT(t *testing.T) {
66109
if err != nil {
67110
t.Fatal(err)
68111
}
69-
_, err = apirouter.JWT.ExtractTenantID(jwtKey, token)
112+
_, err = apirouter.JWT.Extract(jwtKey, token)
70113
assert.ErrorIs(t, err, apirouter.ErrInvalidToken)
71114
})
72115

internal/apirouter/router.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ type RouterConfig struct {
4949
ServiceName string
5050
APIKey string
5151
JWTSecret string
52+
DeploymentID string
5253
Topics []string
5354
Registry destregistry.Registry
5455
PortalConfig portal.PortalConfig
@@ -138,7 +139,7 @@ func NewRouter(
138139
apiRouter := r.Group("/api/v1")
139140
apiRouter.Use(SetTenantIDMiddleware())
140141

141-
tenantHandlers := NewTenantHandlers(logger, telemetry, cfg.JWTSecret, entityStore)
142+
tenantHandlers := NewTenantHandlers(logger, telemetry, cfg.JWTSecret, cfg.DeploymentID, entityStore)
142143
destinationHandlers := NewDestinationHandlers(logger, telemetry, entityStore, cfg.Topics, cfg.Registry)
143144
publishHandlers := NewPublishHandlers(logger, publishmqEventHandler)
144145
logHandlers := NewLogHandlers(logger, logStore)

internal/apirouter/router_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestRouterWithAPIKey(t *testing.T) {
109109
router, _, _ := setupTestRouter(t, apiKey, jwtSecret)
110110

111111
tenantID := "tenantID"
112-
validToken, err := apirouter.JWT.New(jwtSecret, tenantID)
112+
validToken, err := apirouter.JWT.New(jwtSecret, apirouter.JWTClaims{TenantID: tenantID})
113113
if err != nil {
114114
t.Fatal(err)
115115
}
@@ -201,7 +201,7 @@ func TestRouterWithoutAPIKey(t *testing.T) {
201201
router, _, _ := setupTestRouter(t, apiKey, jwtSecret)
202202

203203
tenantID := "tenantID"
204-
validToken, err := apirouter.JWT.New(jwtSecret, tenantID)
204+
validToken, err := apirouter.JWT.New(jwtSecret, apirouter.JWTClaims{TenantID: tenantID})
205205
if err != nil {
206206
t.Fatal(err)
207207
}

internal/apirouter/tenant_handlers.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,26 @@ import (
1313
)
1414

1515
type TenantHandlers struct {
16-
logger *logging.Logger
17-
telemetry telemetry.Telemetry
18-
jwtSecret string
19-
entityStore models.EntityStore
16+
logger *logging.Logger
17+
telemetry telemetry.Telemetry
18+
jwtSecret string
19+
deploymentID string
20+
entityStore models.EntityStore
2021
}
2122

2223
func NewTenantHandlers(
2324
logger *logging.Logger,
2425
telemetry telemetry.Telemetry,
2526
jwtSecret string,
27+
deploymentID string,
2628
entityStore models.EntityStore,
2729
) *TenantHandlers {
2830
return &TenantHandlers{
29-
logger: logger,
30-
telemetry: telemetry,
31-
jwtSecret: jwtSecret,
32-
entityStore: entityStore,
31+
logger: logger,
32+
telemetry: telemetry,
33+
jwtSecret: jwtSecret,
34+
deploymentID: deploymentID,
35+
entityStore: entityStore,
3336
}
3437
}
3538

@@ -180,7 +183,10 @@ func (h *TenantHandlers) RetrieveToken(c *gin.Context) {
180183
if tenant == nil {
181184
return
182185
}
183-
jwtToken, err := JWT.New(h.jwtSecret, tenant.ID)
186+
jwtToken, err := JWT.New(h.jwtSecret, JWTClaims{
187+
TenantID: tenant.ID,
188+
DeploymentID: h.deploymentID,
189+
})
184190
if err != nil {
185191
AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err))
186192
return
@@ -193,7 +199,10 @@ func (h *TenantHandlers) RetrievePortal(c *gin.Context) {
193199
if tenant == nil {
194200
return
195201
}
196-
jwtToken, err := JWT.New(h.jwtSecret, tenant.ID)
202+
jwtToken, err := JWT.New(h.jwtSecret, JWTClaims{
203+
TenantID: tenant.ID,
204+
DeploymentID: h.deploymentID,
205+
})
197206
if err != nil {
198207
AbortWithError(c, http.StatusInternalServerError, NewErrInternalServer(err))
199208
return

internal/services/builder.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ func (b *ServiceBuilder) BuildAPIWorkers(baseRouter *gin.Engine) error {
199199
ServiceName: b.cfg.OpenTelemetry.GetServiceName(),
200200
APIKey: b.cfg.APIKey,
201201
JWTSecret: b.cfg.APIJWTSecret,
202+
DeploymentID: b.cfg.DeploymentID,
202203
Topics: b.cfg.Topics,
203204
Registry: svc.destRegistry,
204205
PortalConfig: b.cfg.GetPortalConfig(),

0 commit comments

Comments
 (0)