Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions aio_pika/robust_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections import defaultdict
from itertools import chain
from typing import Any, DefaultDict, Dict, Optional, Set, Type, Union
from typing import Any, DefaultDict, Dict, Optional, Set, Type, Union, cast
from warnings import warn

import aiormq
Expand Down Expand Up @@ -31,7 +31,7 @@ class RobustChannel(Channel, AbstractRobustChannel): # type: ignore

RESTORE_RETRY_DELAY: int = 2

_exchanges: DefaultDict[str, Set[AbstractRobustExchange]]
_exchanges: DefaultDict[str, AbstractRobustExchange]
_queues: DefaultDict[str, Set[RobustQueue]]
default_exchange: RobustExchange

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
on_return_raises=on_return_raises,
)

self._exchanges = defaultdict(set)
self._exchanges = defaultdict()
self._queues = defaultdict(set)
self._prefetch_count: int = 0
self._prefetch_size: int = 0
Expand Down Expand Up @@ -136,7 +136,7 @@ async def _on_open(self) -> None:
if not hasattr(self, "default_exchange"):
await super()._on_open()

exchanges = tuple(chain(*self._exchanges.values()))
exchanges = self._exchanges.values()
queues = tuple(chain(*self._queues.values()))
channel = await self.get_underlay_channel()

Expand Down Expand Up @@ -200,6 +200,11 @@ async def declare_exchange(
Set to False for temporary exchanges that should not be restored.
"""
await self.ready()
# Passive is True so expecting the exchange to be already declared
# if we can just return it instead of creating a new class instance
if passive and name in self._exchanges:
return self._exchanges[name]

exchange = (
await super().declare_exchange(
name=name,
Expand All @@ -212,12 +217,12 @@ async def declare_exchange(
timeout=timeout,
)
)
exchange = cast(AbstractRobustExchange, exchange)

if not internal and robust:
# noinspection PyTypeChecker
self._exchanges[name].add(exchange) # type: ignore
self._exchanges[name] = exchange

return exchange # type: ignore
return exchange

async def exchange_delete(
self,
Expand Down Expand Up @@ -254,7 +259,12 @@ async def declare_queue(
Set to False for temporary queues that should not be restored.
"""
await self.ready()
queue: RobustQueue = await super().declare_queue( # type: ignore
# Passive is True so expecting the queue to be already declared
# if we can just return it instead of creating a new class instance
if passive and name and name in self._queues:
return list(self._queues[name])[0]

queue = await super().declare_queue(
name=name,
durable=durable,
exclusive=exclusive,
Expand All @@ -263,6 +273,8 @@ async def declare_queue(
arguments=arguments,
timeout=timeout,
)
queue = cast(RobustQueue, queue)

if robust:
self._queues[queue.name].add(queue)
return queue
Expand Down
15 changes: 15 additions & 0 deletions tests/test_amqp_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ async def test_channel_can_be_closed(self, connection):

assert channel.is_closed

async def test_get_exchange(self, connection, declare_exchange):
channel = await self.create_channel(connection)
name = get_random_name("passive", "exchange")

with pytest.raises(aio_pika.exceptions.ChannelNotFoundEntity):
await channel.get_exchange(name)

channel = await self.create_channel(connection)
exchange = await declare_exchange(
name, auto_delete=True, channel=channel,
)
exchange_passive = await channel.get_exchange(name)

assert exchange.name is exchange_passive.name


class TestCaseAmqpNoConfirmsRobust(TestCaseAmqpNoConfirms):
pass
Expand Down