-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathjwt.go
311 lines (270 loc) · 11 KB
/
jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
)
// DefaultLeewaySeconds defines the amount of leeway that's used by default
// for validating the "nbf" (Not Before) and "exp" (Expiration Time) claims.
const DefaultLeewaySeconds = 150
// Validator validates JSON Web Tokens (JWT) by providing signature
// verification and claims set validation. Validator can contain either
// a single or multiple KeySets and will attempt to verify the JWT by iterating
// through the configured KeySets.
type Validator struct {
keySets []KeySet
}
// NewValidator returns a Validator that uses the given KeySet to verify JWT signatures.
func NewValidator(keySets ...KeySet) (*Validator, error) {
if len(keySets) <= 0 {
return nil, errors.New("must provide at least one key set")
}
for _, keySet := range keySets {
if keySet == nil {
return nil, errors.New("keySet must not be nil")
}
}
return &Validator{
keySets: keySets,
}, nil
}
// Expected defines the expected claims values to assert when validating a JWT.
// For claims that involve validation of the JWT with respect to time, leeway
// fields are provided to account for potential clock skew.
type Expected struct {
// The expected JWT "iss" (issuer) claim value. If empty, validation is skipped.
Issuer string
// The expected JWT "sub" (subject) claim value. If empty, validation is skipped.
Subject string
// The expected JWT "jti" (JWT ID) claim value. If empty, validation is skipped.
ID string
// The list of expected JWT "aud" (audience) claim values to match against.
// The JWT claim will be considered valid if it matches any of the expected
// audiences. If empty, validation is skipped.
Audiences []string
// SigningAlgorithms provides the list of expected JWS "alg" (algorithm) header
// parameter values to match against. The JWS header parameter will be considered
// valid if it matches any of the expected signing algorithms. The following
// algorithms are supported: RS256, RS384, RS512, ES256, ES384, ES512, PS256,
// PS384, PS512, EdDSA. If empty, defaults to RS256.
SigningAlgorithms []Alg
// NotBeforeLeeway provides the option to set an amount of leeway to use when
// validating the "nbf" (Not Before) claim. If the duration is zero or not
// provided, a default leeway of 150 seconds will be used. If the duration is
// negative, no leeway will be used.
NotBeforeLeeway time.Duration
// ExpirationLeeway provides the option to set an amount of leeway to use when
// validating the "exp" (Expiration Time) claim. If the duration is zero or not
// provided, a default leeway of 150 seconds will be used. If the duration is
// negative, no leeway will be used.
ExpirationLeeway time.Duration
// ClockSkewLeeway provides the option to set an amount of leeway to use when
// validating the "nbf" (Not Before), "exp" (Expiration Time), and "iat" (Issued At)
// claims. If the duration is zero or not provided, a default leeway of 60 seconds
// will be used. If the duration is negative, no leeway will be used.
ClockSkewLeeway time.Duration
// Now provides the option to specify a func for determining what the current time is.
// The func will be used to provide the current time when validating a JWT with respect to
// the "nbf" (Not Before), "exp" (Expiration Time), and "iat" (Issued At) claims. If not
// provided, defaults to returning time.Now().
Now func() time.Time
}
// Validate validates JWTs of the JWS compact serialization form.
//
// The given JWT is considered valid if:
// 1. Its signature is successfully verified.
// 2. Its claims set and header parameter values match what's given by Expected.
// 3. It's valid with respect to the current time. This means that the current
// time must be within the times (inclusive) given by the "nbf" (Not Before)
// and "exp" (Expiration Time) claims and after the time given by the "iat"
// (Issued At) claim, with configurable leeway. See Expected.Now() for details
// on how the current time is provided for validation.
func (v *Validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
return v.validateAll(ctx, token, expected, false)
}
// ValidateAllowMissingIatNbfExp validates JWTs of the JWS compact serialization form.
//
// The given JWT is considered valid if:
// 1. Its signature is successfully verified.
// 2. Its claims set and header parameter values match what's given by Expected.
// 3. It's valid with respect to the current time. This means that the current
// time must be within the times (inclusive) given by the "nbf" (Not Before)
// and "exp" (Expiration Time) claims and after the time given by the "iat"
// (Issued At) claim, with configurable leeway, if they are present. If all
// of "nbf", "exp", and "iat" are missing, then this check is skipped. See
// Expected.Now() for details on how the current time is provided for
// validation.
func (v *Validator) ValidateAllowMissingIatNbfExp(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
return v.validateAll(ctx, token, expected, true)
}
func (v *Validator) validateAll(ctx context.Context, token string, expected Expected, allowMissingIatExpNbf bool) (map[string]interface{}, error) {
var allClaims map[string]interface{}
var err error
// Ensure that the token is signed by at least one of the given key sets
var tokenVerified bool
for _, keySet := range v.keySets {
// First, verify the signature to ensure subsequent validation is against verified claims
allClaims, err = keySet.VerifySignature(ctx, token)
if err == nil {
tokenVerified = true
break
}
}
if !tokenVerified {
return nil, fmt.Errorf("error verifying token signature: %w", err)
}
// Validate the signing algorithm in the JWS header
if err := validateSigningAlgorithm(token, expected.SigningAlgorithms); err != nil {
return nil, fmt.Errorf("invalid algorithm (alg) header parameter: %w", err)
}
// Unmarshal all claims into the set of public JWT registered claims
claims := jwt.Claims{}
allClaimsJSON, err := json.Marshal(allClaims)
if err != nil {
return nil, err
}
if err := json.Unmarshal(allClaimsJSON, &claims); err != nil {
return nil, err
}
// Validate claims by asserting that they're as expected
if expected.Issuer != "" && expected.Issuer != claims.Issuer {
return nil, fmt.Errorf("invalid issuer (iss) claim")
}
if expected.Subject != "" && expected.Subject != claims.Subject {
return nil, fmt.Errorf("invalid subject (sub) claim")
}
if expected.ID != "" && expected.ID != claims.ID {
return nil, fmt.Errorf("invalid ID (jti) claim")
}
if err := validateAudience(expected.Audiences, claims.Audience); err != nil {
return nil, fmt.Errorf("invalid audience (aud) claim: %w", err)
}
if claims.IssuedAt == nil {
claims.IssuedAt = new(jwt.NumericDate)
}
if claims.Expiry == nil {
claims.Expiry = new(jwt.NumericDate)
}
if claims.NotBefore == nil {
claims.NotBefore = new(jwt.NumericDate)
}
// At least one of the "nbf" (Not Before), "exp" (Expiration Time), or "iat" (Issued At)
// claims are required to be set unless allowMissingIatExpNbf is set.
if *claims.IssuedAt == 0 && *claims.Expiry == 0 && *claims.NotBefore == 0 {
if allowMissingIatExpNbf {
return allClaims, nil
}
return nil, errors.New("no issued at (iat), not before (nbf), or expiration time (exp) claims in token")
}
// If "exp" (Expiration Time) is not set, then set it to the latest of
// either the "iat" (Issued At) or "nbf" (Not Before) claims plus leeway.
if *claims.Expiry == 0 {
latestStart := *claims.IssuedAt
if *claims.NotBefore > *claims.IssuedAt {
latestStart = *claims.NotBefore
}
leeway := expected.ExpirationLeeway.Seconds()
if expected.ExpirationLeeway.Seconds() < 0 {
leeway = 0
} else if expected.ExpirationLeeway.Seconds() == 0 {
leeway = DefaultLeewaySeconds
}
*claims.Expiry = jwt.NumericDate(int64(latestStart) + int64(leeway))
}
// If "nbf" (Not Before) is not set, then set it to the "iat" (Issued At) if set.
// Otherwise, set it to the "exp" (Expiration Time) minus leeway.
if *claims.NotBefore == 0 {
if *claims.IssuedAt != 0 {
*claims.NotBefore = *claims.IssuedAt
} else {
leeway := expected.NotBeforeLeeway.Seconds()
if expected.NotBeforeLeeway.Seconds() < 0 {
leeway = 0
} else if expected.NotBeforeLeeway.Seconds() == 0 {
leeway = DefaultLeewaySeconds
}
*claims.NotBefore = jwt.NumericDate(int64(*claims.Expiry) - int64(leeway))
}
}
// Set clock skew leeway to apply when validating all time-related claims
cksLeeway := expected.ClockSkewLeeway
if expected.ClockSkewLeeway.Seconds() < 0 {
cksLeeway = 0
} else if expected.ClockSkewLeeway.Seconds() == 0 {
cksLeeway = jwt.DefaultLeeway
}
// Validate that the token is not expired with respect to the current time
now := time.Now()
if expected.Now != nil {
now = expected.Now()
}
if claims.NotBefore != nil && now.Add(cksLeeway).Before(claims.NotBefore.Time()) {
return nil, errors.New("invalid not before (nbf) claim: token not yet valid")
}
if claims.Expiry != nil && now.Add(-cksLeeway).After(claims.Expiry.Time()) {
return nil, errors.New("invalid expiration time (exp) claim: token is expired")
}
if claims.IssuedAt != nil && now.Add(cksLeeway).Before(claims.IssuedAt.Time()) {
return nil, errors.New("invalid issued at (iat) claim: token issued in the future")
}
return allClaims, nil
}
// validateSigningAlgorithm checks whether the JWS "alg" (Algorithm) header
// parameter value for the given JWT matches any given in expectedAlgorithms.
// If expectedAlgorithms is empty, RS256 will be expected by default.
func validateSigningAlgorithm(token string, expectedAlgorithms []Alg) error {
if err := SupportedSigningAlgorithm(expectedAlgorithms...); err != nil {
return err
}
jws, err := jose.ParseSigned(token)
if err != nil {
return err
}
if len(jws.Signatures) == 0 {
return fmt.Errorf("token must be signed")
}
if len(jws.Signatures) == 1 && len(jws.Signatures[0].Signature) == 0 {
return fmt.Errorf("token must be signed")
}
if len(jws.Signatures) > 1 {
return fmt.Errorf("token with multiple signatures not supported")
}
if len(expectedAlgorithms) == 0 {
expectedAlgorithms = []Alg{RS256}
}
actual := Alg(jws.Signatures[0].Header.Algorithm)
for _, expected := range expectedAlgorithms {
if expected == actual {
return nil
}
}
return fmt.Errorf("token signed with unexpected algorithm")
}
// validateAudience returns an error if audClaim does not contain any audiences
// given by expectedAudiences. If expectedAudiences is empty, it skips validation
// and returns nil.
func validateAudience(expectedAudiences, audClaim []string) error {
if len(expectedAudiences) == 0 {
return nil
}
for _, v := range expectedAudiences {
if contains(audClaim, v) {
return nil
}
}
return errors.New("audience claim does not match any expected audience")
}
func contains(sl []string, st string) bool {
for _, s := range sl {
if s == st {
return true
}
}
return false
}