Skip to content

Tolerate empty structured payloads #507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions packages/smithy-http/src/smithy_http/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,44 @@ def read_struct(
)
case Binding.PAYLOAD:
assert binding_matcher.payload_member is not None # noqa: S101
deserializer = self._create_payload_deserializer(
binding_matcher.payload_member
)
consumer(binding_matcher.payload_member, deserializer)
if self._should_read_payload(binding_matcher.payload_member):
deserializer = self._create_payload_deserializer(
binding_matcher.payload_member
)
consumer(binding_matcher.payload_member, deserializer)
case _:
pass

if binding_matcher.has_body:
if binding_matcher.has_body and not self._has_empty_body(
self._response, self._body
):
deserializer = self._create_body_deserializer()
deserializer.read_struct(schema, consumer)

def _should_read_payload(self, schema: Schema) -> bool:
if schema.shape_type not in (
ShapeType.LIST,
ShapeType.MAP,
ShapeType.UNION,
ShapeType.STRUCTURE,
):
return True
return not self._has_empty_body(self._response, self._body)

def _has_empty_body(
self, response: HTTPResponse, body: "SyncStreamingBlob | None"
) -> bool:
if "content-length" in response.fields:
return int(response.fields["content-length"].as_string()) == 0
if isinstance(body, bytes | bytearray):
return len(body) == 0
if (seek := getattr(self._body, "seek", None)) is not None:
content_length = seek(0, 2)
if content_length == 0:
return True
seek(0, 0)
return False

def _create_payload_deserializer(self, payload_member: Schema) -> ShapeDeserializer:
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
body = self._body if self._body is not None else self._response.body
Expand Down
62 changes: 57 additions & 5 deletions packages/smithy-http/tests/unit/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:

@dataclass
class HTTPStringPayload:
payload: str
payload: str | None = None

ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStringPayload")
SCHEMA: ClassVar[Schema] = Schema.collection(
Expand All @@ -620,7 +620,8 @@ def serialize(self, serializer: ShapeSerializer) -> None:
self.serialize_members(s)

def serialize_members(self, serializer: ShapeSerializer) -> None:
serializer.write_string(self.SCHEMA.members["payload"], self.payload)
if self.payload is not None:
serializer.write_string(self.SCHEMA.members["payload"], self.payload)

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
Expand Down Expand Up @@ -713,7 +714,7 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:

@dataclass
class HTTPStructuredPayload:
payload: HTTPStringPayload
payload: HTTPStringPayload | None = None

ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStructuredPayload")
SCHEMA: ClassVar[Schema] = Schema.collection(
Expand All @@ -732,7 +733,8 @@ def serialize(self, serializer: ShapeSerializer) -> None:
self.serialize_members(s)

def serialize_members(self, serializer: ShapeSerializer) -> None:
serializer.write_struct(self.SCHEMA.members["payload"], self.payload)
if self.payload is not None:
serializer.write_struct(self.SCHEMA.members["payload"], self.payload)

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
Expand Down Expand Up @@ -1590,6 +1592,53 @@ def payload_cases() -> list[HTTPMessageTestCase]:
HTTPStructuredPayload(payload=HTTPStringPayload(payload="foo")),
HTTPMessage(body=BytesIO(b'{"payload":"foo"}')),
),
HTTPMessageTestCase(
HTTPStructuredPayload(HTTPStringPayload()),
HTTPMessage(body=BytesIO(b"{}")),
),
]


class NonSeekableBytesReader:
def __init__(self, data: bytes) -> None:
self._wrapped = BytesIO(data)

def read(self, size: int = -1, /) -> bytes:
return self._wrapped.read(size)


def response_payload_cases() -> list[HTTPMessageTestCase]:
return [
HTTPMessageTestCase(
HTTPStructuredPayload(),
HTTPMessage(body=b""),
),
HTTPMessageTestCase(
HTTPStructuredPayload(),
HTTPMessage(body=BytesIO(b"")),
),
HTTPMessageTestCase(
HTTPStructuredPayload(),
HTTPMessage(
body=NonSeekableBytesReader(b""),
fields=tuples_to_fields([("content-length", "0")]),
),
),
HTTPMessageTestCase(
HTTPImplicitPayload(),
HTTPMessage(body=b""),
),
HTTPMessageTestCase(
HTTPImplicitPayload(),
HTTPMessage(body=BytesIO(b"")),
),
HTTPMessageTestCase(
HTTPImplicitPayload(),
HTTPMessage(
body=NonSeekableBytesReader(b""),
fields=tuples_to_fields([("content-length", "0")]),
),
),
]


Expand Down Expand Up @@ -1706,7 +1755,10 @@ async def test_serialize_response_omitting_empty_payload() -> None:


RESPONSE_DESER_CASES: list[HTTPMessageTestCase] = (
header_cases() + empty_prefix_header_deser_cases() + payload_cases()
header_cases()
+ empty_prefix_header_deser_cases()
+ payload_cases()
+ response_payload_cases()
)


Expand Down