Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import (
AsyncIterator,
Awaitable,
Coroutine as CoroutineT,
Generator,
Iterable,
Iterator,
Expand Down Expand Up @@ -276,6 +277,69 @@ def _fixture_synchronizer(
AsyncGenFixtureYieldType = TypeVar("AsyncGenFixtureYieldType")


def _create_task_in_context(
coro: CoroutineT[Any, Any, Any],
loop: AbstractEventLoop,
context: contextvars.Context,
) -> asyncio.Task[Any]:
if sys.version_info >= (3, 11):
return loop.create_task(coro, context=context)

from backports.asyncio.runner._patch import _patch_object
from backports.asyncio.runner.tasks import Task

with (
_patch_object(asyncio.tasks, asyncio.tasks.Task.__name__, Task),
_patch_object(contextvars, contextvars.copy_context.__name__, lambda: context),
):
return loop.create_task(coro)


class _FixtureRunner:
def __init__(self, loop: AbstractEventLoop, context: contextvars.Context) -> None:
self.loop = loop
self.queue: asyncio.Queue[tuple[Awaitable[Any], asyncio.Future[Any]] | None] = (
asyncio.Queue()
)
self._context = context
self._task = None

async def _worker(self) -> None:
while True:
item = await self.queue.get()
if item is None:
break
coro, future = item
try:
retval = await coro
future.set_result(retval)
except Exception as exc:
future.set_exception(exc)

def run(self, func):
return self.loop.run_until_complete(self._run(func))

async def _run(self, func):
if self._task is None:
self._task = _create_task_in_context(
self._worker(), loop=self.loop, context=self._context
)

coro = func()
future = self.loop.create_future()
self.queue.put_nowait((coro, future))
return await future

async def _stop(self):
self.queue.put_nowait(None)
if self._task is not None:
await self._task
self._task = None

def stop(self) -> None:
self.loop.run_until_complete(self._stop())


def _wrap_asyncgen_fixture(
fixture_function: Callable[
AsyncGenFixtureParams, AsyncGeneratorType[AsyncGenFixtureYieldType, Any]
Expand All @@ -295,7 +359,8 @@ async def setup():
return res

context = contextvars.copy_context()
result = runner.run(setup(), context=context)
fixture_runner = _FixtureRunner(loop=runner.get_loop(), context=context)
result = fixture_runner.run(setup)

reset_contextvars = _apply_contextvar_changes(context)

Expand All @@ -312,7 +377,8 @@ async def async_finalizer() -> None:
msg += "Yield only once."
raise ValueError(msg)

runner.run(async_finalizer(), context=context)
fixture_runner.run(async_finalizer)
fixture_runner.stop()
if reset_contextvars is not None:
reset_contextvars()

Expand Down
11 changes: 11 additions & 0 deletions tests/async_fixtures/test_async_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,14 @@ async def async_fixture_method(self):
@pytest.mark.asyncio
async def test_async_fixture_method(self):
assert self.is_same_instance


@pytest.fixture()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test file you found uses the "old" way of writing tests directly in a module. In the past, these kinds of tests caused side effects (e.g. failing to clean up resources resulting in a ResourceWarning) that only appeared much later in the test suite. Therefore, we moved to using Pytester for all pytest-asyncio tests.

I'd appreciate, if you rewrote the test to use Pytester, even though it's mostly boilerplate. You can check the the file test_function_scope.py for an example.

Although it's not strictly necessary, I suggest to use strict mode in those tests (or both, if there's different behavior depending on the mode).

async def setup_and_teardown_tasks():
task = asyncio.current_task()
yield
assert task is asyncio.current_task()


async def test_setup_and_teardown_tasks(setup_and_teardown_tasks):
pass
Loading