diff --git a/README.md b/README.md index d63a761..ddd6008 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ decrypted_response_payload = decrypt_payload(body, config) ##### Configuring the JWE Encryption -`jwe_encryption` needs a config dictionary to instruct how to decrypt/decrypt the payloads. Example: +`jwe_encryption` needs a config dictionary to instruct how to encrypt/decrypt the payloads. Example: ```json { @@ -135,6 +135,15 @@ decrypted_response_payload = decrypt_payload(body, config) "decryptionKey": "./path/to/your/private.key", } ``` +You can also pass in a PKCS12 file with the password to decrypt it: +```json +{ + // .... rest of the config + + "decryptionKey": "./path/to/your/certStore.p12", // or "keyStore": "./path/to/your/certStore.p12", + "decryptionKeyPassword": "the-password", // or "keyStorePassword": "the-password", +} +``` The above can be either stored to a file or passed to 'JweEncryptionConfig' as dictionary: ```python @@ -274,6 +283,15 @@ decrypted_response_payload = decrypt_payload(body, config) "oaepPaddingDigestAlgorithm": "SHA256" } ``` +You can also pass in a PKCS12 file with the password to decrypt it: +```json +{ + // .... rest of the config + + "decryptionKey": "./path/to/your/certStore.p12", // or "keyStore": "./path/to/your/certStore.p12", + "decryptionKeyPassword": "the-password", // or "keyStorePassword": "the-password", +} +``` The above can be either stored to a file or passed to 'FieldLevelEncryptionConfig' as dictionary: ```python diff --git a/client_encryption/encryption_utils.py b/client_encryption/encryption_utils.py index ce01b1a..adc7277 100644 --- a/client_encryption/encryption_utils.py +++ b/client_encryption/encryption_utils.py @@ -45,6 +45,33 @@ def write_encryption_certificate(certificate_path, certificate, cert_type): with open(certificate_path, "wb") as f: f.write(certificate.public_bytes(cert_type)) +def __get_config_value(config, key_aliases): + """ + Helper for getting config values which have aliases + Args: + config: The config dictionary + key_aliases: List of key aliases + """ + + for key in key_aliases: + if key in config: + return config[key] + return None + +def load_decryption_key_from_config(config): + """ + Helper for reading decryption key from the config + Args: + config: The config dictionary + """ + + key_file_path = __get_config_value(config, ["decryptionKey", "keyStore"]) + if not key_file_path: + return None + + password = __get_config_value(config, ["decryptionKeyPassword", "keyStorePassword"]) + + return load_decryption_key(key_file_path, password) def load_decryption_key(key_file_path, decryption_key_password=None): """Load a RSA decryption key.""" diff --git a/client_encryption/field_level_encryption_config.py b/client_encryption/field_level_encryption_config.py index 707d61c..f5f9f62 100644 --- a/client_encryption/field_level_encryption_config.py +++ b/client_encryption/field_level_encryption_config.py @@ -3,7 +3,7 @@ from cryptography.hazmat.primitives.serialization import PublicFormat, Encoding from client_encryption import encoding_utils -from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key, validate_hash_algorithm +from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key_from_config, validate_hash_algorithm class FieldLevelEncryptionConfig(object): @@ -42,11 +42,7 @@ def __init__(self, conf): self._encryption_certificate_fingerprint = None self._encryption_certificate_type = None - if "decryptionKey" in json_config: - decryption_key_password = json_config.get("decryptionKeyPassword", None) - self._decryption_key = load_decryption_key(json_config["decryptionKey"], decryption_key_password) - else: - self._decryption_key = None + self._decryption_key = load_decryption_key_from_config(json_config) self._oaep_padding_digest_algorithm = validate_hash_algorithm(json_config["oaepPaddingDigestAlgorithm"]) diff --git a/client_encryption/jwe_encryption_config.py b/client_encryption/jwe_encryption_config.py index c14a46d..8b65424 100644 --- a/client_encryption/jwe_encryption_config.py +++ b/client_encryption/jwe_encryption_config.py @@ -3,7 +3,7 @@ from cryptography.hazmat.primitives.serialization import PublicFormat, Encoding from client_encryption.encoding_utils import ClientEncoding -from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key +from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key_from_config class JweEncryptionConfig(object): @@ -37,11 +37,7 @@ def __init__(self, conf): self._encryption_key_fingerprint = None self._encryption_certificate_type = None - if "decryptionKey" in json_config: - decryption_key_password = json_config.get("decryptionKeyPassword", None) - self._decryption_key = load_decryption_key(json_config["decryptionKey"], decryption_key_password) - else: - self._decryption_key = None + self._decryption_key = load_decryption_key_from_config(json_config) self._encrypted_value_field_name = json_config["encryptedValueFieldName"] diff --git a/tests/test_encryption_utils.py b/tests/test_encryption_utils.py index aa02b79..023041b 100644 --- a/tests/test_encryption_utils.py +++ b/tests/test_encryption_utils.py @@ -107,6 +107,81 @@ def test_load_decryption_key_file_does_not_exist(self): self.assertRaises(PrivateKeyError, to_test.load_decryption_key, key_path) + def test_load_decryption_key_from_config_with_password_aliases(self): + test_cases = [ + # p12 files + ("decryptionKey", "decryptionKeyPassword", "keys/test_key.p12", self._pkcs12), + ("decryptionKey", "keyStorePassword", "keys/test_key.p12", self._pkcs12), + ("keyStore", "decryptionKeyPassword", "keys/test_key.p12", self._pkcs12), + ("keyStore", "keyStorePassword", "keys/test_key.p12", self._pkcs12), + + # der files + ("keyStore", None, "keys/test_key_pkcs8-2048.der", self._pkcs8_2048), + ("decryptionKey", None, "keys/test_key_pkcs8-2048.der", self._pkcs8_2048), + + # pem files + ("keyStore", None, "keys/test_key_pkcs8-2048.pem", self._pkcs8_2048), + ("decryptionKey", None, "keys/test_key_pkcs8-2048.pem", self._pkcs8_2048), + ] + + for key_field, password_field, file, expected_key in test_cases: + with self.subTest(key=key_field, password_field=password_field, file=file, expected_key=expected_key): + config = { + key_field: resource_path(file), + } + + if password_field is not None: + config[password_field] = "Password1" + + key = to_test.load_decryption_key_from_config(config) + + self.assertIsNotNone(key) + self.assertIsInstance(key, RSA.RsaKey, "Must be RSA key") + self.assertEqual(expected_key, self.__strip_key(key), "Decryption key does not match") + + def test_load_decryption_key_from_config_invalid_file_or_password(self): + test_cases = [ + # valid p12 files with wrong password + ("decryptionKey", "decryptionKeyPassword", "wrong-password", "keys/invalid.p12"), + ("decryptionKey", "keyStorePassword", "wrong-password","keys/invalid.p12"), + ("keyStore", "decryptionKeyPassword", "wrong-password","keys/invalid.p12"), + ("keyStore", "keyStorePassword", "wrong-password", "keys/invalid.p12"), + + # invalid p12 files + ("decryptionKey", "decryptionKeyPassword", "Password1", "keys/invalid.p12"), + ("decryptionKey", "keyStorePassword", "Password1","keys/invalid.p12"), + ("keyStore", "decryptionKeyPassword", "Password1","keys/invalid.p12"), + ("keyStore", "keyStorePassword", "Password1", "keys/invalid.p12"), + + # invalid der files + ("keyStore", None, None, "keys/invalid-2048.der"), + ("decryptionKey", None, None, "keys/invalid-2048.der"), + + # invalid pem files + ("keyStore", None, None, "keys/invalid-2048.pem"), + ("decryptionKey", None, None, "keys/invalid-2048.pem"), + ] + + for key_field, password_field, password, file in test_cases: + with self.subTest(key=key_field, password_field=password_field, password=password, file=file): + config = { + key_field: resource_path(file), + } + + if password is not None: + config[password_field] = password + + self.assertRaises(PrivateKeyError, to_test.load_decryption_key, config) + + def test_load_decryption_key_from_config_no_key_field(self): + """Test load_decryption_key_from_config returns None when no key field present""" + configs = [{}, {"not-decryptionKey": "value"}] + for config in configs: + with self.subTest(config): + key = to_test.load_decryption_key_from_config(config) + self.assertIsNone(key) + + def test_load_hash_algorithm(self): hash_algo = to_test.load_hash_algorithm("SHA224") diff --git a/tests/test_field_level_encryption_config.py b/tests/test_field_level_encryption_config.py index 1eac567..13a96ef 100644 --- a/tests/test_field_level_encryption_config.py +++ b/tests/test_field_level_encryption_config.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch, MagicMock from tests import resource_path, get_mastercard_config_for_test import json import client_encryption.field_level_encryption_config as to_test @@ -136,6 +137,33 @@ def test_load_config_decryption_key_file_not_found(self): self.assertRaises(PrivateKeyError, to_test.FieldLevelEncryptionConfig, wrong_json) + @patch('client_encryption.field_level_encryption_config.load_decryption_key_from_config') + def test_load_config_uses_load_decryption_key_from_config(self, mock_load_key): + """Test that FieldLevelEncryptionConfig uses load_decryption_key_from_config""" + mock_keys = [MagicMock(), None] + + for mock_key in mock_keys: + with self.subTest(mock_key=mock_key): + mock_load_key.return_value = mock_key + + json_conf = json.loads(self._test_config_file) + conf = to_test.FieldLevelEncryptionConfig(json_conf) + + mock_load_key.assert_called_with(json_conf) + self.assertEqual(conf.decryption_key, mock_key) + + @patch('client_encryption.field_level_encryption_config.load_decryption_key_from_config') + def test_load_config_propagates_key_loading_exceptions(self, mock_load_key): + """Test that FieldLevelEncryptionConfig propagates exceptions from load_decryption_key_from_config""" + mock_load_key.side_effect = PrivateKeyError("some error") + + json_conf = json.loads(self._test_config_file) + + with self.assertRaises(PrivateKeyError): + to_test.FieldLevelEncryptionConfig(json_conf) + + mock_load_key.assert_called_once_with(json_conf) + def test_load_config_missing_oaep_padding_algorithm(self): wrong_json = json.loads(self._test_config_file) del wrong_json["oaepPaddingDigestAlgorithm"] diff --git a/tests/test_jwe_encryption_config.py b/tests/test_jwe_encryption_config.py index fbe305d..b82bd8a 100644 --- a/tests/test_jwe_encryption_config.py +++ b/tests/test_jwe_encryption_config.py @@ -1,5 +1,6 @@ import json import unittest +from unittest.mock import patch, MagicMock from Crypto.PublicKey import RSA @@ -106,6 +107,33 @@ def test_load_config_decryption_key_file_not_found(self): self.assertRaises(PrivateKeyError, to_test.JweEncryptionConfig, wrong_json) + @patch('client_encryption.jwe_encryption_config.load_decryption_key_from_config') + def test_load_config_uses_load_decryption_key_from_config(self, mock_load_key): + """Test that JweEncryptionConfig uses load_decryption_key_from_config""" + mock_keys = [MagicMock(), None] + + for mock_key in mock_keys: + with self.subTest(mock_key=mock_key): + mock_load_key.return_value = mock_key + + json_conf = json.loads(self._test_config_file) + conf = to_test.JweEncryptionConfig(json_conf) + + mock_load_key.assert_called_with(json_conf) + self.assertEqual(conf.decryption_key, mock_key) + + @patch('client_encryption.jwe_encryption_config.load_decryption_key_from_config') + def test_load_config_propagates_key_loading_exceptions(self, mock_load_key): + """Test that JweEncryptionConfig propagates exceptions from load_decryption_key_from_config""" + mock_load_key.side_effect = PrivateKeyError("some error") + + json_conf = json.loads(self._test_config_file) + + with self.assertRaises(PrivateKeyError): + to_test.JweEncryptionConfig(json_conf) + + mock_load_key.assert_called_once_with(json_conf) + def __check_configuration(self, conf, encoding=ClientEncoding.BASE64, oaep_algo="SHA256"): self.assertIsNotNone(conf.paths["$"], "No resource to encrypt/decrypt fields of is set") resource = conf.paths["$"]