Skip to content

Commit 6835e7b

Browse files
Wrap the Rust HTTP client with make_deferred_yieldable (#18903)
Wrap the Rust HTTP client with `make_deferred_yieldable` so downstream usage doesn't need to use `PreserveLoggingContext()` or `make_deferred_yieldable`. > it seems like we should have some wrapper around it that uses [`make_deferred_yieldable(...)`](https://github.com/element-hq/synapse/blob/40edb10a98ae24c637b7a9cf6a3003bf6fa48b5f/docs/log_contexts.md#where-you-create-a-new-awaitable-make-it-follow-the-rules) to make things right so we don't have to do this in the downstream code. > > *-- @MadLittleMods, #18357 (comment) Spawning from wanting to [remove `PreserveLoggingContext()` from the codebase](#18870) and thinking that we [shouldn't have to pollute all downstream usage with `PreserveLoggingContext()` or `make_deferred_yieldable`](#18357 (comment)) Part of #18905 (Remove `sentinel` logcontext where we log in Synapse)
1 parent d27ff16 commit 6835e7b

File tree

7 files changed

+272
-18
lines changed

7 files changed

+272
-18
lines changed

changelog.d/18903.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Wrap the Rust HTTP client with `make_deferred_yieldable` so it follows Synapse logcontext rules.

rust/src/http_client.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
* <https://www.gnu.org/licenses/agpl-3.0.html>.
1313
*/
1414

15-
use std::{collections::HashMap, future::Future};
15+
use std::{collections::HashMap, future::Future, sync::OnceLock};
1616

1717
use anyhow::Context;
1818
use futures::TryStreamExt;
@@ -299,5 +299,22 @@ where
299299
});
300300
});
301301

302-
Ok(deferred)
302+
// Make the deferred follow the Synapse logcontext rules
303+
make_deferred_yieldable(py, &deferred)
304+
}
305+
306+
static MAKE_DEFERRED_YIELDABLE: OnceLock<pyo3::Py<pyo3::PyAny>> = OnceLock::new();
307+
308+
/// Given a deferred, make it follow the Synapse logcontext rules
309+
fn make_deferred_yieldable<'py>(
310+
py: Python<'py>,
311+
deferred: &Bound<'py, PyAny>,
312+
) -> PyResult<Bound<'py, PyAny>> {
313+
let make_deferred_yieldable = MAKE_DEFERRED_YIELDABLE.get_or_init(|| {
314+
let sys = PyModule::import(py, "synapse.logging.context").unwrap();
315+
let func = sys.getattr("make_deferred_yieldable").unwrap().unbind();
316+
func
317+
});
318+
319+
make_deferred_yieldable.call1(py, (deferred,))?.extract(py)
303320
}

