diff --git a/aio_pika/robust_channel.py b/aio_pika/robust_channel.py index 8dbd0f02..300b9243 100644 --- a/aio_pika/robust_channel.py +++ b/aio_pika/robust_channel.py @@ -2,7 +2,7 @@ import warnings from collections import defaultdict from itertools import chain -from typing import Any, DefaultDict, Dict, Optional, Set, Type, Union +from typing import Any, DefaultDict, Dict, Optional, Set, Type, Union, cast from warnings import warn import aiormq @@ -31,7 +31,7 @@ class RobustChannel(Channel, AbstractRobustChannel): # type: ignore RESTORE_RETRY_DELAY: int = 2 - _exchanges: DefaultDict[str, Set[AbstractRobustExchange]] + _exchanges: DefaultDict[str, AbstractRobustExchange] _queues: DefaultDict[str, Set[RobustQueue]] default_exchange: RobustExchange @@ -60,7 +60,7 @@ def __init__( on_return_raises=on_return_raises, ) - self._exchanges = defaultdict(set) + self._exchanges = defaultdict() self._queues = defaultdict(set) self._prefetch_count: int = 0 self._prefetch_size: int = 0 @@ -136,7 +136,7 @@ async def _on_open(self) -> None: if not hasattr(self, "default_exchange"): await super()._on_open() - exchanges = tuple(chain(*self._exchanges.values())) + exchanges = self._exchanges.values() queues = tuple(chain(*self._queues.values())) channel = await self.get_underlay_channel() @@ -200,6 +200,11 @@ async def declare_exchange( Set to False for temporary exchanges that should not be restored. """ await self.ready() + # Passive is True so expecting the exchange to be already declared + # if we can just return it instead of creating a new class instance + if passive and name in self._exchanges: + return self._exchanges[name] + exchange = ( await super().declare_exchange( name=name, @@ -212,12 +217,12 @@ async def declare_exchange( timeout=timeout, ) ) + exchange = cast(AbstractRobustExchange, exchange) if not internal and robust: - # noinspection PyTypeChecker - self._exchanges[name].add(exchange) # type: ignore + self._exchanges[name] = exchange - return exchange # type: ignore + return exchange async def exchange_delete( self, @@ -254,7 +259,12 @@ async def declare_queue( Set to False for temporary queues that should not be restored. """ await self.ready() - queue: RobustQueue = await super().declare_queue( # type: ignore + # Passive is True so expecting the queue to be already declared + # if we can just return it instead of creating a new class instance + if passive and name and name in self._queues: + return list(self._queues[name])[0] + + queue = await super().declare_queue( name=name, durable=durable, exclusive=exclusive, @@ -263,6 +273,8 @@ async def declare_queue( arguments=arguments, timeout=timeout, ) + queue = cast(RobustQueue, queue) + if robust: self._queues[queue.name].add(queue) return queue diff --git a/tests/test_amqp_robust.py b/tests/test_amqp_robust.py index 4545232b..7a335d62 100644 --- a/tests/test_amqp_robust.py +++ b/tests/test_amqp_robust.py @@ -129,6 +129,21 @@ async def test_channel_can_be_closed(self, connection): assert channel.is_closed + async def test_get_exchange(self, connection, declare_exchange): + channel = await self.create_channel(connection) + name = get_random_name("passive", "exchange") + + with pytest.raises(aio_pika.exceptions.ChannelNotFoundEntity): + await channel.get_exchange(name) + + channel = await self.create_channel(connection) + exchange = await declare_exchange( + name, auto_delete=True, channel=channel, + ) + exchange_passive = await channel.get_exchange(name) + + assert exchange.name is exchange_passive.name + class TestCaseAmqpNoConfirmsRobust(TestCaseAmqpNoConfirms): pass