Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
<version.lib.maven.plugin.project>2.2.1</version.lib.maven.plugin.project>
<version.lib.apiguardian.api>1.1.2</version.lib.apiguardian.api>
<version.lib.groovy>2.4.16</version.lib.groovy>
<version.lib.nimbus.jose.jwt>10.4</version.lib.nimbus.jose.jwt>
<!--
!Version statement! - end
-->
Expand Down Expand Up @@ -1433,6 +1434,12 @@
<artifactId>apiguardian-api</artifactId>
<version>${version.lib.apiguardian.api}</version>
</dependency>
<!-- For JWE testing -->
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
<version>${version.lib.nimbus.jose.jwt}</version>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down
6 changes: 6 additions & 0 deletions security/jwt/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,11 @@
<artifactId>hamcrest-core</artifactId>
<scope>test</scope>
</dependency>
<!-- For JWE testing -->
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
178 changes: 110 additions & 68 deletions security/jwt/src/main/java/io/helidon/security/jwt/EncryptedJwt.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, 2024 Oracle and/or its affiliates.
* Copyright (c) 2021, 2025 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package io.helidon.security.jwt;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
Expand Down Expand Up @@ -46,6 +48,7 @@
import javax.crypto.spec.SecretKeySpec;

import io.helidon.common.Errors;
import io.helidon.common.LazyValue;
import io.helidon.security.jwt.jwk.Jwk;
import io.helidon.security.jwt.jwk.JwkEC;
import io.helidon.security.jwt.jwk.JwkKeys;
Expand All @@ -61,8 +64,11 @@
*/
public final class EncryptedJwt {

private static final byte[] EMPTY_BYTES = new byte[0];
private static final Map<SupportedEncryption, AesAlgorithm> CONTENT_ENCRYPTION;

private static final LazyValue<SecureRandom> RANDOM = LazyValue.create(SecureRandom::new);

private static final Pattern JWE_PATTERN = Pattern
.compile("(^[\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+$)");
private static final Base64.Decoder URL_DECODER = Base64.getUrlDecoder();
Expand Down Expand Up @@ -175,7 +181,7 @@ public static EncryptedJwt parseToken(String token) {
* </ul>
*
* @param header parsed JWT header
* @param token String with the token
* @param token String with the token
* @return Encrypted jwt parts
* @throws RuntimeException in case of invalid content, see {@link Errors.ErrorMessagesException}
*/
Expand All @@ -198,9 +204,9 @@ public static EncryptedJwt parseToken(JwtHeaders header, String token) {
/**
* Add validator of kek algorithm to the collection of validators.
*
* @param validators collection of validators
* @param expectedKekAlg audience key encryption key algorithm
* @param mandatory whether the alg field is mandatory in the token
* @param validators collection of validators
* @param expectedKekAlg audience key encryption key algorithm
* @param mandatory whether the alg field is mandatory in the token
*/
public static void addKekValidator(Collection<Validator<EncryptedJwt>> validators, String expectedKekAlg, boolean mandatory) {
validators.add((encryptedJwt, collector) -> {
Expand Down Expand Up @@ -379,21 +385,38 @@ public SignedJwt decrypt(JwkKeys jwkKeys, Jwk defaultJwk) {

errors.collect().checkValid();

byte[] decryptedKey = unwrapRsa(supportedAlgorithm, privateKey, encryptedKey);
//Base64 headers are used as an aad. This aad has to be in US_ASCII encoding.
EncryptionParts encryptionParts = new EncryptionParts(decryptedKey,
iv,
headerBase64.getBytes(StandardCharsets.US_ASCII),
encryptedPayload,
authTag);
AesAlgorithm aesAlgorithm;
try {
SupportedEncryption supportedEncryption = SupportedEncryption.getValue(enc);
aesAlgorithm = CONTENT_ENCRYPTION.get(supportedEncryption);
} catch (IllegalArgumentException e) {
throw new JwtException("Unsupported content encryption: " + enc);
}

byte[] decryptedKey = unwrapRsa(supportedAlgorithm, privateKey, encryptedKey);
byte[] macKey;
byte[] encKey;
if (aesAlgorithm.hasHmac()) {
int keySizeInBytes = aesAlgorithm.keySize / 8;
macKey = new byte[keySizeInBytes];
encKey = new byte[keySizeInBytes];
System.arraycopy(decryptedKey, 0, macKey, 0, keySizeInBytes);
System.arraycopy(decryptedKey, keySizeInBytes, encKey, 0, keySizeInBytes);
} else {
encKey = decryptedKey;
macKey = EMPTY_BYTES;
}
//Base64 headers are used as an aad. This aad has to be in US_ASCII encoding.
EncryptionParts encryptionParts = new EncryptionParts(encKey,
macKey,
iv,
headerBase64.getBytes(StandardCharsets.US_ASCII),
encryptedPayload, authTag);

String decryptedPayload = new String(aesAlgorithm.decrypt(encryptionParts), StandardCharsets.UTF_8);
Arrays.fill(decryptedKey, (byte) 0); //clear the decrypted key from the memory
Arrays.fill(macKey, (byte) 0);
Arrays.fill(encKey, (byte) 0);
return SignedJwt.parseToken(decryptedPayload);
}

Expand Down Expand Up @@ -561,9 +584,13 @@ public EncryptedJwt build() {
//Base64 headers are used as an aad. This aad has to be in US_ASCII encoding.
EncryptionParts encryptionParts = contentEncryption.encrypt(jwt.tokenContent().getBytes(StandardCharsets.UTF_8),
headersBase64.getBytes(StandardCharsets.US_ASCII));
byte[] mac = encryptionParts.mac();
byte[] aesKey = encryptionParts.key();
byte[] result = new byte[mac.length + aesKey.length];
System.arraycopy(mac, 0, result, 0, mac.length);
System.arraycopy(aesKey, 0, result, mac.length, aesKey.length);

byte[] encryptedAesKey = wrapRsa(algorithm, publicKey, aesKey);
byte[] encryptedAesKey = wrapRsa(algorithm, publicKey, result);
String token = tokenBuilder.append(headersBase64).append(".")
.append(encode(encryptedAesKey)).append(".")
.append(encode(encryptionParts.iv())).append(".")
Expand Down Expand Up @@ -696,8 +723,6 @@ static SupportedEncryption getValue(String value) {

private static class AesAlgorithm {

private static final SecureRandom RANDOM = new SecureRandom();

private final String cipher;
private final int keySize;
private final int ivSize;
Expand All @@ -711,16 +736,16 @@ private AesAlgorithm(String cipher, int keySize, int ivSize) {
EncryptionParts encrypt(byte[] plainContent, byte[] aad) {
try {
KeyGenerator kgen = KeyGenerator.getInstance("AES");
kgen.init(keySize, RANDOM);
kgen.init(keySize, RANDOM.get());
SecretKey secretKey = kgen.generateKey();
byte[] iv = new byte[ivSize];
RANDOM.nextBytes(iv);
EncryptionParts encryptionParts = new EncryptionParts(secretKey.getEncoded(), iv, aad, null, null);
RANDOM.get().nextBytes(iv);
EncryptionParts encryptionParts = new EncryptionParts(secretKey.getEncoded(), EMPTY_BYTES, iv, aad, null, null);
Cipher cipher = Cipher.getInstance(this.cipher);
cipher.init(Cipher.ENCRYPT_MODE, secretKey, createParameterSpec(encryptionParts));
postCipherConstruct(cipher, encryptionParts);
byte[] encryptedContent = cipher.doFinal(plainContent);
return new EncryptionParts(secretKey.getEncoded(), iv, aad, encryptedContent, null);
return new EncryptionParts(secretKey.getEncoded(), EMPTY_BYTES, iv, aad, encryptedContent, null);
} catch (Exception e) {
throw new JwtException("Exception during content encryption", e);
}
Expand All @@ -747,6 +772,13 @@ protected AlgorithmParameterSpec createParameterSpec(EncryptionParts encryptionP
return new IvParameterSpec(encryptionParts.iv());
}

boolean hasHmac() {
return false;
}

int keySize() {
return keySize;
}
}

private static class AesAlgorithmWithHmac extends AesAlgorithm {
Expand All @@ -761,24 +793,32 @@ private AesAlgorithmWithHmac(String cipher, int keySize, int ivSize, String hmac
@Override
public EncryptionParts encrypt(byte[] plainContent, byte[] aad) {
EncryptionParts encryptionParts = super.encrypt(plainContent, aad);
encryptionParts = createMacKey(encryptionParts);
byte[] authTag = sign(encryptionParts);
return new EncryptionParts(encryptionParts.key(),
encryptionParts.mac(),
encryptionParts.iv(),
encryptionParts.aad(),
encryptionParts.encryptedContent(),
authTag);
}

private EncryptionParts createMacKey(EncryptionParts encryptionParts) {
byte[] mac = new byte[keySize() / 8];
RANDOM.get().nextBytes(mac);
return new EncryptionParts(encryptionParts.key(),
mac,
encryptionParts.iv(),
encryptionParts.aad(),
encryptionParts.encryptedContent(),
encryptionParts.authTag());
}

private byte[] sign(EncryptionParts parts) {
try {
Mac mac = macInstance();
mac.init(new SecretKeySpec(parts.key(), "AES"));
mac.update(parts.aad());
mac.update(parts.encryptedContent());
return mac.doFinal();
} catch (InvalidKeyException e) {
throw new JwtException("Exception occurred while HMAC signature");
}
byte[] fullHmacTag = computeMac(parts);
byte[] authKey = new byte[fullHmacTag.length / 2];
System.arraycopy(fullHmacTag, 0, authKey, 0, authKey.length);
return authKey;
}

@Override
Expand All @@ -790,15 +830,44 @@ public byte[] decrypt(EncryptionParts encryptionParts) {
}

private boolean verifySignature(EncryptionParts encryptionParts) {
byte[] fullHmacTag = computeMac(encryptionParts);
byte[] authKey = new byte[fullHmacTag.length / 2];
System.arraycopy(fullHmacTag, 0, authKey, 0, authKey.length);
return MessageDigest.isEqual(authKey, encryptionParts.authTag());
}

private byte[] computeMac(EncryptionParts encryptionParts) {
try {
/*
The octet string AL is equal to the number of bits in the
Additional Authenticated Data A expressed as a 64-bit unsigned
big-endian integer.
*/
byte[] al = ByteBuffer.allocate(8)
.order(ByteOrder.BIG_ENDIAN)
.putLong(encryptionParts.aad().length * 8L)
.array();

/*
A message Authentication Tag T is computed by applying HMAC to the following data, in order:
the Additional Authenticated Data A,
the Initialization Vector IV,
the ciphertext E computed in the previous step, and
the octet string AL defined above.

We denote the output
of the MAC computed in this step as M. The first T_LEN octets of
M are used as T.
*/
Mac mac = macInstance();
mac.init(new SecretKeySpec(encryptionParts.key(), "AES"));
mac.init(new SecretKeySpec(encryptionParts.mac(), "AES"));
mac.update(encryptionParts.aad());
mac.update(encryptionParts.iv());
mac.update(encryptionParts.encryptedContent());
byte[] authKey = mac.doFinal();
return MessageDigest.isEqual(authKey, encryptionParts.authTag());
mac.update(al);
return mac.doFinal();
} catch (InvalidKeyException e) {
throw new JwtException("Exception occurred while HMAC signature.");
throw new JwtException("Exception occurred while calculating HMAC signature.");
}
}

Expand All @@ -809,6 +878,11 @@ private Mac macInstance() {
throw new JwtException("Could not find MAC instance: " + hmac);
}
}

@Override
boolean hasHmac() {
return true;
}
}

private static class AesGcmAlgorithm extends AesAlgorithm {
Expand All @@ -827,10 +901,10 @@ public EncryptionParts encrypt(byte[] plainContent, byte[] aad) {
System.arraycopy(wholeEncryptedContent, 0, encryptedContent, 0, encryptedContent.length);
System.arraycopy(wholeEncryptedContent, length, authTag, 0, authTag.length);
return new EncryptionParts(encryptionParts.key(),
encryptionParts.mac(),
encryptionParts.iv(),
encryptionParts.aad(),
encryptedContent,
authTag);
encryptedContent, authTag);
}

@Override
Expand All @@ -843,6 +917,7 @@ byte[] decrypt(EncryptionParts encryptionParts) {
System.arraycopy(encryptedPayload, 0, result, 0, epl);
System.arraycopy(authTag, 0, result, epl, al);
EncryptionParts newEncParts = new EncryptionParts(encryptionParts.key(),
encryptionParts.mac(),
encryptionParts.iv(),
encryptionParts.aad(),
result,
Expand All @@ -861,41 +936,8 @@ protected void postCipherConstruct(Cipher cipher, EncryptionParts encryptionPart
}
}

private static final class EncryptionParts {
private record EncryptionParts(byte[] key, byte[] mac, byte[] iv, byte[] aad, byte[] encryptedContent, byte[] authTag) {

private final byte[] key;
private final byte[] iv;
private final byte[] aad;
private final byte[] encryptedContent;
private final byte[] authTag;

private EncryptionParts(byte[] key, byte[] iv, byte[] aad, byte[] encryptedContent, byte[] authTag) {
this.key = key;
this.iv = iv;
this.aad = aad;
this.encryptedContent = encryptedContent;
this.authTag = authTag;
}

public byte[] key() {
return key;
}

public byte[] iv() {
return iv;
}

public byte[] aad() {
return aad;
}

public byte[] encryptedContent() {
return encryptedContent;
}

public byte[] authTag() {
return authTag;
}
}

}
Loading
Loading