synapse/api/auth/mas.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
UnrecognizedRequestError,
3434
)
3535
from synapse.http.site import SynapseRequest
36-
from synapse.logging.context import PreserveLoggingContext
3736
from synapse.logging.opentracing import (
3837
active_span,
3938
force_tracing,
@@ -229,13 +228,12 @@ async def _introspect_token(
229228
try:
230229
with start_active_span("mas-introspect-token"):
231230
inject_request_headers(raw_headers)
232-
with PreserveLoggingContext():
233-
resp_body = await self._rust_http_client.post(
234-
url=self._introspection_endpoint,
235-
response_limit=1 * 1024 * 1024,
236-
headers=raw_headers,
237-
request_body=body,
238-
)
231+
resp_body = await self._rust_http_client.post(
232+
url=self._introspection_endpoint,
233+
response_limit=1 * 1024 * 1024,
234+
headers=raw_headers,
235+
request_body=body,
236+
)
239237
except HttpResponseException as e:
240238
end_time = self._clock.time()
241239
introspection_response_timer.labels(

synapse/api/auth/msc3861_delegated.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
UnrecognizedRequestError,
3939
)
4040
from synapse.http.site import SynapseRequest
41-
from synapse.logging.context import PreserveLoggingContext
4241
from synapse.logging.opentracing import (
4342
active_span,
4443
force_tracing,
@@ -327,13 +326,12 @@ async def _introspect_token(
327326
try:
328327
with start_active_span("mas-introspect-token"):
329328
inject_request_headers(raw_headers)
330-
with PreserveLoggingContext():
331-
resp_body = await self._rust_http_client.post(
332-
url=uri,
333-
response_limit=1 * 1024 * 1024,
334-
headers=raw_headers,
335-
request_body=body,
336-
)
329+
resp_body = await self._rust_http_client.post(
330+
url=uri,
331+
response_limit=1 * 1024 * 1024,
332+
headers=raw_headers,
333+
request_body=body,
334+
)
337335
except HttpResponseException as e:
338336
end_time = self._clock.time()
339337
introspection_response_timer.labels(

synapse/synapse_rust/http_client.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ from twisted.internet.defer import Deferred
1717
from synapse.types import ISynapseReactor
1818

1919
class HttpClient:
20+
"""
21+
The returned deferreds follow Synapse logcontext rules.
22+
"""
23+
2024
def __init__(self, reactor: ISynapseReactor, user_agent: str) -> None: ...
2125
def get(self, url: str, response_limit: int) -> Deferred[bytes]: ...
2226
def post(

tests/synapse_rust/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# This file is licensed under the Affero General Public License (AGPL) version 3.
2+
#
3+
# Copyright (C) 2025 New Vector, Ltd
4+
#
5+
# This program is free software: you can redistribute it and/or modify
6+
# it under the terms of the GNU Affero General Public License as
7+
# published by the Free Software Foundation, either version 3 of the
8+
# License, or (at your option) any later version.
9+
#
10+
# See the GNU Affero General Public License for more details:
11+
# <https://www.gnu.org/licenses/agpl-3.0.html>.
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# This file is licensed under the Affero General Public License (AGPL) version 3.
2+
#
3+
# Copyright (C) 2025 New Vector, Ltd
4+
#
5+
# This program is free software: you can redistribute it and/or modify
6+
# it under the terms of the GNU Affero General Public License as
7+
# published by the Free Software Foundation, either version 3 of the
8+
# License, or (at your option) any later version.
9+
#
10+
# See the GNU Affero General Public License for more details:
11+
# <https://www.gnu.org/licenses/agpl-3.0.html>.
12+
13+
import json
14+
import logging
15+
import threading
16+
import time
17+
from http.server import BaseHTTPRequestHandler, HTTPServer
18+
from typing import Any, Coroutine, Generator, TypeVar, Union
19+
20+
from twisted.internet.defer import Deferred, ensureDeferred
21+
from twisted.internet.testing import MemoryReactor
22+
23+
from synapse.logging.context import (
24+
LoggingContext,
25+
PreserveLoggingContext,
26+
_Sentinel,
27+
current_context,
28+
run_in_background,
29+
)
30+
from synapse.server import HomeServer
31+
from synapse.synapse_rust.http_client import HttpClient
32+
from synapse.util.clock import Clock
33+
from synapse.util.json import json_decoder
34+
35+
from tests.unittest import HomeserverTestCase
36+
37+
logger = logging.getLogger(__name__)
38+
39+
T = TypeVar("T")
40+
41+
42+
class StubRequestHandler(BaseHTTPRequestHandler):
43+
server: "StubServer"
44+
45+
def do_GET(self) -> None:
46+
self.server.calls += 1
47+
48+
self.send_response(200)
49+
self.send_header("Content-Type", "application/json")
50+
self.end_headers()
51+
self.wfile.write(json.dumps({"ok": True}).encode("utf-8"))
52+
53+
def log_message(self, format: str, *args: Any) -> None:
54+
# Don't log anything; by default, the server logs to stderr
55+
pass
56+
57+
58+
class StubServer(HTTPServer):
59+
"""A stub HTTP server that we can send requests to for testing.
60+
61+
This opens a real HTTP server on a random port, on a separate thread.
62+
"""
63+
64+
calls: int = 0
65+
"""How many times has the endpoint been requested."""
66+
67+
_thread: threading.Thread
68+
69+
def __init__(self) -> None:
70+
super().__init__(("127.0.0.1", 0), StubRequestHandler)
71+
72+
self._thread = threading.Thread(
73+
target=self.serve_forever,
74+
name="StubServer",
75+
kwargs={"poll_interval": 0.01},
76+
daemon=True,
77+
)
78+
self._thread.start()
79+
80+
def shutdown(self) -> None:
81+
super().shutdown()
82+
self._thread.join()
83+
84+
@property
85+
def endpoint(self) -> str:
86+
return f"http://127.0.0.1:{self.server_port}/"
87+
88+
89+
class HttpClientTestCase(HomeserverTestCase):
90+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
91+
hs = self.setup_test_homeserver()
92+
93+
# XXX: We must create the Rust HTTP client before we call `reactor.run()` below.
94+
# Twisted's `MemoryReactor` doesn't invoke `callWhenRunning` callbacks if it's
95+
# already running and we rely on that to start the Tokio thread pool in Rust. In
96+
# the future, this may not matter, see https://github.com/twisted/twisted/pull/12514
97+
self._http_client = hs.get_proxied_http_client()
98+
self._rust_http_client = HttpClient(
99+
reactor=hs.get_reactor(),
100+
user_agent=self._http_client.user_agent.decode("utf8"),
101+
)
102+
103+
# This triggers the server startup hooks, which starts the Tokio thread pool
104+
reactor.run()
105+
106+
return hs
107+
108+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
109+
self.server = StubServer()
110+
111+
def tearDown(self) -> None:
112+
# MemoryReactor doesn't trigger the shutdown phases, and we want the
113+
# Tokio thread pool to be stopped
114+
# XXX: This logic should probably get moved somewhere else
115+
shutdown_triggers = self.reactor.triggers.get("shutdown", {})
116+
for phase in ["before", "during", "after"]:
117+
triggers = shutdown_triggers.get(phase, [])
118+
for callbable, args, kwargs in triggers:
119+
callbable(*args, **kwargs)
120+
121+
def till_deferred_has_result(
122+
self,
123+
awaitable: Union[
124+
"Coroutine[Deferred[Any], Any, T]",
125+
"Generator[Deferred[Any], Any, T]",
126+
"Deferred[T]",
127+
],
128+
) -> "Deferred[T]":
129+
"""Wait until a deferred has a result.
130+
131+
This is useful because the Rust HTTP client will resolve the deferred
132+
using reactor.callFromThread, which are only run when we call
133+
reactor.advance.
134+
"""
135+
deferred = ensureDeferred(awaitable)
136+
tries = 0
137+
while not deferred.called:
138+
time.sleep(0.1)
139+
self.reactor.advance(0)
140+
tries += 1
141+
if tries > 100:
142+
raise Exception("Timed out waiting for deferred to resolve")
143+
144+
return deferred
145+
146+
def _check_current_logcontext(self, expected_logcontext_string: str) -> None:
147+
context = current_context()
148+
assert isinstance(context, LoggingContext) or isinstance(context, _Sentinel), (
149+
f"Expected LoggingContext({expected_logcontext_string}) but saw {context}"
150+
)
151+
self.assertEqual(
152+
str(context),
153+
expected_logcontext_string,
154+
f"Expected LoggingContext({expected_logcontext_string}) but saw {context}",
155+
)
156+
157+
def test_request_response(self) -> None:
158+
"""
159+
Test to make sure we can make a basic request and get the expected
160+
response.
161+
"""
162+
163+
async def do_request() -> None:
164+
resp_body = await self._rust_http_client.get(
165+
url=self.server.endpoint,
166+
response_limit=1 * 1024 * 1024,
167+
)
168+
raw_response = json_decoder.decode(resp_body.decode("utf-8"))
169+
self.assertEqual(raw_response, {"ok": True})
170+
171+
self.get_success(self.till_deferred_has_result(do_request()))
172+
self.assertEqual(self.server.calls, 1)
173+
174+
async def test_logging_context(self) -> None:
175+
"""
176+
Test to make sure the `LoggingContext` (logcontext) is handled correctly
177+
when making requests.
178+
"""
179+
# Sanity check that we start in the sentinel context
180+
self._check_current_logcontext("sentinel")
181+
182+
callback_finished = False
183+
184+
async def do_request() -> None:
185+
nonlocal callback_finished
186+
try:
187+
# Should have the same logcontext as the caller
188+
self._check_current_logcontext("foo")
189+
190+
with LoggingContext(name="competing", server_name="test_server"):
191+
# Make the actual request
192+
await self._rust_http_client.get(
193+
url=self.server.endpoint,
194+
response_limit=1 * 1024 * 1024,
195+
)
196+
self._check_current_logcontext("competing")
197+
198+
# Back to the caller's context outside of the `LoggingContext` block
199+
self._check_current_logcontext("foo")
200+
finally:
201+
# When exceptions happen, we still want to mark the callback as finished
202+
# so that the test can complete and we see the underlying error.
203+
callback_finished = True
204+
205+
with LoggingContext(name="foo", server_name="test_server"):
206+
# Fire off the function, but don't wait on it.
207+
run_in_background(do_request)
208+
209+
# Now wait for the function under test to have run
210+
with PreserveLoggingContext():
211+
while not callback_finished:
212+
# await self.hs.get_clock().sleep(0)
213+
time.sleep(0.1)
214+
self.reactor.advance(0)
215+
216+
# check that the logcontext is left in a sane state.
217+
self._check_current_logcontext("foo")
218+
219+
self.assertTrue(
220+
callback_finished,
221+
"Callback never finished which means the test probably didn't wait long enough",
222+
)
223+
224+
# Back to the sentinel context
225+
self._check_current_logcontext("sentinel")

0 commit comments

Comments
 (0)