Skip to content

Commit d9a236e

Browse files
committed
Break strong ref between Connection and Protocol.
Connection objects currently stay alive and open if they are not explicitly closed or terminated. GC won't happen to them because event loops have a strong reference to their underlying Transport object. By replacing a strong Connection<->Protocol reference with a weak one, we are able to implement Connection.__del__() method that: * issues a warning if a Connection object is being GCed prior to be explicitly closed; * terminates the underlying Protocol and Transport, effectively closing the open network connection to the Postgres server. When in asyncio debug mode (enabled by PYTHONASYNCIODEBUG env variable or explicitly with `loop.set_debug(True)`) Connection objects save the traceback of their origin and later use it to make the GC warning clarer. Addresses #323.
1 parent 17f2079 commit d9a236e

File tree

8 files changed

+142
-12
lines changed

8 files changed

+142
-12
lines changed

asyncpg/connection.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66

77

88
import asyncio
9+
import asyncpg
910
import collections
1011
import collections.abc
1112
import itertools
1213
import struct
14+
import sys
1315
import time
16+
import traceback
1417
import warnings
1518

1619
from . import compat
@@ -44,7 +47,8 @@ class Connection(metaclass=ConnectionMeta):
4447
'_listeners', '_server_version', '_server_caps',
4548
'_intro_query', '_reset_query', '_proxy',
4649
'_stmt_exclusive_section', '_config', '_params', '_addr',
47-
'_log_listeners', '_cancellations')
50+
'_log_listeners', '_cancellations', '_source_traceback',
51+
'__weakref__')
4852

