diff --git a/src/aleph/chains/connector.py b/src/aleph/chains/connector.py index 9d5c2aff3..dd11be6e8 100644 --- a/src/aleph/chains/connector.py +++ b/src/aleph/chains/connector.py @@ -1,5 +1,6 @@ import asyncio import logging +from contextlib import AbstractAsyncContextManager, AsyncExitStack from typing import Dict, Self, Union from aleph_message.models import Chain @@ -39,6 +40,23 @@ def __init__( self.readers = {} self.writers = {} + self._exit_stack = AsyncExitStack() + + async def __aenter__(self) -> Self: + await self._exit_stack.__aenter__() + # A connector can be both a reader and a writer; dedupe by identity so + # __aexit__ runs at most once per instance. + seen: set[int] = set() + for connector in (*self.readers.values(), *self.writers.values()): + if id(connector) in seen: + continue + seen.add(id(connector)) + if isinstance(connector, AbstractAsyncContextManager): + await self._exit_stack.enter_async_context(connector) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) @classmethod async def new( diff --git a/src/aleph/chains/ethereum.py b/src/aleph/chains/ethereum.py index 218b31bc0..41aa18e9a 100644 --- a/src/aleph/chains/ethereum.py +++ b/src/aleph/chains/ethereum.py @@ -108,6 +108,13 @@ def __init__( pending_tx_publisher=pending_tx_publisher, ) + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + # Closes the aiohttp ClientSession cached by AsyncHTTPProvider. + await self.web3_client.provider.disconnect() + @classmethod async def new( cls, @@ -377,68 +384,61 @@ async def broadcast_messages( ) async def packer(self, config: Config): - try: - pri_key = HexBytes(config.ethereum.private_key.value) - account = Account.from_key(pri_key) - address = account.address + pri_key = HexBytes(config.ethereum.private_key.value) + account = Account.from_key(pri_key) + address = account.address - LOGGER.info("Ethereum Connector set up with address %s" % address) - i = 0 - while True: - with self.session_factory() as session: - # Wait for sync operations to complete - if (count_pending_txs(session=session, chain=Chain.ETH)) or ( - count_pending_messages(session=session, chain=Chain.ETH) - ) > 1000: - await asyncio.sleep(30) - continue - - if i >= 100: - await asyncio.sleep(30) # wait three (!!) blocks - i = 0 - - nonce = await self.web3_client.eth.get_transaction_count( - account.address - ) + LOGGER.info("Ethereum Connector set up with address %s" % address) + i = 0 + while True: + with self.session_factory() as session: + # Wait for sync operations to complete + if (count_pending_txs(session=session, chain=Chain.ETH)) or ( + count_pending_messages(session=session, chain=Chain.ETH) + ) > 1000: + await asyncio.sleep(30) + continue + + if i >= 100: + await asyncio.sleep(30) # wait three (!!) blocks + i = 0 + + nonce = await self.web3_client.eth.get_transaction_count( + account.address + ) - # Collect all unconfirmed messages using pagination - max_unconfirmed = config.aleph.jobs.max_unconfirmed_messages.value - all_messages = [] - offset = 0 - while True: - batch = list( - get_unconfirmed_messages( - session=session, - limit=500, - offset=offset, - ) + # Collect all unconfirmed messages using pagination + max_unconfirmed = config.aleph.jobs.max_unconfirmed_messages.value + all_messages = [] + offset = 0 + while True: + batch = list( + get_unconfirmed_messages( + session=session, + limit=500, + offset=offset, ) - if not batch: - break - all_messages.extend(batch) - offset += len(batch) - if len(batch) < 500 or len(all_messages) >= max_unconfirmed: - break - all_messages = all_messages[:max_unconfirmed] - - if all_messages: - LOGGER.info( - "Chain sync: %d unconfirmed messages" % len(all_messages) ) + if not batch: + break + all_messages.extend(batch) + offset += len(batch) + if len(batch) < 500 or len(all_messages) >= max_unconfirmed: + break + all_messages = all_messages[:max_unconfirmed] - try: - response = await self.broadcast_messages( - account=account, - messages=all_messages, - nonce=nonce, - ) - LOGGER.info("Broadcast %r on %s" % (response, Chain.ETH.value)) - except Exception: - LOGGER.exception( - "Error while broadcasting messages to Ethereum" - ) + if all_messages: + LOGGER.info("Chain sync: %d unconfirmed messages" % len(all_messages)) + + try: + response = await self.broadcast_messages( + account=account, + messages=all_messages, + nonce=nonce, + ) + LOGGER.info("Broadcast %r on %s" % (response, Chain.ETH.value)) + except Exception: + LOGGER.exception("Error while broadcasting messages to Ethereum") - await asyncio.sleep(config.ethereum.commit_delay.value) - i += 1 - finally: - await self.web3_client.provider.disconnect() + await asyncio.sleep(config.ethereum.commit_delay.value) + i += 1 diff --git a/src/aleph/commands.py b/src/aleph/commands.py index 4cee56357..79f247d0b 100644 --- a/src/aleph/commands.py +++ b/src/aleph/commands.py @@ -184,6 +184,7 @@ async def main(args: List[str]) -> None: pending_tx_publisher=pending_tx_publisher, chain_data_service=chain_data_service, ) + await stack.enter_async_context(chain_connector) await repair_node( storage_service=storage_service, session_factory=session_factory diff --git a/src/aleph/jobs/job_utils.py b/src/aleph/jobs/job_utils.py index ca7eaabf6..5b254910e 100644 --- a/src/aleph/jobs/job_utils.py +++ b/src/aleph/jobs/job_utils.py @@ -157,8 +157,12 @@ def __init__(self, mq_queue: aio_pika.abc.AbstractQueue): self._event = asyncio.Event() async def _check_for_message(self): - async with self.mq_queue.iterator(no_ack=True) as queue_iter: - async for _ in queue_iter: + # Manual ack (no_ack=False) so that prefetched messages still in the + # iterator buffer on shutdown are nack'd with requeue=True by aio_pika, + # instead of being logged as "lost for consumer with no_ack". + async with self.mq_queue.iterator(no_ack=False) as queue_iter: + async for message in queue_iter: + await message.ack() self._event.set() async def __aenter__(self):