Skip to content

WIP: feat/support decryption key alias #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ decrypted_response_payload = decrypt_payload(body, config)

##### Configuring the JWE Encryption <a name="configuring-the-jwe-encryption"></a>

`jwe_encryption` needs a config dictionary to instruct how to decrypt/decrypt the payloads. Example:
`jwe_encryption` needs a config dictionary to instruct how to encrypt/decrypt the payloads. Example:

```json
{
Expand All @@ -135,6 +135,15 @@ decrypted_response_payload = decrypt_payload(body, config)
"decryptionKey": "./path/to/your/private.key",
}
```
You can also pass in a PKCS12 file with the password to decrypt it:
```json
{
// .... rest of the config

"decryptionKey": "./path/to/your/certStore.p12", // or "keyStore": "./path/to/your/certStore.p12",
"decryptionKeyPassword": "the-password", // or "keyStorePassword": "the-password",
}
```

The above can be either stored to a file or passed to 'JweEncryptionConfig' as dictionary:
```python
Expand Down Expand Up @@ -274,6 +283,15 @@ decrypted_response_payload = decrypt_payload(body, config)
"oaepPaddingDigestAlgorithm": "SHA256"
}
```
You can also pass in a PKCS12 file with the password to decrypt it:
```json
{
// .... rest of the config

"decryptionKey": "./path/to/your/certStore.p12", // or "keyStore": "./path/to/your/certStore.p12",
"decryptionKeyPassword": "the-password", // or "keyStorePassword": "the-password",
}
```

