diff --git a/pom.xml b/pom.xml index cc59f7e6740..c58ef79b325 100644 --- a/pom.xml +++ b/pom.xml @@ -110,6 +110,7 @@ 2.2.1 1.1.2 2.4.16 + 10.4 @@ -1433,6 +1434,12 @@ apiguardian-api ${version.lib.apiguardian.api} + + + com.nimbusds + nimbus-jose-jwt + ${version.lib.nimbus.jose.jwt} + diff --git a/security/jwt/pom.xml b/security/jwt/pom.xml index 52f3a6f8676..896457f4b89 100644 --- a/security/jwt/pom.xml +++ b/security/jwt/pom.xml @@ -69,5 +69,11 @@ hamcrest-core test + + + com.nimbusds + nimbus-jose-jwt + test + diff --git a/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java b/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java index faecd9b632d..dcc9cfa6d5e 100644 --- a/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java +++ b/security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2024 Oracle and/or its affiliates. + * Copyright (c) 2021, 2025 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package io.helidon.security.jwt; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.security.InvalidKeyException; import java.security.MessageDigest; @@ -46,6 +48,7 @@ import javax.crypto.spec.SecretKeySpec; import io.helidon.common.Errors; +import io.helidon.common.LazyValue; import io.helidon.security.jwt.jwk.Jwk; import io.helidon.security.jwt.jwk.JwkEC; import io.helidon.security.jwt.jwk.JwkKeys; @@ -61,8 +64,11 @@ */ public final class EncryptedJwt { + private static final byte[] EMPTY_BYTES = new byte[0]; private static final Map CONTENT_ENCRYPTION; + private static final LazyValue RANDOM = LazyValue.create(SecureRandom::new); + private static final Pattern JWE_PATTERN = Pattern .compile("(^[\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+$)"); private static final Base64.Decoder URL_DECODER = Base64.getUrlDecoder(); @@ -175,7 +181,7 @@ public static EncryptedJwt parseToken(String token) { * * * @param header parsed JWT header - * @param token String with the token + * @param token String with the token * @return Encrypted jwt parts * @throws RuntimeException in case of invalid content, see {@link Errors.ErrorMessagesException} */ @@ -198,9 +204,9 @@ public static EncryptedJwt parseToken(JwtHeaders header, String token) { /** * Add validator of kek algorithm to the collection of validators. * - * @param validators collection of validators - * @param expectedKekAlg audience key encryption key algorithm - * @param mandatory whether the alg field is mandatory in the token + * @param validators collection of validators + * @param expectedKekAlg audience key encryption key algorithm + * @param mandatory whether the alg field is mandatory in the token */ public static void addKekValidator(Collection> validators, String expectedKekAlg, boolean mandatory) { validators.add((encryptedJwt, collector) -> { @@ -379,13 +385,6 @@ public SignedJwt decrypt(JwkKeys jwkKeys, Jwk defaultJwk) { errors.collect().checkValid(); - byte[] decryptedKey = unwrapRsa(supportedAlgorithm, privateKey, encryptedKey); - //Base64 headers are used as an aad. This aad has to be in US_ASCII encoding. - EncryptionParts encryptionParts = new EncryptionParts(decryptedKey, - iv, - headerBase64.getBytes(StandardCharsets.US_ASCII), - encryptedPayload, - authTag); AesAlgorithm aesAlgorithm; try { SupportedEncryption supportedEncryption = SupportedEncryption.getValue(enc); @@ -393,7 +392,31 @@ public SignedJwt decrypt(JwkKeys jwkKeys, Jwk defaultJwk) { } catch (IllegalArgumentException e) { throw new JwtException("Unsupported content encryption: " + enc); } + + byte[] decryptedKey = unwrapRsa(supportedAlgorithm, privateKey, encryptedKey); + byte[] macKey; + byte[] encKey; + if (aesAlgorithm.hasHmac()) { + int keySizeInBytes = aesAlgorithm.keySize / 8; + macKey = new byte[keySizeInBytes]; + encKey = new byte[keySizeInBytes]; + System.arraycopy(decryptedKey, 0, macKey, 0, keySizeInBytes); + System.arraycopy(decryptedKey, keySizeInBytes, encKey, 0, keySizeInBytes); + } else { + encKey = decryptedKey; + macKey = EMPTY_BYTES; + } + //Base64 headers are used as an aad. This aad has to be in US_ASCII encoding. + EncryptionParts encryptionParts = new EncryptionParts(encKey, + macKey, + iv, + headerBase64.getBytes(StandardCharsets.US_ASCII), + encryptedPayload, authTag); + String decryptedPayload = new String(aesAlgorithm.decrypt(encryptionParts), StandardCharsets.UTF_8); + Arrays.fill(decryptedKey, (byte) 0); //clear the decrypted key from the memory + Arrays.fill(macKey, (byte) 0); + Arrays.fill(encKey, (byte) 0); return SignedJwt.parseToken(decryptedPayload); } @@ -561,9 +584,13 @@ public EncryptedJwt build() { //Base64 headers are used as an aad. This aad has to be in US_ASCII encoding. EncryptionParts encryptionParts = contentEncryption.encrypt(jwt.tokenContent().getBytes(StandardCharsets.UTF_8), headersBase64.getBytes(StandardCharsets.US_ASCII)); + byte[] mac = encryptionParts.mac(); byte[] aesKey = encryptionParts.key(); + byte[] result = new byte[mac.length + aesKey.length]; + System.arraycopy(mac, 0, result, 0, mac.length); + System.arraycopy(aesKey, 0, result, mac.length, aesKey.length); - byte[] encryptedAesKey = wrapRsa(algorithm, publicKey, aesKey); + byte[] encryptedAesKey = wrapRsa(algorithm, publicKey, result); String token = tokenBuilder.append(headersBase64).append(".") .append(encode(encryptedAesKey)).append(".") .append(encode(encryptionParts.iv())).append(".") @@ -696,8 +723,6 @@ static SupportedEncryption getValue(String value) { private static class AesAlgorithm { - private static final SecureRandom RANDOM = new SecureRandom(); - private final String cipher; private final int keySize; private final int ivSize; @@ -711,16 +736,16 @@ private AesAlgorithm(String cipher, int keySize, int ivSize) { EncryptionParts encrypt(byte[] plainContent, byte[] aad) { try { KeyGenerator kgen = KeyGenerator.getInstance("AES"); - kgen.init(keySize, RANDOM); + kgen.init(keySize, RANDOM.get()); SecretKey secretKey = kgen.generateKey(); byte[] iv = new byte[ivSize]; - RANDOM.nextBytes(iv); - EncryptionParts encryptionParts = new EncryptionParts(secretKey.getEncoded(), iv, aad, null, null); + RANDOM.get().nextBytes(iv); + EncryptionParts encryptionParts = new EncryptionParts(secretKey.getEncoded(), EMPTY_BYTES, iv, aad, null, null); Cipher cipher = Cipher.getInstance(this.cipher); cipher.init(Cipher.ENCRYPT_MODE, secretKey, createParameterSpec(encryptionParts)); postCipherConstruct(cipher, encryptionParts); byte[] encryptedContent = cipher.doFinal(plainContent); - return new EncryptionParts(secretKey.getEncoded(), iv, aad, encryptedContent, null); + return new EncryptionParts(secretKey.getEncoded(), EMPTY_BYTES, iv, aad, encryptedContent, null); } catch (Exception e) { throw new JwtException("Exception during content encryption", e); } @@ -747,6 +772,13 @@ protected AlgorithmParameterSpec createParameterSpec(EncryptionParts encryptionP return new IvParameterSpec(encryptionParts.iv()); } + boolean hasHmac() { + return false; + } + + int keySize() { + return keySize; + } } private static class AesAlgorithmWithHmac extends AesAlgorithm { @@ -761,24 +793,32 @@ private AesAlgorithmWithHmac(String cipher, int keySize, int ivSize, String hmac @Override public EncryptionParts encrypt(byte[] plainContent, byte[] aad) { EncryptionParts encryptionParts = super.encrypt(plainContent, aad); + encryptionParts = createMacKey(encryptionParts); byte[] authTag = sign(encryptionParts); return new EncryptionParts(encryptionParts.key(), + encryptionParts.mac(), encryptionParts.iv(), encryptionParts.aad(), encryptionParts.encryptedContent(), authTag); } + private EncryptionParts createMacKey(EncryptionParts encryptionParts) { + byte[] mac = new byte[keySize() / 8]; + RANDOM.get().nextBytes(mac); + return new EncryptionParts(encryptionParts.key(), + mac, + encryptionParts.iv(), + encryptionParts.aad(), + encryptionParts.encryptedContent(), + encryptionParts.authTag()); + } + private byte[] sign(EncryptionParts parts) { - try { - Mac mac = macInstance(); - mac.init(new SecretKeySpec(parts.key(), "AES")); - mac.update(parts.aad()); - mac.update(parts.encryptedContent()); - return mac.doFinal(); - } catch (InvalidKeyException e) { - throw new JwtException("Exception occurred while HMAC signature"); - } + byte[] fullHmacTag = computeMac(parts); + byte[] authKey = new byte[fullHmacTag.length / 2]; + System.arraycopy(fullHmacTag, 0, authKey, 0, authKey.length); + return authKey; } @Override @@ -790,15 +830,44 @@ public byte[] decrypt(EncryptionParts encryptionParts) { } private boolean verifySignature(EncryptionParts encryptionParts) { + byte[] fullHmacTag = computeMac(encryptionParts); + byte[] authKey = new byte[fullHmacTag.length / 2]; + System.arraycopy(fullHmacTag, 0, authKey, 0, authKey.length); + return MessageDigest.isEqual(authKey, encryptionParts.authTag()); + } + + private byte[] computeMac(EncryptionParts encryptionParts) { try { + /* + The octet string AL is equal to the number of bits in the + Additional Authenticated Data A expressed as a 64-bit unsigned + big-endian integer. + */ + byte[] al = ByteBuffer.allocate(8) + .order(ByteOrder.BIG_ENDIAN) + .putLong(encryptionParts.aad().length * 8L) + .array(); + + /* + A message Authentication Tag T is computed by applying HMAC to the following data, in order: + the Additional Authenticated Data A, + the Initialization Vector IV, + the ciphertext E computed in the previous step, and + the octet string AL defined above. + + We denote the output + of the MAC computed in this step as M. The first T_LEN octets of + M are used as T. + */ Mac mac = macInstance(); - mac.init(new SecretKeySpec(encryptionParts.key(), "AES")); + mac.init(new SecretKeySpec(encryptionParts.mac(), "AES")); mac.update(encryptionParts.aad()); + mac.update(encryptionParts.iv()); mac.update(encryptionParts.encryptedContent()); - byte[] authKey = mac.doFinal(); - return MessageDigest.isEqual(authKey, encryptionParts.authTag()); + mac.update(al); + return mac.doFinal(); } catch (InvalidKeyException e) { - throw new JwtException("Exception occurred while HMAC signature."); + throw new JwtException("Exception occurred while calculating HMAC signature."); } } @@ -809,6 +878,11 @@ private Mac macInstance() { throw new JwtException("Could not find MAC instance: " + hmac); } } + + @Override + boolean hasHmac() { + return true; + } } private static class AesGcmAlgorithm extends AesAlgorithm { @@ -827,10 +901,10 @@ public EncryptionParts encrypt(byte[] plainContent, byte[] aad) { System.arraycopy(wholeEncryptedContent, 0, encryptedContent, 0, encryptedContent.length); System.arraycopy(wholeEncryptedContent, length, authTag, 0, authTag.length); return new EncryptionParts(encryptionParts.key(), + encryptionParts.mac(), encryptionParts.iv(), encryptionParts.aad(), - encryptedContent, - authTag); + encryptedContent, authTag); } @Override @@ -843,6 +917,7 @@ byte[] decrypt(EncryptionParts encryptionParts) { System.arraycopy(encryptedPayload, 0, result, 0, epl); System.arraycopy(authTag, 0, result, epl, al); EncryptionParts newEncParts = new EncryptionParts(encryptionParts.key(), + encryptionParts.mac(), encryptionParts.iv(), encryptionParts.aad(), result, @@ -861,41 +936,8 @@ protected void postCipherConstruct(Cipher cipher, EncryptionParts encryptionPart } } - private static final class EncryptionParts { + private record EncryptionParts(byte[] key, byte[] mac, byte[] iv, byte[] aad, byte[] encryptedContent, byte[] authTag) { - private final byte[] key; - private final byte[] iv; - private final byte[] aad; - private final byte[] encryptedContent; - private final byte[] authTag; - - private EncryptionParts(byte[] key, byte[] iv, byte[] aad, byte[] encryptedContent, byte[] authTag) { - this.key = key; - this.iv = iv; - this.aad = aad; - this.encryptedContent = encryptedContent; - this.authTag = authTag; - } - - public byte[] key() { - return key; - } - - public byte[] iv() { - return iv; - } - - public byte[] aad() { - return aad; - } - - public byte[] encryptedContent() { - return encryptedContent; - } - - public byte[] authTag() { - return authTag; - } } } diff --git a/security/jwt/src/test/java/io/helidon/security/jwt/EncryptedJwtTest.java b/security/jwt/src/test/java/io/helidon/security/jwt/EncryptedJwtTest.java index eabf8177b87..bda132d556d 100644 --- a/security/jwt/src/test/java/io/helidon/security/jwt/EncryptedJwtTest.java +++ b/security/jwt/src/test/java/io/helidon/security/jwt/EncryptedJwtTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2022 Oracle and/or its affiliates. + * Copyright (c) 2021, 2025 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,25 @@ package io.helidon.security.jwt; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.text.ParseException; import java.util.Optional; import io.helidon.common.configurable.Resource; import io.helidon.security.jwt.EncryptedJwt.SupportedAlgorithm; import io.helidon.security.jwt.jwk.JwkKeys; - +import io.helidon.security.jwt.jwk.JwkRSA; + +import com.nimbusds.jose.EncryptionMethod; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWEAlgorithm; +import com.nimbusds.jose.JWEHeader; +import com.nimbusds.jose.JWEObject; +import com.nimbusds.jose.Payload; +import com.nimbusds.jose.crypto.RSADecrypter; +import com.nimbusds.jose.crypto.RSAEncrypter; +import com.nimbusds.jwt.SignedJWT; import jakarta.json.JsonObject; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -127,4 +140,39 @@ public void testCustomEncryptAndDecrypt() { assertThat(decryptedOne.headerJson(), is(decryptedTwo.headerJson())); } + @Test + void testNimbusToHelidon() throws ParseException, JOSEException { + JwkRSA jwk = (JwkRSA) jwkKeys.forKeyId("RS_512").orElseThrow(); + RSAPublicKey publicKey = (RSAPublicKey) jwk.publicKey(); + + Payload payload = new Payload(SignedJWT.parse(signedJwt.tokenContent())); + JWEHeader header = new JWEHeader(JWEAlgorithm.RSA_OAEP_256, EncryptionMethod.A256CBC_HS512); + JWEObject jweObject = new JWEObject(header, payload); + jweObject.encrypt(new RSAEncrypter(publicKey)); + String serializedJweFromNimbus = jweObject.serialize(); + + EncryptedJwt encryptedJwt = parseToken(serializedJweFromNimbus); + SignedJwt decrypted = encryptedJwt.decrypt(jwk); + + assertThat(decrypted.payloadJson().toString(), is(signedJwt.payloadJson().toString())); + } + + @Test + void testHelidonToNimbus() throws ParseException, JOSEException { + JwkRSA jwk = (JwkRSA) jwkKeys.forKeyId("RS_512").orElseThrow(); + RSAPrivateKey privateKey = (RSAPrivateKey) jwk.privateKey().orElseThrow(); + + EncryptedJwt encryptedJwt = builder(signedJwt) + .jwk(jwk) + .algorithm(SupportedAlgorithm.RSA_OAEP_256) + .encryption(SupportedEncryption.A256CBC_HS512) + .build(); + + JWEObject jweObject = JWEObject.parse(encryptedJwt.token()); + jweObject.decrypt(new RSADecrypter(privateKey)); + SignedJWT signedJWT = SignedJWT.parse(jweObject.getPayload().toString()); + + assertThat(signedJWT.getPayload().toString(), is(signedJwt.payloadJson().toString())); + } + } \ No newline at end of file diff --git a/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/BaseBuilder.java b/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/BaseBuilder.java index bc652b4c08f..43735cc7231 100644 --- a/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/BaseBuilder.java +++ b/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/BaseBuilder.java @@ -62,6 +62,7 @@ abstract class BaseBuilder, T> implements Builder contentKeyDecryptionKeys() { + return Optional.empty(); + } + /** * A fluent API {@link io.helidon.common.Builder} to build instances of {@link TenantConfig}. */ diff --git a/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/TenantConfigImpl.java b/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/TenantConfigImpl.java index 3a637712071..343948f4cc7 100644 --- a/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/TenantConfigImpl.java +++ b/security/providers/oidc-common/src/main/java/io/helidon/security/providers/oidc/common/TenantConfigImpl.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Oracle and/or its affiliates. + * Copyright (c) 2023, 2025 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ class TenantConfigImpl implements TenantConfig { private final OidcConfig.ClientAuthentication tokenEndpointAuthentication; private final Duration clientTimeout; private final JwkKeys signJwk; + private final JwkKeys contentKeyDecryptionKeys; private final String clientSecret; private final URI introspectUri; private final URI logoutEndpointUri; @@ -73,6 +74,7 @@ class TenantConfigImpl implements TenantConfig { this.clientSecret = builder.clientSecret(); this.signJwk = builder.signJwk(); + this.contentKeyDecryptionKeys = builder.contentKeyDecryptionKeys(); this.oidcMetadata = builder.oidcMetadata(); this.useWellKnown = builder.useWellKnown(); @@ -201,4 +203,8 @@ public boolean useWellKnown() { return useWellKnown; } + @Override + public Optional contentKeyDecryptionKeys() { + return Optional.ofNullable(contentKeyDecryptionKeys); + } } diff --git a/security/providers/oidc-common/src/main/java/module-info.java b/security/providers/oidc-common/src/main/java/module-info.java index 294aa672f6d..d4685398a25 100644 --- a/security/providers/oidc-common/src/main/java/module-info.java +++ b/security/providers/oidc-common/src/main/java/module-info.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023 Oracle and/or its affiliates. + * Copyright (c) 2018, 2025 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ requires io.helidon.common.context; requires io.helidon.common.crypto; requires io.helidon.common.parameters; + requires io.helidon.common.pki; requires io.helidon.cors; requires io.helidon.http.media.jsonp; requires io.helidon.security.providers.common; diff --git a/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/OidcFeature.java b/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/OidcFeature.java index 9e5d593a272..e1acaf306ab 100644 --- a/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/OidcFeature.java +++ b/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/OidcFeature.java @@ -48,8 +48,11 @@ import io.helidon.http.Status; import io.helidon.security.Security; import io.helidon.security.SecurityException; +import io.helidon.security.jwt.EncryptedJwt; import io.helidon.security.jwt.Jwt; +import io.helidon.security.jwt.JwtHeaders; import io.helidon.security.jwt.SignedJwt; +import io.helidon.security.jwt.jwk.JwkKeys; import io.helidon.security.providers.oidc.common.OidcConfig; import io.helidon.security.providers.oidc.common.OidcCookieHandler; import io.helidon.security.providers.oidc.common.Tenant; @@ -438,7 +441,7 @@ private void processCodeWithTenant(String code, ServerRequest req, ServerRespons if (response.status().family() == Status.Family.SUCCESSFUL) { try { JsonObject jsonObject = response.as(JsonObject.class); - processJsonResponse(req, res, jsonObject, tenantName, stateCookie); + processJsonResponse(req, res, jsonObject, tenantName, stateCookie, tenant); } catch (Exception e) { processError(res, e, "Failed to read JSON from response"); } @@ -498,16 +501,26 @@ private String processJsonResponse(ServerRequest req, ServerResponse res, JsonObject json, String tenantName, - JsonObject stateCookie) { + JsonObject stateCookie, + Tenant tenant) { String accessToken = json.getString("access_token"); String idToken = json.getString("id_token", null); String refreshToken = json.getString("refresh_token", null); + JwtHeaders jwtHeaders = JwtHeaders.parseToken(idToken); + TenantConfig tenantConfig = tenant.tenantConfig(); - Jwt idTokenJwt = SignedJwt.parseToken(idToken).getJwt(); + Jwt idTokenJwt; + if (jwtHeaders.encryption().isPresent() && tenantConfig.contentKeyDecryptionKeys().isPresent()) { + EncryptedJwt encryptedJwt = EncryptedJwt.parseToken(jwtHeaders, idToken); + JwkKeys jwkKeys = tenantConfig.contentKeyDecryptionKeys().get(); + idTokenJwt = encryptedJwt.decrypt(jwkKeys, jwkKeys.keys().getFirst()).getJwt(); + } else { + idTokenJwt = SignedJwt.parseToken(idToken).getJwt(); + } String nonceOriginal = stateCookie.getString("nonce"); - String nonceAccess = idTokenJwt.nonce() + String nonceIdToken = idTokenJwt.nonce() .orElseThrow(() -> new IllegalStateException("Nonce is required to be present in the id token")); - if (!nonceAccess.equals(nonceOriginal)) { + if (!nonceIdToken.equals(nonceOriginal)) { throw new IllegalStateException("Original nonce and the one obtained from id token does not match"); } diff --git a/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/TenantAuthenticationHandler.java b/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/TenantAuthenticationHandler.java index ce0f90e4820..1b4f8b208d1 100644 --- a/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/TenantAuthenticationHandler.java +++ b/security/providers/oidc/src/main/java/io/helidon/security/providers/oidc/TenantAuthenticationHandler.java @@ -57,8 +57,10 @@ import io.helidon.security.SecurityResponse; import io.helidon.security.Subject; import io.helidon.security.abac.scope.ScopeValidator; +import io.helidon.security.jwt.EncryptedJwt; import io.helidon.security.jwt.Jwt; import io.helidon.security.jwt.JwtException; +import io.helidon.security.jwt.JwtHeaders; import io.helidon.security.jwt.JwtUtil; import io.helidon.security.jwt.JwtValidator; import io.helidon.security.jwt.SignedJwt; @@ -518,9 +520,24 @@ private String encode(String state) { } private AuthenticationResponse validateIdToken(String tenantId, ProviderRequest providerRequest, String idToken) { + JwtHeaders jwtHeaders = JwtHeaders.parseToken(idToken); SignedJwt signedJwt; try { - signedJwt = SignedJwt.parseToken(idToken); + if (jwtHeaders.encryption().isPresent()) { + if (tenantConfig.contentKeyDecryptionKeys().isEmpty() + || tenantConfig.contentKeyDecryptionKeys().get().keys().isEmpty()) { + if (LOGGER.isLoggable(System.Logger.Level.DEBUG)) { + LOGGER.log(System.Logger.Level.DEBUG, "There are no content key decryption keys configured " + + "for the tenant: " + tenantId); + } + return AuthenticationResponse.failed("Failed to decrypt ID Token"); + } + EncryptedJwt encryptedJwt = EncryptedJwt.parseToken(jwtHeaders, idToken); + JwkKeys jwkKeys = tenantConfig.contentKeyDecryptionKeys().get(); + signedJwt = encryptedJwt.decrypt(jwkKeys, jwkKeys.keys().getFirst()); + } else { + signedJwt = SignedJwt.parseToken(idToken); + } } catch (Exception e) { //invalid token if (LOGGER.isLoggable(System.Logger.Level.DEBUG)) { @@ -576,9 +593,24 @@ private AuthenticationResponse validateAccessToken(String tenantId, ProviderRequest providerRequest, String token, Jwt idToken) { + JwtHeaders jwtHeaders = JwtHeaders.parseToken(token); SignedJwt signedJwt; try { - signedJwt = SignedJwt.parseToken(token); + if (jwtHeaders.encryption().isPresent()) { + if (tenantConfig.contentKeyDecryptionKeys().isEmpty() + || tenantConfig.contentKeyDecryptionKeys().get().keys().isEmpty()) { + if (LOGGER.isLoggable(System.Logger.Level.DEBUG)) { + LOGGER.log(System.Logger.Level.DEBUG, "There are no content key decryption keys configured " + + "for the tenant: " + tenantId); + } + return AuthenticationResponse.failed("Failed to decrypt Access Token"); + } + EncryptedJwt encryptedJwt = EncryptedJwt.parseToken(jwtHeaders, token); + JwkKeys jwkKeys = tenantConfig.contentKeyDecryptionKeys().get(); + signedJwt = encryptedJwt.decrypt(jwkKeys, jwkKeys.keys().getFirst()); + } else { + signedJwt = SignedJwt.parseToken(token); + } } catch (Exception e) { //invalid token if (LOGGER.isLoggable(System.Logger.Level.DEBUG)) {