Skip to content

Commit 8141b93

Browse files
ioistiredelprans
andauthored
Add support for connection termination listeners (#525)
The new `Connection.add_termination_listener()` method can be used to register callbacks to be invoked when a connection has been terminated. Co-authored-by: Elvis Pranskevichus <[email protected]>
1 parent b081320 commit 8141b93

File tree

3 files changed

+105
-12
lines changed

3 files changed

+105
-12
lines changed

asyncpg/_testbase/fuzzer.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ def _close_connection(self, connection):
145145
if conn_task is not None:
146146
conn_task.cancel()
147147

148+
def close_all_connections(self):
149+
for conn in list(self.connections):
150+
self.loop.call_soon_threadsafe(self._close_connection, conn)
151+
148152

149153
class Connection:
150154
def __init__(self, client_sock, backend_sock, proxy):
@@ -215,10 +219,11 @@ async def _read(self, sock, n):
215219
else:
216220
return read_task.result()
217221
finally:
218-
if not read_task.done():
219-
read_task.cancel()
220-
if not conn_event_task.done():
221-
conn_event_task.cancel()
222+
if not self.loop.is_closed():
223+
if not read_task.done():
224+
read_task.cancel()
225+
if not conn_event_task.done():
226+
conn_event_task.cancel()
222227

223228
async def _write(self, sock, data):
224229
write_task = asyncio.ensure_future(
@@ -236,10 +241,11 @@ async def _write(self, sock, data):
236241
else:
237242
return write_task.result()
238243
finally:
239-
if not write_task.done():
240-
write_task.cancel()
241-
if not conn_event_task.done():
242-
conn_event_task.cancel()
244+
if not self.loop.is_closed():
245+
if not write_task.done():
246+
write_task.cancel()
247+
if not conn_event_task.done():
248+
conn_event_task.cancel()
243249

244250
async def proxy_to_backend(self):
245251
buf = None
@@ -264,7 +270,8 @@ async def proxy_to_backend(self):
264270
pass
265271

266272
finally:
267-
self.loop.call_soon(self.close)
273+
if not self.loop.is_closed():
274+
self.loop.call_soon(self.close)
268275

269276
async def proxy_from_backend(self):
270277
buf = None
@@ -289,4 +296,5 @@ async def proxy_from_backend(self):
289296
pass
290297

291298
finally:
292-
self.loop.call_soon(self.close)
299+
if not self.loop.is_closed():
300+
self.loop.call_soon(self.close)

asyncpg/connection.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class Connection(metaclass=ConnectionMeta):
4646
'_listeners', '_server_version', '_server_caps',
4747
'_intro_query', '_reset_query', '_proxy',
4848
'_stmt_exclusive_section', '_config', '_params', '_addr',
49-
'_log_listeners', '_cancellations', '_source_traceback',
50-
'__weakref__')
49+
'_log_listeners', '_termination_listeners', '_cancellations',
50+
'_source_traceback', '__weakref__')
5151

5252
def __init__(self, protocol, transport, loop,
5353
addr: (str, int) or str,
@@ -78,6 +78,7 @@ def __init__(self, protocol, transport, loop,
7878
self._listeners = {}
7979
self._log_listeners = set()
8080
self._cancellations = set()
81+
self._termination_listeners = set()
8182

8283
settings = self._protocol.get_settings()
8384
ver_string = settings.server_version
@@ -178,6 +179,28 @@ def remove_log_listener(self, callback):
178179
"""
179180
self._log_listeners.discard(callback)
180181

182+
def add_termination_listener(self, callback):
183+
"""Add a listener that will be called when the connection is closed.
184+
185+
:param callable callback:
186+
A callable receiving one argument:
187+
**connection**: a Connection the callback is registered with.
188+
189+
.. versionadded:: 0.21.0
190+
"""
191+
self._termination_listeners.add(callback)
192+
193+
def remove_termination_listener(self, callback):
194+
"""Remove a listening callback for connection termination.
195+
196+
:param callable callback:
197+
The callable that was passed to
198+
:meth:`Connection.add_termination_listener`.
199+
200+
.. versionadded:: 0.21.0
201+
"""
202+
self._termination_listeners.discard(callback)
203+
181204
def get_server_pid(self):
182205
"""Return the PID of the Postgres server the connection is bound to."""
183206
return self._protocol.get_server_pid()
@@ -1120,6 +1143,7 @@ def _abort(self):
11201143
self._protocol = None
11211144

11221145
def _cleanup(self):
1146+
self._call_termination_listeners()
11231147
# Free the resources associated with this connection.
11241148
# This must be called when a connection is terminated.
11251149

@@ -1237,6 +1261,25 @@ def _call_log_listener(self, cb, con_ref, message):
12371261
'exception': ex
12381262
})
12391263

1264+
def _call_termination_listeners(self):
1265+
if not self._termination_listeners:
1266+
return
1267+
1268+
con_ref = self._unwrap()
1269+
for cb in self._termination_listeners:
1270+
try:
1271+
cb(con_ref)
1272+
except Exception as ex:
1273+
self._loop.call_exception_handler({
1274+
'message': (
1275+
'Unhandled exception in asyncpg connection '
1276+
'termination listener callback {!r}'.format(cb)
1277+
),
1278+
'exception': ex
1279+
})
1280+
1281+
self._termination_listeners.clear()
1282+
12401283
def _process_notification(self, pid, channel, payload):
12411284
if channel not in self._listeners:
12421285
return

tests/test_listeners.py

+42
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77

88
import asyncio
9+
import os
10+
import platform
11+
import sys
12+
import unittest
913

1014
from asyncpg import _testbase as tb
1115
from asyncpg import exceptions
@@ -272,3 +276,41 @@ def listener1(*args):
272276
pass
273277

274278
con.add_log_listener(listener1)
279+
280+
281+
@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
282+
@unittest.skipIf(
283+
platform.system() == 'Windows' and
284+
sys.version_info >= (3, 8),
285+
'not compatible with ProactorEventLoop which is default in Python 3.8')
286+
class TestConnectionTerminationListener(tb.ProxiedClusterTestCase):
287+
288+
async def test_connection_termination_callback_called_on_remote(self):
289+
290+
called = False
291+
292+
def close_cb(con):
293+
nonlocal called
294+
called = True
295+
296+
con = await self.connect()
297+
con.add_termination_listener(close_cb)
298+
self.proxy.close_all_connections()
299+
try:
300+
await con.fetchval('SELECT 1')
301+
except Exception:
302+
pass
303+
self.assertTrue(called)
304+
305+
async def test_connection_termination_callback_called_on_local(self):
306+
307+
called = False
308+
309+
def close_cb(con):
310+
nonlocal called
311+
called = True
312+
313+
con = await self.connect()
314+
con.add_termination_listener(close_cb)
315+
await con.close()
316+
self.assertTrue(called)

0 commit comments

Comments
 (0)