diff --git a/aio_pika/connection.py b/aio_pika/connection.py index ac73767b..4827be73 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -315,7 +315,7 @@ async def main(): query=kw ) - connection = connection_class(url, loop=loop) + connection = connection_class(url, loop=loop, **kwargs) await connection.connect( timeout=timeout, client_properties=client_properties diff --git a/aio_pika/exceptions.py b/aio_pika/exceptions.py index 9af7bb6b..41e17741 100644 --- a/aio_pika/exceptions.py +++ b/aio_pika/exceptions.py @@ -40,6 +40,10 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty): pass +class MaxReconnectAttemptsReached(Exception): + pass + + __all__ = ( 'AMQPChannelError', 'AMQPConnectionError', @@ -53,6 +57,7 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty): 'DuplicateConsumerTag', 'IncompatibleProtocolError', 'InvalidFrameError', + 'MaxReconnectAttemptsReached', 'MessageProcessError', 'MethodNotImplemented', 'ProbableAuthenticationError', diff --git a/aio_pika/robust_connection.py b/aio_pika/robust_connection.py index ddcc784e..ad0b0446 100644 --- a/aio_pika/robust_connection.py +++ b/aio_pika/robust_connection.py @@ -4,7 +4,7 @@ from typing import Callable, Type from aiormq.connection import parse_bool, parse_int -from .exceptions import CONNECTION_EXCEPTIONS +from .exceptions import CONNECTION_EXCEPTIONS, MaxReconnectAttemptsReached from .connection import Connection, connect, ConnectionType from .tools import CallbackCollection from .types import TimeoutType @@ -29,6 +29,7 @@ class RobustConnection(Connection): CHANNEL_CLASS = RobustChannel KWARGS_TYPES = ( + ('max_reconnect_attempts', parse_int, '0'), ('reconnect_interval', parse_int, '5'), ('fail_fast', parse_bool, '1'), ) @@ -43,7 +44,9 @@ def __init__(self, url, loop=None, **kwargs): self.fail_fast = self.kwargs['fail_fast'] self.__channels = set() + self._reconnect_attempt = None self._reconnect_callbacks = CallbackCollection() + self._stop_callbacks = CallbackCollection() self._closed = False @property @@ -77,6 +80,9 @@ def add_reconnect_callback(self, callback: Callable[[], None]): self._reconnect_callbacks.add(callback) + def add_stop_callback(self, callback: Callable[[Exception], None]): + self._stop_callbacks.add(callback) + async def connect(self, timeout: TimeoutType = None, **kwargs): if kwargs: # Store connect kwargs for reconnects @@ -104,6 +110,16 @@ async def reconnect(self): if self.is_closed: return + if self.kwargs['max_reconnect_attempts'] > 0: + if self._reconnect_attempt is None: + self._reconnect_attempt = 1 + else: + self._reconnect_attempt += 1 + + if self._reconnect_attempt > self.kwargs['max_reconnect_attempts']: + await self.close(MaxReconnectAttemptsReached()) + return + try: await super().connect() except CONNECTION_EXCEPTIONS: @@ -131,6 +147,7 @@ def channel(self, channel_number: int = None, return channel async def _on_reconnect(self): + self._reconnect_attempt = None for number, channel in self._channels.items(): try: await channel.on_reconnect(self, number) @@ -151,6 +168,7 @@ async def close(self, exc=asyncio.CancelledError): return self._closed = True + self._stop_callbacks(exc) if self.connection is None: return diff --git a/tests/test_amqp.py b/tests/test_amqp.py index 83d848d4..cfd4aecc 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -1231,7 +1231,7 @@ async def test_on_return_raises(self): ) for _ in range(100): - with pytest.raises(aio_pika.exceptions.DeliveryError) as e: + with pytest.raises(aio_pika.exceptions.DeliveryError): await channel.default_exchange.publish( Message(body=body), routing_key=queue_name, ) diff --git a/tests/test_amqp_robust.py b/tests/test_amqp_robust.py index e884155e..a7f4043e 100644 --- a/tests/test_amqp_robust.py +++ b/tests/test_amqp_robust.py @@ -8,6 +8,7 @@ from aiormq import ChannelLockedResource from aio_pika import connect_robust, Message +from aio_pika.exceptions import MaxReconnectAttemptsReached from aio_pika.robust_channel import RobustChannel from aio_pika.robust_connection import RobustConnection from aio_pika.robust_queue import RobustQueue @@ -27,6 +28,7 @@ def __init__(self, *, loop, shost='127.0.0.1', sport, self.src_port = sport self.dst_host = dhost self.dst_port = dport + self._run_task = None self.connections = set() async def _pipe(self, reader: asyncio.StreamReader, @@ -54,12 +56,19 @@ async def handle_client(self, creader: asyncio.StreamReader, ]) async def start(self): - return await asyncio.start_server( + self._run_task = await asyncio.start_server( self.handle_client, host=self.src_host, port=self.src_port, loop=self.loop, ) + return self._run_task + + async def stop(self): + assert self._run_task is not None + self._run_task.close() + await self.disconnect() + self._run_task = None async def disconnect(self): tasks = list() @@ -72,7 +81,8 @@ async def close(writer): writer = self.connections.pop() # type: asyncio.StreamWriter tasks.append(self.loop.create_task(close(writer))) - await asyncio.wait(tasks) + if tasks: + await asyncio.wait(tasks) class TestCase(AMQPTestCase): @@ -84,7 +94,7 @@ def get_unused_port() -> int: sock.close() return port - async def create_connection(self, cleanup=True): + async def create_connection(self, cleanup=True, max_reconnect_attempts=0): self.proxy = Proxy( dhost=AMQP_URL.host, dport=AMQP_URL.port, @@ -98,7 +108,11 @@ async def create_connection(self, cleanup=True): self.proxy.src_host ).with_port( self.proxy.src_port - ).update_query(reconnect_interval=1) + ).update_query( + reconnect_interval=1 + ).update_query( + max_reconnect_attempts=max_reconnect_attempts + ) client = await connect_robust(str(url), loop=self.loop) @@ -210,6 +224,28 @@ async def reader(): assert len(shared) == 10 + async def test_robust_reconnect_max_attempts(self): + client = await self.create_connection(max_reconnect_attempts=2) + self.assertIsInstance(client, RobustConnection) + + first_close = asyncio.Future() + stopped = asyncio.Future() + + def stop_callback(exc): + assert isinstance(exc, MaxReconnectAttemptsReached) + stopped.set_result(True) + + def close_callback(f): + first_close.set_result(True) + + client.add_stop_callback(stop_callback) + client.connection.closing.add_done_callback(close_callback) + await self.proxy.stop() + await first_close + # 1 interval before first try and 2 after attempts + await asyncio.wait_for(stopped, + timeout=client.reconnect_interval * 3 + 0.1) + async def test_channel_locked_resource2(self): ch1 = await self.create_channel() ch2 = await self.create_channel()