20
20
21
21
import java .net .URL ;
22
22
import java .text .ParseException ;
23
+ import java .util .Set ;
24
+ import java .util .HashSet ;
23
25
import java .util .Arrays ;
24
26
import java .util .Date ;
25
27
import java .util .List ;
43
45
import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
44
46
import com .nimbusds .jose .proc .SecurityContext ;
45
47
import com .nimbusds .jwt .SignedJWT ;
48
+ import com .nimbusds .jwt .JWTClaimsSet ;
46
49
import com .nimbusds .jwt .proc .ConfigurableJWTProcessor ;
47
50
48
51
public abstract class RangerJwtAuthHandler implements RangerAuthHandler {
49
52
private static final Logger LOG = LoggerFactory .getLogger (RangerJwtAuthHandler .class );
50
53
51
54
private JWSVerifier verifier = null ;
55
+ protected SignedJWT signedJWT = null ;
52
56
private String jwksProviderUrl = null ;
53
57
public static final String TYPE = "ranger-jwt" ; // Constant that identifies the authentication mechanism.
54
58
public static final String KEY_PROVIDER_URL = "jwks.provider-url" ; // JWKS provider URL
55
59
public static final String KEY_JWT_PUBLIC_KEY = "jwt.public-key" ; // JWT token provider public key
56
60
public static final String KEY_JWT_COOKIE_NAME = "jwt.cookie-name" ; // JWT cookie name
57
61
public static final String KEY_JWT_AUDIENCES = "jwt.audiences" ;
58
62
public static final String JWT_AUTHZ_PREFIX = "Bearer " ;
63
+ public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM = "custom.jwt.claim.group.key" ;
64
+ public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT = "knox.groups" ;
59
65
66
+ public String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = null ;
60
67
protected List <String > audiences = null ;
61
68
protected JWKSource <SecurityContext > keySource = null ;
62
69
@@ -76,6 +83,7 @@ public void initialize(final Properties config) throws Exception {
76
83
77
84
// optional configurations
78
85
String pemPublicKey = config .getProperty (KEY_JWT_PUBLIC_KEY );
86
+ CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = config .getProperty (CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM , CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT );
79
87
80
88
// setup JWT provider public key if configured
81
89
if (StringUtils .isNotBlank (pemPublicKey )) {
@@ -112,30 +120,36 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str
112
120
113
121
if (StringUtils .isNotBlank (serializedJWT )) {
114
122
try {
115
- final SignedJWT jwtToken = SignedJWT .parse (serializedJWT );
116
- boolean valid = validateToken (jwtToken );
123
+ signedJWT = SignedJWT .parse (serializedJWT );
124
+ JWTClaimsSet claimsSet = getJWTClaimsSet ();
125
+
126
+ if (LOG .isDebugEnabled ()){
127
+ LOG .debug ("RangerJwtAuthHandler.authenticate(): JWTClaimsSet - {}" , claimsSet );
128
+ }
129
+
130
+ boolean valid = validateToken ();
117
131
if (valid ) {
118
132
String userName ;
119
133
120
134
if (StringUtils .isNotBlank (doAsUser )) {
121
135
userName = doAsUser .trim ();
122
136
} else {
123
- userName = jwtToken . getJWTClaimsSet () .getSubject ();
137
+ userName = claimsSet .getSubject ();
124
138
}
125
139
126
140
if (LOG .isDebugEnabled ()) {
127
141
LOG .debug ("RangerJwtAuthHandler.authenticate(): Issuing AuthenticationToken for user: [{}]" , userName );
128
- LOG .debug ("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]" , jwtToken . getJWTClaimsSet () .getSubject (), doAsUser );
142
+ LOG .debug ("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]" , claimsSet .getSubject (), doAsUser );
129
143
}
130
144
token = new AuthenticationToken (userName , userName , TYPE );
131
145
} else {
132
- LOG .warn ("RangerJwtAuthHandler.authenticate(): Validation failed for JWT token : [{}] " , jwtToken .serialize ());
146
+ LOG .warn ("RangerJwtAuthHandler.authenticate(): Validation failed for JWT: [{}] " , signedJWT .serialize ());
133
147
}
134
- } catch (ParseException pe ) {
135
- LOG .warn ("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT token " , pe );
148
+ } catch (ParseException | RuntimeException exp ) {
149
+ LOG .warn ("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT" , exp );
136
150
}
137
151
} else {
138
- LOG .warn ("RangerJwtAuthHandler.authenticate(): JWT token not found. " );
152
+ LOG .warn ("RangerJwtAuthHandler.authenticate(): JWT not found" );
139
153
}
140
154
}
141
155
@@ -145,6 +159,31 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str
145
159
146
160
return token ;
147
161
}
162
+
163
+ protected JWTClaimsSet getJWTClaimsSet () throws ParseException {
164
+ return signedJWT .getJWTClaimsSet ();
165
+ }
166
+
167
+ public Set <String > getGroupsFromClaimSet () {
168
+ List <String > groupsClaim = null ;
169
+ try {
170
+ groupsClaim = (List <String >) getJWTClaimsSet ().getClaim (CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE );
171
+ } catch (ParseException e ) {
172
+ LOG .error ("Unable to parse JWT claim set" , e );
173
+ }
174
+
175
+ if (groupsClaim == null ) {
176
+ LOG .warn ("No group claim found!" );
177
+ return new HashSet <>();
178
+ }
179
+
180
+ Set <String > groups = new HashSet <>(groupsClaim );
181
+ if (LOG .isDebugEnabled ()) {
182
+ LOG .debug ("Groups present in Claim [{}]: {}" , CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE , groups );
183
+ }
184
+ return groups ;
185
+ }
186
+
148
187
149
188
protected String getJWT (final String jwtAuthHeader , final String jwtCookie ) {
150
189
String serializedJWT = null ;
@@ -171,19 +210,18 @@ protected String getJWT(final String jwtAuthHeader, final String jwtCookie) {
171
210
* implementation through submethods used within but also allows for the
172
211
* override of the entire token validation algorithm.
173
212
*
174
- * @param jwtToken the token to validate
175
213
* @return true if valid
176
214
*/
177
- protected boolean validateToken (final SignedJWT jwtToken ) {
178
- boolean expValid = validateExpiration (jwtToken );
215
+ protected boolean validateToken () throws ParseException {
216
+ boolean expValid = validateExpiration ();
179
217
boolean sigValid = false ;
180
218
boolean audValid = false ;
181
219
182
220
if (expValid ) {
183
- sigValid = validateSignature (jwtToken );
221
+ sigValid = validateSignature ();
184
222
185
223
if (sigValid ) {
186
- audValid = validateAudiences (jwtToken );
224
+ audValid = validateAudiences ();
187
225
}
188
226
}
189
227
@@ -195,41 +233,40 @@ protected boolean validateToken(final SignedJWT jwtToken) {
195
233
}
196
234
197
235
/**
198
- * Verify the signature of the JWT token in this method. This method depends on
236
+ * Verify the signature of the JWT in this method. This method depends on
199
237
* the public key that was established during init based upon the provisioned
200
238
* public key. Override this method in subclasses in order to customize the
201
239
* signature verification behavior.
202
240
*
203
- * @param jwtToken the token that contains the signature to be validated
204
241
* @return valid true if signature verifies successfully; false otherwise
205
242
*/
206
- protected boolean validateSignature (final SignedJWT jwtToken ) {
243
+ protected boolean validateSignature () {
207
244
boolean valid = false ;
208
245
209
- if (JWSObject .State .SIGNED == jwtToken .getState ()) {
246
+ if (JWSObject .State .SIGNED == signedJWT .getState ()) {
210
247
if (LOG .isDebugEnabled ()) {
211
- LOG .debug ("JWT token is in a SIGNED state" );
248
+ LOG .debug ("JWT is in a SIGNED state" );
212
249
}
213
250
214
- if (jwtToken .getSignature () != null ) {
251
+ if (signedJWT .getSignature () != null ) {
215
252
try {
216
253
if (StringUtils .isNotBlank (jwksProviderUrl )) {
217
- JWSKeySelector <SecurityContext > keySelector = new JWSVerificationKeySelector <>(jwtToken .getHeader ().getAlgorithm (), keySource );
254
+ JWSKeySelector <SecurityContext > keySelector = new JWSVerificationKeySelector <>(signedJWT .getHeader ().getAlgorithm (), keySource );
218
255
219
256
// Create a JWT processor for the access tokens
220
257
ConfigurableJWTProcessor <SecurityContext > jwtProcessor = getJwtProcessor (keySelector );
221
258
222
259
// Process the token
223
- jwtProcessor .process (jwtToken , null );
260
+ jwtProcessor .process (signedJWT , null );
224
261
valid = true ;
225
262
if (LOG .isDebugEnabled ()) {
226
- LOG .debug ("JWT token has been successfully verified." );
263
+ LOG .debug ("JWT has been successfully verified." );
227
264
}
228
265
} else if (verifier != null ) {
229
- if (jwtToken .verify (verifier )) {
266
+ if (signedJWT .verify (verifier )) {
230
267
valid = true ;
231
268
if (LOG .isDebugEnabled ()) {
232
- LOG .debug ("JWT token has been successfully verified." );
269
+ LOG .debug ("JWT has been successfully verified." );
233
270
}
234
271
} else {
235
272
LOG .warn ("JWT signature verification failed." );
@@ -257,61 +294,51 @@ protected boolean validateSignature(final SignedJWT jwtToken) {
257
294
* token claims list for audience. Override this method in subclasses in order
258
295
* to customize the audience validation behavior.
259
296
*
260
- * @param jwtToken the JWT token where the allowed audiences will be found
261
297
* @return true if an expected audience is present, otherwise false
262
298
*/
263
- protected boolean validateAudiences (final SignedJWT jwtToken ) {
299
+ protected boolean validateAudiences () throws ParseException {
264
300
boolean valid = false ;
265
- try {
266
- List <String > tokenAudienceList = jwtToken .getJWTClaimsSet ().getAudience ();
267
- // if there were no expected audiences configured then just
268
- // consider any audience acceptable
269
- if (audiences == null ) {
270
- valid = true ;
271
- } else {
272
- // if any of the configured audiences is found then consider it
273
- // acceptable
274
- for (String aud : tokenAudienceList ) {
275
- if (audiences .contains (aud )) {
276
- if (LOG .isDebugEnabled ()) {
277
- LOG .debug ("JWT token audience has been successfully validated." );
278
- }
279
- valid = true ;
280
- break ;
301
+ JWTClaimsSet claimsSet = getJWTClaimsSet ();
302
+ List <String > tokenAudienceList = claimsSet .getAudience ();
303
+ // if there were no expected audiences configured then just consider any audience acceptable
304
+ if (audiences == null ) {
305
+ valid = true ;
306
+ } else {
307
+ // if any of the configured audiences is found then consider it acceptable
308
+ for (String aud : tokenAudienceList ) {
309
+ if (audiences .contains (aud )) {
310
+ if (LOG .isDebugEnabled ()) {
311
+ LOG .debug ("JWT audience has been successfully validated." );
281
312
}
282
- }
283
- if (!valid ) {
284
- LOG .warn ("JWT audience validation failed." );
313
+ valid = true ;
314
+ break ;
285
315
}
286
316
}
287
- } catch (ParseException pe ) {
288
- LOG .warn ("Unable to parse the JWT token." , pe );
317
+ }
318
+
319
+ if (!valid ) {
320
+ LOG .warn ("JWT audience validation failed." );
289
321
}
290
322
return valid ;
291
323
}
292
324
293
325
/**
294
- * Validate that the expiration time of the JWT token has not been violated. If
295
- * it has then throw an AuthenticationException. Override this method in
326
+ * Validate that the expiration time of the JWT has not been violated. If
327
+ * it has, then throw an AuthenticationException. Override this method in
296
328
* subclasses in order to customize the expiration validation behavior.
297
329
*
298
- * @param jwtToken the token that contains the expiration date to validate
299
330
* @return valid true if the token has not expired; false otherwise
300
331
*/
301
- protected boolean validateExpiration (final SignedJWT jwtToken ) {
332
+ protected boolean validateExpiration () throws ParseException {
302
333
boolean valid = false ;
303
- try {
304
- Date expires = jwtToken .getJWTClaimsSet ().getExpirationTime ();
305
- if (expires == null || new Date ().before (expires )) {
306
- valid = true ;
307
- if (LOG .isDebugEnabled ()) {
308
- LOG .debug ("JWT token expiration date has been successfully validated." );
309
- }
310
- } else {
311
- LOG .warn ("JWT token provided is expired." );
334
+ Date expires = getJWTClaimsSet ().getExpirationTime ();
335
+ if (expires == null || new Date ().before (expires )) {
336
+ valid = true ;
337
+ if (LOG .isDebugEnabled ()) {
338
+ LOG .debug ("JWT expiration date has been successfully validated." );
312
339
}
313
- } catch ( ParseException pe ) {
314
- LOG .warn ("Failed to validate JWT expiry." , pe );
340
+ } else {
341
+ LOG .warn ("JWT provided has expired." );
315
342
}
316
343
317
344
return valid ;
0 commit comments