diff --git a/asynch/connection.py b/asynch/connection.py index 27383dc..b052b84 100644 --- a/asynch/connection.py +++ b/asynch/connection.py @@ -178,32 +178,25 @@ async def ping(self) -> None: msg = f"Ping has failed for {self}" raise ConnectionError(msg) - async def _refresh(self) -> None: - """Refresh the connection. + async def is_live(self) -> bool: + """Checks if the connection is live. - Attempting to ping and if failed, - then trying to connect again. - If the reconnection does not work, - an Exception is propagated. + Attempts to ping and returns True if successful. :raises ConnectionError: 1. refreshing created, i.e., not opened connection 2. refreshing already closed connection - :return: None + :return: True if the connection is alive, otherwise False. """ - - if self.status == ConnectionStatus.created: - msg = f"the {self} is not opened to be refreshed" - raise ConnectionError(msg) - if self.status == ConnectionStatus.closed: - msg = f"the {self} is already closed" - raise ConnectionError(msg) + if self.status == ConnectionStatus.created or self.status == ConnectionStatus.closed: + return False try: await self.ping() + return True except ConnectionError: - await self.connect() + return False async def rollback(self): raise NotSupportedError diff --git a/asynch/pool.py b/asynch/pool.py index 7aa1f38..3dd0cb0 100644 --- a/asynch/pool.py +++ b/asynch/pool.py @@ -2,7 +2,7 @@ import logging from collections import deque from collections.abc import AsyncIterator -from contextlib import asynccontextmanager, suppress +from contextlib import asynccontextmanager from typing import Optional from asynch.connection import Connection @@ -132,7 +132,7 @@ def maxsize(self) -> int: def minsize(self) -> int: return self._minsize - async def _create_connection(self) -> None: + async def _create_connection(self) -> Connection: if self._pool_size == self._maxsize: raise AsynchPoolError(f"{self} is already full") if self._pool_size > self._maxsize: @@ -143,11 +143,15 @@ async def _create_connection(self) -> None: try: await conn.ping() - self._free_connections.append(conn) + return conn except ConnectionError as e: msg = f"failed to create a {conn} for {self}" raise AsynchPoolError(msg) from e + async def _create_and_release_connection(self) -> None: + conn = await self._create_connection() + self._free_connections.append(conn) + def _pop_connection(self) -> Connection: if not self._free_connections: raise AsynchPoolError(f"no free connection in {self}") @@ -156,8 +160,8 @@ def _pop_connection(self) -> Connection: async def _get_fresh_connection(self) -> Optional[Connection]: while self._free_connections: conn = self._pop_connection() - with suppress(ConnectionError): - await conn._refresh() + logger.debug(f"Testing connection {conn}") + if await conn.is_live(): return conn return None @@ -166,8 +170,8 @@ async def _acquire_connection(self) -> Connection: self._acquired_connections.append(conn) return conn - await self._create_connection() - conn = self._pop_connection() + logger.debug("No free connection in pool. Creating new connection.") + conn = await self._create_connection() self._acquired_connections.append(conn) return conn @@ -176,13 +180,9 @@ async def _release_connection(self, conn: Connection) -> None: raise AsynchPoolError(f"the connection {conn} does not belong to {self}") self._acquired_connections.remove(conn) - try: - await conn._refresh() - except ConnectionError as e: - msg = f"the {conn} is invalidated" - raise AsynchPoolError(msg) from e - - self._free_connections.append(conn) + if await conn.is_live(): + logger.debug(f"Releasing connection {conn}") + self._free_connections.append(conn) async def _init_connections(self, n: int, *, strict: bool = False) -> None: if n < 0: @@ -199,7 +199,7 @@ async def _init_connections(self, n: int, *, strict: bool = False) -> None: # it is possible that the `_create_connection` may not create `n` connections tasks: list[asyncio.Task] = [ - asyncio.create_task(self._create_connection()) for _ in range(n) + asyncio.create_task(self._create_and_release_connection()) for _ in range(n) ] # that is why possible exceptions from the `_create_connection` are also gathered if strict and any( @@ -226,10 +226,15 @@ async def connection(self) -> AsyncIterator[Connection]: :return: a free connection from the pool :rtype: Connection """ + logger.debug( + f"Acquiring connection from Pool ({len(self._free_connections)} free connections, {len(self._acquired_connections)} acquired connections)" + ) async with self._sem: async with self._lock: conn = await self._acquire_connection() + logger.debug(f"Acquired connection {conn}") + try: yield conn finally: diff --git a/asynch/proto/connection.py b/asynch/proto/connection.py index a599d74..ed3381c 100644 --- a/asynch/proto/connection.py +++ b/asynch/proto/connection.py @@ -577,9 +577,10 @@ async def disconnect(self): async def connect(self): if self.connected: await self.disconnect() - logger.debug("Connecting. Database: %s. User: %s", self.database, self.user) for host, port in self.hosts: - logger.debug("Connecting to %s:%s", host, port) + logger.debug( + "Connecting to %s:%s Database: %s. User: %s", host, port, self.database, self.user + ) return await self._init_connection(host, port) async def execute( diff --git a/asynch/proto/streams/buffered.py b/asynch/proto/streams/buffered.py index e435773..79d5f3f 100644 --- a/asynch/proto/streams/buffered.py +++ b/asynch/proto/streams/buffered.py @@ -61,6 +61,9 @@ async def write_fixed_strings(self, data, length): async def close(self) -> None: if not self.writer: return + if self.writer.is_closing(): + return + self.writer.close() await self.writer.wait_closed() diff --git a/tests/test_reconnection.py b/tests/test_reconnection.py index be47abf..e8d8d14 100644 --- a/tests/test_reconnection.py +++ b/tests/test_reconnection.py @@ -40,7 +40,7 @@ async def proxy(request): @pytest.fixture() async def proxy_pool(proxy): - async with Pool(minsize=1, maxsize=1, dsn=CONNECTION_DSN.replace("9000", "9001")) as pool: + async with Pool(minsize=1, maxsize=2, dsn=CONNECTION_DSN.replace("9000", "9001")) as pool: yield pool @@ -66,6 +66,30 @@ async def test_close_disconnected_connection(proxy_pool): await asyncio.sleep(TIMEOUT * 2) +@pytest.mark.asyncio +async def test_connection_reuse(proxy_pool): + async def execute_sleep(): + async with proxy_pool.connection() as c: + async with c.cursor() as cursor: + await cursor.execute("SELECT sleep(0.1)") + + await asyncio.gather(execute_sleep(), execute_sleep()) + + # There are two live connections in the pool. + assert proxy_pool.free_connections == 2 + + logger.info(f"Killing {proxy_pool._free_connections[0]}") + await proxy_pool._free_connections[0]._connection.writer.close() + + async with proxy_pool.connection() as c: + async with c.cursor() as cursor: + await cursor.execute("SELECT 1") + + # The first connection was not live anymore and was closed. The second connection was reused. + # There is now only one connection in the pool. + assert proxy_pool.free_connections == 1 + + async def reader_to_writer(name: str, graceful: bool, reader: StreamReader, writer: StreamWriter): while True: try: