Skip to content

Commit 2a72400

Browse files
Support multiple proxy servers in Forwarded header parsing (#782)
* Support multiple proxy servers in Forwarded header parsing * Update CHANGELOG * use regex and simplify * use integer regex * Use compiled regexes --------- Co-authored-by: vincentsarago <[email protected]>
1 parent 48c4dea commit 2a72400

File tree

3 files changed

+53
-22
lines changed

3 files changed

+53
-22
lines changed

CHANGES.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
## [Unreleased]
44

5-
## Changed
5+
### Changed
66

77
* use `string` type instead of python `datetime.datetime` for datetime parameter in `BaseSearchGetRequest`, `ItemCollectionUri` and `BaseCollectionSearchGetRequest` GET models
88
* rename `filter` to `filter_expr` for `FilterExtensionGetRequest` and `FilterExtensionPostRequest` attributes to avoid conflict with python filter method
99

10+
### Fixed
11+
12+
* Support multiple proxy servers in the `forwarded` header in `ProxyHeaderMiddleware` ([#782](https://github.com/stac-utils/stac-fastapi/pull/782))
13+
1014
## [3.0.5] - 2025-01-10
1115

1216
### Removed

stac_fastapi/api/stac_fastapi/api/middleware.py

+20-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Api middleware."""
22

3+
import contextlib
34
import re
45
import typing
56
from http.client import HTTP_PORT, HTTPS_PORT
@@ -44,6 +45,10 @@ def __init__(
4445
)
4546

4647

48+
_PROTO_HEADER_REGEX = re.compile(r"proto=(?P<proto>http(s)?)")
49+
_HOST_HEADER_REGEX = re.compile(r"host=(?P<host>[\w.-]+)(:(?P<port>\d{1,5}))?")
50+
51+
4752
class ProxyHeaderMiddleware:
4853
"""Account for forwarding headers when deriving base URL.
4954
@@ -68,11 +73,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6873
proto == "https" and port != HTTPS_PORT
6974
):
7075
port_suffix = f":{port}"
76+
7177
scope["headers"] = self._replace_header_value_by_name(
7278
scope,
7379
"host",
7480
f"{domain}{port_suffix}",
7581
)
82+
7683
await self.app(scope, receive, send)
7784

7885
def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
@@ -87,31 +94,23 @@ def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
8794
else:
8895
domain = header_host_parts[0]
8996
port = None
90-
forwarded = self._get_header_value_by_name(scope, "forwarded")
91-
if forwarded is not None:
92-
parts = forwarded.split(";")
93-
for part in parts:
94-
if len(part) > 0 and re.search("=", part):
95-
key, value = part.split("=")
96-
if key == "proto":
97-
proto = value
98-
elif key == "host":
99-
host_parts = value.split(":")
100-
domain = host_parts[0]
101-
try:
102-
port = int(host_parts[1]) if len(host_parts) == 2 else None
103-
except ValueError:
104-
# ignore ports that are not valid integers
105-
pass
97+
98+
if forwarded := self._get_header_value_by_name(scope, "forwarded"):
99+
for proxy in forwarded.split(","):
100+
if (proto_expr := _PROTO_HEADER_REGEX.search(proxy)) and (
101+
host_expr := _HOST_HEADER_REGEX.search(proxy)
102+
):
103+
proto = proto_expr.group("proto")
104+
domain = host_expr.group("host")
105+
port_str = host_expr.group("port") # None if not present in the match
106+
106107
else:
107108
domain = self._get_header_value_by_name(scope, "x-forwarded-host", domain)
108109
proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto)
109110
port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port)
110-
try:
111-
port = int(port_str) if port_str is not None else None
112-
except ValueError:
113-
# ignore ports that are not valid integers
114-
pass
111+
112+
with contextlib.suppress(ValueError): # ignore ports that are not valid integers
113+
port = int(port_str) if port_str is not None else port
115114

116115
return (proto, domain, port)
117116

stac_fastapi/api/tests/test_middleware.py

+28
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,34 @@ def test_replace_header_value_by_name(
155155
},
156156
("https", "test", 1234),
157157
),
158+
(
159+
{
160+
"scheme": "http",
161+
"server": ["testserver", 80],
162+
"headers": [
163+
(
164+
b"forwarded",
165+
# two proxy servers added an entry, we want to use the last one
166+
b"proto=https;host=test:1234,proto=https;host=second-server:1111",
167+
)
168+
],
169+
},
170+
("https", "second-server", 1111),
171+
),
172+
(
173+
{
174+
"scheme": "http",
175+
"server": ["testserver", 80],
176+
"headers": [
177+
(
178+
b"forwarded",
179+
# check when host and port are inverted
180+
b"host=test:1234;proto=https",
181+
)
182+
],
183+
},
184+
("https", "test", 1234),
185+
),
158186
],
159187
)
160188
def test_get_forwarded_url_parts(

0 commit comments

Comments
 (0)