4953
def __init__(self, protocol, transport, loop,
5054
addr: (str, int) or str,
@@ -98,6 +102,27 @@ def __init__(self, protocol, transport, loop,
98102
# `con.execute()`, and `con.executemany()`.
99103
self._stmt_exclusive_section = _Atomic()
100104

105+
if loop.get_debug():
106+
self._source_traceback = _extract_stack()
107+
else:
108+
self._source_traceback = None
109+
110+
def __del__(self):
111+
if not self.is_closed() and self._protocol is not None:
112+
if self._source_traceback:
113+
msg = "unclosed connection {!r}; created at:\n {}".format(
114+
self, self._source_traceback)
115+
else:
116+
msg = (
117+
"unclosed connection {!r}; run in asyncio debug "
118+
"mode to show the traceback of connection "
119+
"origin".format(self)
120+
)
121+
122+
warnings.warn(msg, ResourceWarning)
123+
if not self._loop.is_closed():
124+
self.terminate()
125+
101126
async def add_listener(self, channel, callback):
102127
"""Add a listener for Postgres notifications.
103128
@@ -1791,4 +1816,25 @@ def _detect_server_capabilities(server_version, connection_settings):
17911816
)
17921817

17931818

1819+
def _extract_stack(limit=10):
1820+
"""Replacement for traceback.extract_stack() that only does the
1821+
necessary work for asyncio debug mode.
1822+
"""
1823+
frame = sys._getframe().f_back
1824+
try:
1825+
stack = traceback.StackSummary.extract(
1826+
traceback.walk_stack(frame), lookup_lines=False)
1827+
finally:
1828+
del frame
1829+
1830+
apg_path = asyncpg.__path__[0]
1831+
i = 0
1832+
while i < len(stack) and stack[i][0].startswith(apg_path):
1833+
i += 1
1834+
stack = stack[i:i + limit]
1835+
1836+
stack.reverse()
1837+
return ''.join(traceback.format_list(stack))
1838+
1839+
17941840
_uid = 0

asyncpg/pool.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,21 @@ async def _get_new_connection(self):
467467
connection_class=self._connection_class)
468468

469469
if self._init is not None:
470-
await self._init(con)
470+
try:
471+
await self._init(con)
472+
except Exception as ex:
473+
# If a user-defined `init` function fails, we don't
474+
# know if the connection is safe for re-use, hence
475+
# we close it. A new connection will be created
476+
# when `acquire` is called again.
477+
try:
478+
# Use `close()` to close the connection gracefully.
479+
# An exception in `init` isn't necessarily caused
480+
# by an IO or a protocol error. close() will
481+
# do the necessary cleanup via _release_on_close().
482+
await con.close()
483+
finally:
484+
raise ex
471485

472486
return con
473487

asyncpg/protocol/prepared_stmt.pxd

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ cdef class PreparedStatementState:
1717
list row_desc
1818
list parameters_desc
1919

20-
BaseProtocol protocol
2120
ConnectionSettings settings
2221

2322
int16_t args_num

asyncpg/protocol/prepared_stmt.pyx

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ cdef class PreparedStatementState:
1414
def __cinit__(self, str name, str query, BaseProtocol protocol):
1515
self.name = name
1616
self.query = query
17-
self.protocol = protocol
1817
self.settings = protocol.settings
1918
self.row_desc = self.parameters_desc = None
2019
self.args_codecs = self.rows_codecs = None

asyncpg/protocol/protocol.pxd

+3-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ cdef class BaseProtocol(CoreProtocol):
3636
object timeout_handle
3737
object timeout_callback
3838
object completed_callback
39-
object connection
39+
object conref
4040
bint is_reading
4141

4242
str last_query
@@ -48,6 +48,8 @@ cdef class BaseProtocol(CoreProtocol):
4848

4949
PreparedStatementState statement
5050

51+
cdef get_connection(self)
52+
5153
cdef _get_timeout_impl(self, timeout)
5254
cdef _check_state(self)
5355
cdef _new_waiter(self, timeout)

asyncpg/protocol/protocol.pyx

+28-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import codecs
1616
import collections
1717
import socket
1818
import time
19+
import weakref
1920

2021
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
2122
int32_t, uint32_t, int64_t, uint64_t, \
@@ -126,7 +127,10 @@ cdef class BaseProtocol(CoreProtocol):
126127
self.create_future = self._create_future_fallback
127128

128129
def set_connection(self, connection):
129-
self.connection = connection
130+
self.conref = weakref.ref(connection)
131+
132+
cdef get_connection(self):
133+
return self.conref()
130134

131135
def get_server_pid(self):
132136
return self.backend_pid
@@ -585,7 +589,21 @@ cdef class BaseProtocol(CoreProtocol):
585589
def _request_cancel(self):
586590
self.cancel_waiter = self.create_future()
587591
self.cancel_sent_waiter = self.create_future()
588-
self.connection._cancel_current_command(self.cancel_sent_waiter)
592+
593+
con = self.get_connection()
594+
if con is not None:
595+
# if 'con' is None it means that the connection object has been
596+
# garbage collected and that the transport will soon be aborted.
597+
con._cancel_current_command(self.cancel_sent_waiter)
598+
else:
599+
self.loop.call_exception_handler({
600+
'message': 'asyncpg.Protocol has no reference to its '
601+
'Connection object and yet a cancellation '
602+
'was requested. Please report this at '
603+
'github.com/magicstack/asyncpg.'
604+
})
605+
self.abort()
606+
589607
self._set_state(PROTOCOL_CANCELLED)
590608

591609
def _on_timeout(self, fut):
@@ -636,7 +654,7 @@ cdef class BaseProtocol(CoreProtocol):
636654

637655
cdef inline _get_timeout_impl(self, timeout):
638656
if timeout is None:
639-
timeout = self.connection._config.command_timeout
657+
timeout = self.get_connection()._config.command_timeout
640658
elif timeout is NO_TIMEOUT:
641659
timeout = None
642660
else:
@@ -688,7 +706,7 @@ cdef class BaseProtocol(CoreProtocol):
688706
'cannot perform operation: another operation is in progress')
689707
self.waiter = self.create_future()
690708
if timeout is not None:
691-
self.timeout_handle = self.connection._loop.call_later(
709+
self.timeout_handle = self.loop.call_later(
692710
timeout, self.timeout_callback, self.waiter)
693711
self.waiter.add_done_callback(self.completed_callback)
694712
return self.waiter
@@ -839,10 +857,14 @@ cdef class BaseProtocol(CoreProtocol):
839857
self.return_extra = False
840858

841859
cdef _on_notice(self, parsed):
842-
self.connection._process_log_message(parsed, self.last_query)
860+
con = self.get_connection()
861+
if con is not None:
862+
con._process_log_message(parsed, self.last_query)
843863

844864
cdef _on_notification(self, pid, channel, payload):
845-
self.connection._process_notification(pid, channel, payload)
865+
con = self.get_connection()
866+
if con is not None:
867+
con._process_notification(pid, channel, payload)
846868

847869
cdef _on_connection_lost(self, exc):
848870
if self.closing:

tests/test_connect.py

+43
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import contextlib
10+
import gc
1011
import ipaddress
1112
import os
1213
import platform
@@ -15,6 +16,7 @@
1516
import tempfile
1617
import textwrap
1718
import unittest
19+
import weakref
1820

1921
import asyncpg
2022
from asyncpg import _testbase as tb
@@ -851,3 +853,44 @@ async def worker():
851853
tasks = [worker() for _ in range(100)]
852854
await asyncio.gather(*tasks, loop=self.loop)
853855
await pool.close()
856+
857+
858+
class TestConnectionGC(tb.ClusterTestCase):
859+
860+
async def _run_no_explicit_close_test(self):
861+
con = await self.connect()
862+
proto = con._protocol
863+
conref = weakref.ref(con)
864+
del con
865+
866+
gc.collect()
867+
gc.collect()
868+
gc.collect()
869+
870+
self.assertIsNone(conref())
871+
self.assertTrue(proto.is_closed())
872+
873+
async def test_no_explicit_close_no_debug(self):
874+
olddebug = self.loop.get_debug()
875+
self.loop.set_debug(False)
876+
try:
877+
with self.assertWarnsRegex(
878+
ResourceWarning,
879+
r'unclosed connection.*run in asyncio debug'):
880+
await self._run_no_explicit_close_test()
881+
finally:
882+
self.loop.set_debug(olddebug)
883+
884+
async def test_no_explicit_close_with_debug(self):
885+
olddebug = self.loop.get_debug()
886+
self.loop.set_debug(True)
887+
try:
888+
with self.assertWarnsRegex(ResourceWarning,
889+
r'unclosed connection') as rw:
890+
await self._run_no_explicit_close_test()
891+
892+
msg = rw.warning.args[0]
893+
self.assertIn(' created at:\n', msg)
894+
self.assertIn('in test_no_explicit_close_with_debug', msg)
895+
finally:
896+
self.loop.set_debug(olddebug)

tests/test_pool.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ class Error(Exception):
292292
pass
293293

294294
async def setup(con):
295-
nonlocal setup_calls
295+
nonlocal setup_calls, last_con
296+
last_con = con
296297
setup_calls += 1
297298
if setup_calls > 1:
298299
cons.append(con)
@@ -302,24 +303,28 @@ async def setup(con):
302303

303304
with self.subTest(method='setup'):
304305
setup_calls = 0
306+
last_con = None
305307
cons = []
306308
async with self.create_pool(database='postgres',
307309
min_size=1, max_size=1,
308310
setup=setup) as pool:
309311
with self.assertRaises(Error):
310312
await pool.acquire()
313+
self.assertTrue(last_con.is_closed())
311314

312315
async with pool.acquire() as con:
313316
self.assertEqual(cons, ['error', con])
314317

315318
with self.subTest(method='init'):
316319
setup_calls = 0
320+
last_con = None
317321
cons = []
318322
async with self.create_pool(database='postgres',
319323
min_size=0, max_size=1,
320324
init=setup) as pool:
321325
with self.assertRaises(Error):
322326
await pool.acquire()
327+
self.assertTrue(last_con.is_closed())
323328

324329
async with pool.acquire() as con:
325330
self.assertEqual(await con.fetchval('select 1::int'), 1)

0 commit comments

Comments
 (0)