Skip to content

Commit ff2e2d2

Browse files
Create new URI when replacing
1 parent 46975ec commit ff2e2d2

File tree

2 files changed

+159
-11
lines changed

2 files changed

+159
-11
lines changed

packages/smithy-http/src/smithy_http/aio/protocols.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from io import BytesIO
44
from typing import Any
55

6+
from smithy_core import URI as _URI
67
from smithy_core.aio.interfaces import ClientProtocol
78
from smithy_core.codecs import Codec
89
from smithy_core.deserializers import DeserializeableShape
@@ -36,17 +37,27 @@ def set_service_endpoint(
3637
endpoint: Endpoint,
3738
) -> HTTPRequest:
3839
uri = endpoint.uri
39-
uri_builder = request.destination
40-
41-
if uri.scheme:
42-
uri_builder.scheme = uri.scheme
43-
if uri.host:
44-
uri_builder.host = uri.host
45-
if uri.port and uri.port > -1:
46-
uri_builder.port = uri.port
47-
if uri.path:
48-
uri_builder.path = os.path.join(uri.path, uri_builder.path or "")
49-
# TODO: merge headers from the endpoint properties bag
40+
previous = request.destination
41+
42+
path = previous.path or uri.path
43+
if uri.path is not None and previous.path is not None:
44+
path = os.path.join(uri.path, previous.path.lstrip("/"))
45+
46+
query = previous.query or uri.query
47+
if uri.query is not None and previous.query is not None:
48+
query = f"{uri.query}&{previous.query}"
49+
50+
request.destination = _URI(
51+
scheme=uri.scheme,
52+
username=uri.username or previous.username,
53+
password=uri.password or previous.password,
54+
host=uri.host,
55+
port=uri.port or previous.port,
56+
path=path,
57+
query=query,
58+
fragment=uri.fragment or previous.fragment,
59+
)
60+
5061
return request
5162

5263

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Any
5+
6+
import pytest
7+
from smithy_core import URI
8+
from smithy_core.documents import TypeRegistry
9+
from smithy_core.endpoints import Endpoint
10+
from smithy_core.interfaces import TypedProperties
11+
from smithy_core.interfaces import URI as URIInterface
12+
from smithy_core.schemas import APIOperation
13+
from smithy_core.shapes import ShapeID
14+
from smithy_http import Fields
15+
from smithy_http.aio import HTTPRequest
16+
from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface
17+
from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface
18+
from smithy_http.aio.protocols import HttpClientProtocol
19+
20+
21+
class TestProtocol(HttpClientProtocol):
22+
_id = ShapeID("ns.foo#bar")
23+
24+
@property
25+
def id(self) -> ShapeID:
26+
return self._id
27+
28+
def serialize_request(
29+
self,
30+
*,
31+
operation: APIOperation[Any, Any],
32+
input: Any,
33+
endpoint: URIInterface,
34+
context: TypedProperties,
35+
) -> HTTPRequestInterface:
36+
raise Exception("This is only for tests.")
37+
38+
def deserialize_response(
39+
self,
40+
*,
41+
operation: APIOperation[Any, Any],
42+
request: HTTPRequestInterface,
43+
response: HTTPResponseInterface,
44+
error_registry: TypeRegistry,
45+
context: TypedProperties,
46+
) -> Any:
47+
raise Exception("This is only for tests.")
48+
49+
50+
@pytest.mark.parametrize(
51+
"request_uri,endpoint_uri,expected",
52+
[
53+
(
54+
URI(host="com.example", path="/foo"),
55+
URI(host="com.example", path="/bar"),
56+
URI(host="com.example", path="/bar/foo"),
57+
),
58+
(
59+
URI(host="com.example"),
60+
URI(host="com.example", path="/bar"),
61+
URI(host="com.example", path="/bar"),
62+
),
63+
(
64+
URI(host="com.example", path="/foo"),
65+
URI(host="com.example"),
66+
URI(host="com.example", path="/foo"),
67+
),
68+
(
69+
URI(host="com.example", scheme="http"),
70+
URI(host="com.example", scheme="https"),
71+
URI(host="com.example", scheme="https"),
72+
),
73+
(
74+
URI(host="com.example", username="name", password="password"),
75+
URI(host="com.example", username="othername", password="otherpassword"),
76+
URI(host="com.example", username="othername", password="otherpassword"),
77+
),
78+
(
79+
URI(host="com.example", username="name", password="password"),
80+
URI(host="com.example"),
81+
URI(host="com.example", username="name", password="password"),
82+
),
83+
(
84+
URI(host="com.example", port=8080),
85+
URI(host="com.example", port=8000),
86+
URI(host="com.example", port=8000),
87+
),
88+
(
89+
URI(host="com.example", port=8080),
90+
URI(host="com.example"),
91+
URI(host="com.example", port=8080),
92+
),
93+
(
94+
URI(host="com.example", query="foo=bar"),
95+
URI(host="com.example"),
96+
URI(host="com.example", query="foo=bar"),
97+
),
98+
(
99+
URI(host="com.example"),
100+
URI(host="com.example", query="spam"),
101+
URI(host="com.example", query="spam"),
102+
),
103+
(
104+
URI(host="com.example", query="foo=bar"),
105+
URI(host="com.example", query="spam"),
106+
URI(host="com.example", query="spam&foo=bar"),
107+
),
108+
(
109+
URI(host="com.example", fragment="header"),
110+
URI(host="com.example", fragment="footer"),
111+
URI(host="com.example", fragment="footer"),
112+
),
113+
(
114+
URI(host="com.example"),
115+
URI(host="com.example", fragment="footer"),
116+
URI(host="com.example", fragment="footer"),
117+
),
118+
(
119+
URI(host="com.example", fragment="header"),
120+
URI(host="com.example"),
121+
URI(host="com.example", fragment="header"),
122+
),
123+
],
124+
)
125+
def test_http_protocol_joins_uris(
126+
request_uri: URI, endpoint_uri: URI, expected: URI
127+
) -> None:
128+
protocol = TestProtocol()
129+
request = HTTPRequest(
130+
destination=request_uri,
131+
method="GET",
132+
fields=Fields(),
133+
)
134+
endpoint = Endpoint(uri=endpoint_uri)
135+
updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint)
136+
actual = updated_request.destination
137+
assert actual == expected

0 commit comments

Comments
 (0)