The above can be either stored to a file or passed to 'FieldLevelEncryptionConfig' as dictionary:
```python
Expand Down
27 changes: 27 additions & 0 deletions client_encryption/encryption_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,33 @@ def write_encryption_certificate(certificate_path, certificate, cert_type):
with open(certificate_path, "wb") as f:
f.write(certificate.public_bytes(cert_type))

def __get_config_value(config, key_aliases):
"""
Helper for getting config values which have aliases
Args:
config: The config dictionary
key_aliases: List of key aliases
"""

for key in key_aliases:
if key in config:
return config[key]
return None

def load_decryption_key_from_config(config):
"""
Helper for reading decryption key from the config
Args:
config: The config dictionary
"""

key_file_path = __get_config_value(config, ["decryptionKey", "keyStore"])
if not key_file_path:
return None

password = __get_config_value(config, ["decryptionKeyPassword", "keyStorePassword"])

return load_decryption_key(key_file_path, password)

def load_decryption_key(key_file_path, decryption_key_password=None):
"""Load a RSA decryption key."""
Expand Down
8 changes: 2 additions & 6 deletions client_encryption/field_level_encryption_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from cryptography.hazmat.primitives.serialization import PublicFormat, Encoding

from client_encryption import encoding_utils
from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key, validate_hash_algorithm
from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key_from_config, validate_hash_algorithm


class FieldLevelEncryptionConfig(object):
Expand Down Expand Up @@ -42,11 +42,7 @@ def __init__(self, conf):
self._encryption_certificate_fingerprint = None
self._encryption_certificate_type = None

if "decryptionKey" in json_config:
decryption_key_password = json_config.get("decryptionKeyPassword", None)
self._decryption_key = load_decryption_key(json_config["decryptionKey"], decryption_key_password)
else:
self._decryption_key = None
self._decryption_key = load_decryption_key_from_config(json_config)

self._oaep_padding_digest_algorithm = validate_hash_algorithm(json_config["oaepPaddingDigestAlgorithm"])

Expand Down
8 changes: 2 additions & 6 deletions client_encryption/jwe_encryption_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from cryptography.hazmat.primitives.serialization import PublicFormat, Encoding

from client_encryption.encoding_utils import ClientEncoding
from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key
from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key_from_config


class JweEncryptionConfig(object):
Expand Down Expand Up @@ -37,11 +37,7 @@ def __init__(self, conf):
self._encryption_key_fingerprint = None
self._encryption_certificate_type = None

if "decryptionKey" in json_config:
decryption_key_password = json_config.get("decryptionKeyPassword", None)
self._decryption_key = load_decryption_key(json_config["decryptionKey"], decryption_key_password)
else:
self._decryption_key = None
self._decryption_key = load_decryption_key_from_config(json_config)

self._encrypted_value_field_name = json_config["encryptedValueFieldName"]

Expand Down
75 changes: 75 additions & 0 deletions tests/test_encryption_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,81 @@ def test_load_decryption_key_file_does_not_exist(self):

self.assertRaises(PrivateKeyError, to_test.load_decryption_key, key_path)

def test_load_decryption_key_from_config_with_password_aliases(self):
test_cases = [
# p12 files
("decryptionKey", "decryptionKeyPassword", "keys/test_key.p12", self._pkcs12),
("decryptionKey", "keyStorePassword", "keys/test_key.p12", self._pkcs12),
("keyStore", "decryptionKeyPassword", "keys/test_key.p12", self._pkcs12),
("keyStore", "keyStorePassword", "keys/test_key.p12", self._pkcs12),

# der files
("keyStore", None, "keys/test_key_pkcs8-2048.der", self._pkcs8_2048),
("decryptionKey", None, "keys/test_key_pkcs8-2048.der", self._pkcs8_2048),

# pem files
("keyStore", None, "keys/test_key_pkcs8-2048.pem", self._pkcs8_2048),
("decryptionKey", None, "keys/test_key_pkcs8-2048.pem", self._pkcs8_2048),
]

for key_field, password_field, file, expected_key in test_cases:
with self.subTest(key=key_field, password_field=password_field, file=file, expected_key=expected_key):
config = {
key_field: resource_path(file),
}

if password_field is not None:
config[password_field] = "Password1"

key = to_test.load_decryption_key_from_config(config)

self.assertIsNotNone(key)
self.assertIsInstance(key, RSA.RsaKey, "Must be RSA key")
self.assertEqual(expected_key, self.__strip_key(key), "Decryption key does not match")

def test_load_decryption_key_from_config_invalid_file_or_password(self):
test_cases = [
# valid p12 files with wrong password
("decryptionKey", "decryptionKeyPassword", "wrong-password", "keys/invalid.p12"),
("decryptionKey", "keyStorePassword", "wrong-password","keys/invalid.p12"),
("keyStore", "decryptionKeyPassword", "wrong-password","keys/invalid.p12"),
("keyStore", "keyStorePassword", "wrong-password", "keys/invalid.p12"),

# invalid p12 files
("decryptionKey", "decryptionKeyPassword", "Password1", "keys/invalid.p12"),
("decryptionKey", "keyStorePassword", "Password1","keys/invalid.p12"),
("keyStore", "decryptionKeyPassword", "Password1","keys/invalid.p12"),
("keyStore", "keyStorePassword", "Password1", "keys/invalid.p12"),

# invalid der files
("keyStore", None, None, "keys/invalid-2048.der"),
("decryptionKey", None, None, "keys/invalid-2048.der"),

# invalid pem files
("keyStore", None, None, "keys/invalid-2048.pem"),
("decryptionKey", None, None, "keys/invalid-2048.pem"),
]

for key_field, password_field, password, file in test_cases:
with self.subTest(key=key_field, password_field=password_field, password=password, file=file):
config = {
key_field: resource_path(file),
}

if password is not None:
config[password_field] = password

self.assertRaises(PrivateKeyError, to_test.load_decryption_key, config)

def test_load_decryption_key_from_config_no_key_field(self):
"""Test load_decryption_key_from_config returns None when no key field present"""
configs = [{}, {"not-decryptionKey": "value"}]
for config in configs:
with self.subTest(config):
key = to_test.load_decryption_key_from_config(config)
self.assertIsNone(key)


def test_load_hash_algorithm(self):
hash_algo = to_test.load_hash_algorithm("SHA224")

Expand Down
28 changes: 28 additions & 0 deletions tests/test_field_level_encryption_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest.mock import patch, MagicMock
from tests import resource_path, get_mastercard_config_for_test
import json
import client_encryption.field_level_encryption_config as to_test
Expand Down Expand Up @@ -136,6 +137,33 @@ def test_load_config_decryption_key_file_not_found(self):

self.assertRaises(PrivateKeyError, to_test.FieldLevelEncryptionConfig, wrong_json)

@patch('client_encryption.field_level_encryption_config.load_decryption_key_from_config')
def test_load_config_uses_load_decryption_key_from_config(self, mock_load_key):
"""Test that FieldLevelEncryptionConfig uses load_decryption_key_from_config"""
mock_keys = [MagicMock(), None]

for mock_key in mock_keys:
with self.subTest(mock_key=mock_key):
mock_load_key.return_value = mock_key

json_conf = json.loads(self._test_config_file)
conf = to_test.FieldLevelEncryptionConfig(json_conf)

mock_load_key.assert_called_with(json_conf)
self.assertEqual(conf.decryption_key, mock_key)

@patch('client_encryption.field_level_encryption_config.load_decryption_key_from_config')
def test_load_config_propagates_key_loading_exceptions(self, mock_load_key):
"""Test that FieldLevelEncryptionConfig propagates exceptions from load_decryption_key_from_config"""
mock_load_key.side_effect = PrivateKeyError("some error")

json_conf = json.loads(self._test_config_file)

with self.assertRaises(PrivateKeyError):
to_test.FieldLevelEncryptionConfig(json_conf)

mock_load_key.assert_called_once_with(json_conf)

def test_load_config_missing_oaep_padding_algorithm(self):
wrong_json = json.loads(self._test_config_file)
del wrong_json["oaepPaddingDigestAlgorithm"]
Expand Down
28 changes: 28 additions & 0 deletions tests/test_jwe_encryption_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import unittest
from unittest.mock import patch, MagicMock

from Crypto.PublicKey import RSA

Expand Down Expand Up @@ -106,6 +107,33 @@ def test_load_config_decryption_key_file_not_found(self):

self.assertRaises(PrivateKeyError, to_test.JweEncryptionConfig, wrong_json)

@patch('client_encryption.jwe_encryption_config.load_decryption_key_from_config')
def test_load_config_uses_load_decryption_key_from_config(self, mock_load_key):
"""Test that JweEncryptionConfig uses load_decryption_key_from_config"""
mock_keys = [MagicMock(), None]

for mock_key in mock_keys:
with self.subTest(mock_key=mock_key):
mock_load_key.return_value = mock_key

json_conf = json.loads(self._test_config_file)
conf = to_test.JweEncryptionConfig(json_conf)

mock_load_key.assert_called_with(json_conf)
self.assertEqual(conf.decryption_key, mock_key)

@patch('client_encryption.jwe_encryption_config.load_decryption_key_from_config')
def test_load_config_propagates_key_loading_exceptions(self, mock_load_key):
"""Test that JweEncryptionConfig propagates exceptions from load_decryption_key_from_config"""
mock_load_key.side_effect = PrivateKeyError("some error")

json_conf = json.loads(self._test_config_file)

with self.assertRaises(PrivateKeyError):
to_test.JweEncryptionConfig(json_conf)

mock_load_key.assert_called_once_with(json_conf)

def __check_configuration(self, conf, encoding=ClientEncoding.BASE64, oaep_algo="SHA256"):
self.assertIsNotNone(conf.paths["$"], "No resource to encrypt/decrypt fields of is set")
resource = conf.paths["$"]
Expand Down