Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions dramatiq/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import asyncio
import concurrent.futures
import functools
import gc
import logging
import threading
from types import TracebackType
from typing import Awaitable, Callable, Optional, ParamSpec, TypeVar

from .threading import Interrupt
Expand Down Expand Up @@ -133,10 +135,15 @@ def run_coroutine(self, coro: Awaitable[R]) -> R:
raise RuntimeError("Event loop is not running.")

done = threading.Event()
result_container: list[R] = []
exception_container: list[tuple[BaseException, TracebackType | None]] = []

async def wrapped_coro() -> R:
async def wrapped_coro() -> None:
try:
return await coro
result = await coro
result_container.append(result)
except BaseException as e:
exception_container.append((e, e.__traceback__))
finally:
done.set()

Expand All @@ -146,28 +153,13 @@ async def wrapped_coro() -> R:
try:
# Use a timeout to be able to catch asynchronously
# raised dramatiq exceptions (Interrupt).
return future.result(timeout=self.interrupt_check_ival)
future.result(timeout=self.interrupt_check_ival)
break
except (
# TODO replace with built-in TimeoutError once 3.10 support dropped.
concurrent.futures.TimeoutError
):
# NOTE: TimeoutError caught here could be from future.result() timing out (i.e. future not done yet),
# or a TimeoutError raised inside the future itself (future is done).
if not future.done():
# future not done, so .result() must've timed out. continue to wait again.
continue

# If execution reaches here, it means a TimeoutError was caught above, and the future is done.
# There are 3 possibilities here:
# 1. TimeoutError was raised inside the future. This will re-raise it.
# 2. First .result() call timed out, but the future completed by the time .done() was called.
# a) This will return the future's result, or
# b) raise the Exception that happened in the future.
return future.result(timeout=0)
# This is outside the 'except' block to avoid any
# "During handling of the above exception, another exception occurred" messages.
# zero timeout used because future is now done.

continue
except Interrupt as e:
# Asynchronously raised from another thread: cancel the
# future.
Expand All @@ -179,3 +171,15 @@ async def wrapped_coro() -> R:
if not done.wait(timeout=1.0):
raise RuntimeError("Timed out while waiting for coroutine.") from e
raise

exc_value: BaseException | None
if exception_container:
exc_value, exc_tb = exception_container[0]
exception_container.clear()
try:
raise exc_value.with_traceback(exc_tb)
finally:
exc_value = None
exc_tb = None

return result_container[0]
50 changes: 49 additions & 1 deletion tests/middleware/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import logging
import weakref
from threading import get_ident
from unittest import mock

Expand All @@ -13,6 +15,7 @@
get_event_loop_thread,
set_event_loop_thread,
)
from dramatiq.brokers.stub import StubBroker
from dramatiq.logging import get_logger
from dramatiq.middleware import CurrentMessage
from dramatiq.middleware.asyncio import AsyncIO
Expand Down Expand Up @@ -82,7 +85,9 @@ async def raise_error():
assert e.traceback[-1].name == "raise_actual_error"


def test_event_loop_thread_run_coroutine_timeout_exception(started_thread: EventLoopThread):
def test_event_loop_thread_run_coroutine_timeout_exception(
started_thread: EventLoopThread,
):
"""Test that TimeoutError in coroutine doesn't lead to infinite loop.

Regression test for https://github.com/Bogdanp/dramatiq/issues/791
Expand All @@ -100,6 +105,49 @@ async def raise_error():
started_thread.run_coroutine(coro)


def test_run_coroutine_exception_doesnt_leak(stub_broker: StubBroker, caplog: pytest.LogCaptureFixture):
# Disable log capturing to prevent pytest from holding traceback references
# via captured LogRecord.exc_info
caplog.set_level(logging.CRITICAL)

stub_broker.add_middleware(AsyncIO())

class Payload:
__slots__ = ("data",)

def __init__(self, data: bytes) -> None:
self.data = data

class PayloadError(Exception):
def __init__(self, payload: Payload) -> None:
self.payload = payload
super().__init__(payload)

weak_refs: list[weakref.ref[Payload]] = []

@actor(max_retries=1, max_backoff=1)
async def failing_actor():
payload = Payload(b"x" * 1024)
weak_refs.append(weakref.ref(payload))
raise PayloadError(payload)

with worker(stub_broker, worker_timeout=100, worker_threads=1) as stub_worker:
failing_actor.send()
stub_broker.join(failing_actor.queue_name, fail_fast=False)
stub_worker.join()

# StubBroker overloads MessageProxy.clear_exception to do nothing.
# Simulate here the normal behavior of clearing exception references.
for message in stub_broker.dead_letters:
del message._exception

stub_broker.flush_all()

# Check that payloads were collected
for weak_ref in weak_refs:
assert weak_ref() is None, "Payload object still alive"


@pytest.mark.skipif(
threading.current_platform not in threading.supported_platforms,
reason="Threading not supported on this platform.",
Expand Down
Loading