Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tornado/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def __str__(self) -> str:

def _apply_xheaders(self, headers: httputil.HTTPHeaders) -> None:
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
# Only trust proxies in trusted_downstream
if self.remote_ip not in self.trusted_downstream:
return
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
# Skip trusted downstream hosts in X-Forwarded-For list
Expand Down
24 changes: 23 additions & 1 deletion tornado/test/httpserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def test_invalid_content_length(self):
yield self.stream.read_until_close()


class XHeaderTest(HandlerBaseTestCase):
class XHeaderDirectProxyNotTrustedTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
self.set_header("request-version", self.request.version)
Expand All @@ -560,6 +560,27 @@ def get(self):
def get_httpserver_options(self):
return dict(xheaders=True, trusted_downstream=["5.5.5.5"])

def test_direct_proxy_not_trusted(self):
valid_ipv4_list = {"X-Forwarded-For": "4.4.4.4, 5.5.5.5"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "127.0.0.1"
)


class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
self.set_header("request-version", self.request.version)
self.write(
dict(
remote_ip=self.request.remote_ip,
remote_protocol=self.request.protocol,
)
)

def get_httpserver_options(self):
return dict(xheaders=True, trusted_downstream=["127.0.0.1", "5.5.5.5"])

def test_ip_headers(self):
self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")

Expand Down Expand Up @@ -648,6 +669,7 @@ def get_app(self):
def get_httpserver_options(self):
output = super(SSLXHeaderTest, self).get_httpserver_options()
output["xheaders"] = True
output["trusted_downstream"] = ["127.0.0.1"]
return output

def test_request_without_xprotocol(self):
Expand Down