diff --git a/packages/smithy-http/src/smithy_http/deserializers.py b/packages/smithy-http/src/smithy_http/deserializers.py index 5a12de2b..8a0f3fef 100644 --- a/packages/smithy-http/src/smithy_http/deserializers.py +++ b/packages/smithy-http/src/smithy_http/deserializers.py @@ -84,17 +84,44 @@ def read_struct( ) case Binding.PAYLOAD: assert binding_matcher.payload_member is not None # noqa: S101 - deserializer = self._create_payload_deserializer( - binding_matcher.payload_member - ) - consumer(binding_matcher.payload_member, deserializer) + if self._should_read_payload(binding_matcher.payload_member): + deserializer = self._create_payload_deserializer( + binding_matcher.payload_member + ) + consumer(binding_matcher.payload_member, deserializer) case _: pass - if binding_matcher.has_body: + if binding_matcher.has_body and not self._has_empty_body( + self._response, self._body + ): deserializer = self._create_body_deserializer() deserializer.read_struct(schema, consumer) + def _should_read_payload(self, schema: Schema) -> bool: + if schema.shape_type not in ( + ShapeType.LIST, + ShapeType.MAP, + ShapeType.UNION, + ShapeType.STRUCTURE, + ): + return True + return not self._has_empty_body(self._response, self._body) + + def _has_empty_body( + self, response: HTTPResponse, body: "SyncStreamingBlob | None" + ) -> bool: + if "content-length" in response.fields: + return int(response.fields["content-length"].as_string()) == 0 + if isinstance(body, bytes | bytearray): + return len(body) == 0 + if (seek := getattr(self._body, "seek", None)) is not None: + content_length = seek(0, 2) + if content_length == 0: + return True + seek(0, 0) + return False + def _create_payload_deserializer(self, payload_member: Schema) -> ShapeDeserializer: if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): body = self._body if self._body is not None else self._response.body diff --git a/packages/smithy-http/tests/unit/test_serializers.py b/packages/smithy-http/tests/unit/test_serializers.py index 0dd5eea6..19137c87 100644 --- a/packages/smithy-http/tests/unit/test_serializers.py +++ b/packages/smithy-http/tests/unit/test_serializers.py @@ -605,7 +605,7 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None: @dataclass class HTTPStringPayload: - payload: str + payload: str | None = None ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStringPayload") SCHEMA: ClassVar[Schema] = Schema.collection( @@ -620,7 +620,8 @@ def serialize(self, serializer: ShapeSerializer) -> None: self.serialize_members(s) def serialize_members(self, serializer: ShapeSerializer) -> None: - serializer.write_string(self.SCHEMA.members["payload"], self.payload) + if self.payload is not None: + serializer.write_string(self.SCHEMA.members["payload"], self.payload) @classmethod def deserialize(cls, deserializer: ShapeDeserializer) -> Self: @@ -713,7 +714,7 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None: @dataclass class HTTPStructuredPayload: - payload: HTTPStringPayload + payload: HTTPStringPayload | None = None ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStructuredPayload") SCHEMA: ClassVar[Schema] = Schema.collection( @@ -732,7 +733,8 @@ def serialize(self, serializer: ShapeSerializer) -> None: self.serialize_members(s) def serialize_members(self, serializer: ShapeSerializer) -> None: - serializer.write_struct(self.SCHEMA.members["payload"], self.payload) + if self.payload is not None: + serializer.write_struct(self.SCHEMA.members["payload"], self.payload) @classmethod def deserialize(cls, deserializer: ShapeDeserializer) -> Self: @@ -1590,6 +1592,53 @@ def payload_cases() -> list[HTTPMessageTestCase]: HTTPStructuredPayload(payload=HTTPStringPayload(payload="foo")), HTTPMessage(body=BytesIO(b'{"payload":"foo"}')), ), + HTTPMessageTestCase( + HTTPStructuredPayload(HTTPStringPayload()), + HTTPMessage(body=BytesIO(b"{}")), + ), + ] + + +class NonSeekableBytesReader: + def __init__(self, data: bytes) -> None: + self._wrapped = BytesIO(data) + + def read(self, size: int = -1, /) -> bytes: + return self._wrapped.read(size) + + +def response_payload_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPStructuredPayload(), + HTTPMessage(body=b""), + ), + HTTPMessageTestCase( + HTTPStructuredPayload(), + HTTPMessage(body=BytesIO(b"")), + ), + HTTPMessageTestCase( + HTTPStructuredPayload(), + HTTPMessage( + body=NonSeekableBytesReader(b""), + fields=tuples_to_fields([("content-length", "0")]), + ), + ), + HTTPMessageTestCase( + HTTPImplicitPayload(), + HTTPMessage(body=b""), + ), + HTTPMessageTestCase( + HTTPImplicitPayload(), + HTTPMessage(body=BytesIO(b"")), + ), + HTTPMessageTestCase( + HTTPImplicitPayload(), + HTTPMessage( + body=NonSeekableBytesReader(b""), + fields=tuples_to_fields([("content-length", "0")]), + ), + ), ] @@ -1706,7 +1755,10 @@ async def test_serialize_response_omitting_empty_payload() -> None: RESPONSE_DESER_CASES: list[HTTPMessageTestCase] = ( - header_cases() + empty_prefix_header_deser_cases() + payload_cases() + header_cases() + + empty_prefix_header_deser_cases() + + payload_cases() + + response_payload_cases() )