Skip to content

Commit ce83e03

Browse files
Return AccessDenied Error Code when failing to decrypt credentials (#178)
1 parent 1fee342 commit ce83e03

File tree

5 files changed

+89
-10
lines changed

5 files changed

+89
-10
lines changed

src/cloudformation_cli_python_lib/cipher.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import uuid
44
from typing import Optional
55

6+
# boto3, botocore, aws_encryption_sdk don't have stub files
67
import boto3 # type: ignore
78

89
import aws_encryption_sdk # type: ignore
@@ -16,6 +17,7 @@
1617
)
1718
from botocore.session import Session, get_session # type: ignore
1819

20+
from .exceptions import _EncryptionError
1921
from .utils import Credentials
2022

2123

@@ -63,7 +65,7 @@ def decrypt_credentials(
6365
try:
6466
credentials_data = json.loads(encrypted_credentials)
6567
return Credentials(**credentials_data)
66-
except json.JSONDecodeError:
68+
except (json.JSONDecodeError, TypeError, ValueError):
6769
return None
6870

6971
try:
@@ -72,10 +74,19 @@ def decrypt_credentials(
7274
key_provider=self._key_provider,
7375
)
7476
credentials_data = json.loads(decrypted_credentials.decode("UTF-8"))
77+
if credentials_data is None:
78+
raise _EncryptionError(
79+
"Failed to decrypt credentials. Decrypted credentials are 'null'."
80+
)
7581

7682
return Credentials(**credentials_data)
77-
except (json.JSONDecodeError, AWSEncryptionSDKClientError) as e:
78-
raise RuntimeError("Failed to decrypt credentials.") from e
83+
except (
84+
AWSEncryptionSDKClientError,
85+
json.JSONDecodeError,
86+
TypeError,
87+
ValueError,
88+
) as e:
89+
raise _EncryptionError("Failed to decrypt credentials.") from e
7990

8091
@staticmethod
8192
def _get_assume_role_session(

src/cloudformation_cli_python_lib/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ def __init__(self, type_name: str, message: str):
9898

9999
class Unknown(_HandlerError):
100100
pass
101+
102+
103+
class _EncryptionError(Exception):
104+
pass

src/cloudformation_cli_python_lib/hook.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77

88
from .boto3_proxy import SessionProxy, _get_boto_session
99
from .cipher import Cipher, KmsCipher
10-
from .exceptions import InternalFailure, InvalidRequest, _HandlerError
10+
from .exceptions import (
11+
AccessDenied,
12+
InternalFailure,
13+
InvalidRequest,
14+
_EncryptionError,
15+
_HandlerError,
16+
)
1117
from .interface import (
1218
BaseHookHandlerRequest,
1319
HandlerErrorCode,
@@ -180,6 +186,9 @@ def _parse_request(
180186
# credentials are used when rescheduling, so can't zero them out (for now)
181187
invocation_point = HookInvocationPoint[event.actionInvocationPoint]
182188
callback_context = event.requestContext.callbackContext or {}
189+
except _EncryptionError as e:
190+
LOG.exception("Failed to decrypt credentials")
191+
raise AccessDenied(f"{e} ({type(e).__name__})") from e
183192
except Exception as e:
184193
LOG.exception("Invalid request")
185194
raise InvalidRequest(f"{e} ({type(e).__name__})") from e
@@ -228,7 +237,6 @@ def print_or_log(message: str) -> None:
228237
print(message)
229238
traceback.print_exc()
230239

231-
event: Optional[HookInvocationRequest] = None
232240
try:
233241
sessions, invocation_point, callback, event = self._parse_request(
234242
event_data
@@ -276,12 +284,12 @@ def print_or_log(message: str) -> None:
276284
# use the raw event_data as a last-ditch attempt to call back if the
277285
# request is invalid
278286
return self._create_progress_response(
279-
progress, event
287+
progress, event_data
280288
)._serialize() # pylint: disable=protected-access
281289

282290
@staticmethod
283291
def _create_progress_response(
284-
progress_event: ProgressEvent, request: Optional[HookInvocationRequest]
292+
progress_event: ProgressEvent, request: Optional[MutableMapping[str, Any]]
285293
) -> HookProgressEvent:
286294
response = HookProgressEvent(Hook._get_hook_status(progress_event.status))
287295
response.result = progress_event.result
@@ -291,7 +299,7 @@ def _create_progress_response(
291299
response.callbackDelaySeconds = progress_event.callbackDelaySeconds
292300
response.errorCode = progress_event.errorCode
293301
if request:
294-
response.clientRequestToken = request.clientRequestToken
302+
response.clientRequestToken = request.get("clientRequestToken")
295303
return response
296304

297305
@staticmethod

tests/lib/cipher_test.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from cloudformation_cli_python_lib.cipher import KmsCipher
6+
from cloudformation_cli_python_lib.exceptions import _EncryptionError
67
from cloudformation_cli_python_lib.utils import Credentials
78

89
from aws_encryption_sdk.exceptions import AWSEncryptionSDKClientError
@@ -54,7 +55,7 @@ def test_decrypt_credentials_fail():
5455
), patch(
5556
"cloudformation_cli_python_lib.cipher.aws_encryption_sdk.EncryptionSDKClient.decrypt"
5657
) as mock_decrypt, pytest.raises(
57-
RuntimeError
58+
_EncryptionError
5859
) as excinfo:
5960
mock_decrypt.side_effect = AWSEncryptionSDKClientError()
6061
cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole")
@@ -64,6 +65,29 @@ def test_decrypt_credentials_fail():
6465
assert str(excinfo.value) == "Failed to decrypt credentials."
6566

6667

68+
def test_decrypt_credentials_returns_null_fail():
69+
with patch(
70+
"cloudformation_cli_python_lib.cipher.aws_encryption_sdk.StrictAwsKmsMasterKeyProvider",
71+
autospec=True,
72+
), patch(
73+
"cloudformation_cli_python_lib.cipher.aws_encryption_sdk.EncryptionSDKClient.decrypt"
74+
) as mock_decrypt, pytest.raises(
75+
_EncryptionError
76+
) as excinfo:
77+
mock_decrypt.return_value = (
78+
b"null",
79+
Mock(),
80+
)
81+
cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole")
82+
cipher.decrypt_credentials(
83+
"ewogICAgICAgICAgICAiYWNjZXNzS2V5SWQiOiAiSUFTQVlLODM1R0FJRkhBSEVJMjMiLAogICAg"
84+
)
85+
assert (
86+
str(excinfo.value)
87+
== "Failed to decrypt credentials. Decrypted credentials are 'null'."
88+
)
89+
90+
6791
@pytest.mark.parametrize(
6892
"encryption_key_arn,encryption_key_role",
6993
[

tests/lib/hook_test.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import pytest
88
from cloudformation_cli_python_lib import Hook
9-
from cloudformation_cli_python_lib.exceptions import InternalFailure, InvalidRequest
9+
from cloudformation_cli_python_lib.exceptions import (
10+
InternalFailure,
11+
InvalidRequest,
12+
_EncryptionError,
13+
)
1014
from cloudformation_cli_python_lib.hook import _ensure_serialize
1115
from cloudformation_cli_python_lib.interface import (
1216
BaseModel,
@@ -194,6 +198,34 @@ def test_entrypoint_success_without_caller_provider_creds():
194198
assert event == expected
195199

196200

201+
def test_entrypoint_encryption_error_raises_access_denied():
202+
@dataclass
203+
class TypeConfigurationModel(BaseModel):
204+
a_string: str
205+
206+
@classmethod
207+
def _deserialize(cls, json_data):
208+
return cls("test")
209+
210+
hook = Hook(Mock(), TypeConfigurationModel)
211+
212+
with patch(
213+
"cloudformation_cli_python_lib.hook.HookProviderLogHandler.setup"
214+
), patch("cloudformation_cli_python_lib.hook.MetricsPublisherProxy"), patch(
215+
"cloudformation_cli_python_lib.hook.KmsCipher.decrypt_credentials"
216+
) as mock_cipher:
217+
mock_cipher.side_effect = _EncryptionError("Failed to decrypt credentials.")
218+
event = hook.__call__.__wrapped__( # pylint: disable=no-member
219+
hook, ENTRYPOINT_PAYLOAD, None
220+
)
221+
222+
assert event["errorCode"] == "AccessDenied"
223+
assert event["hookStatus"] == "FAILED"
224+
assert event["callbackDelaySeconds"] == 0
225+
assert event["clientRequestToken"] == "4b90a7e4-b790-456b-a937-0cfdfa211dfe"
226+
assert "Failed to decrypt credentials" in event["message"]
227+
228+
197229
def test_cast_hook_request_invalid_request(hook):
198230
request = HookInvocationRequest.deserialize(ENTRYPOINT_PAYLOAD)
199231
request.requestData = None

0 commit comments

Comments
 (0)