Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions docs/guide/testing-taskiq.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ the same interface as a real broker, but it doesn't send tasks actually.
Let's define a task.

```python
from your_project.taskiq import broker
from your_project.tkq import broker

@broker.task
async def parse_int(val: str) -> int:
Expand All @@ -107,7 +107,7 @@ And that's it. Test should pass.
What if you want to test a function that uses task. Let's define such function.

```python
from your_project.taskiq import broker
from your_project.tkq import broker

@broker.task
async def parse_int(val: str) -> int:
Expand All @@ -129,6 +129,63 @@ async def test_add_one():
assert await parse_and_add_one("11") == 12
```

### Unawaitable tasks

When a function calls an asynchronous task but doesn't await its result,
it can be challenging to test.

In such cases, the `InMemoryBroker` provides two convenient ways to help you:
the `await_inplace` constructor parameter and the `wait_all` method.

Consider the following example where we define a task and a function that calls it:

```python
from your_project.tkq import broker

@broker.task
async def parse_int(val: str) -> int:
return int(val)


async def parse_int_later(val: str) -> int:
await parse_int.kiq(val)
return 1
```

To test this function, we can do two things:

1. By setting the `await_inplace=True` parameter when creating the broker.
In that case all tasks will be automatically awaited as soon as they are called.
In such a way you don't need to manually call the `wait_result` in your code.

To set it up, define the broker as the following:

```python
...
broker = InMemoryBroker(await_inplace=True)
...

```

With this setup all `await function.kiq()` calls will behave similarly to `await function()`, but
with dependency injection and all taskiq-related functionality.

2. Alternatively, you can manually await all tasks after invoking the
target function by using the `wait_all` method.
This gives you more control over when to wait for tasks to complete.

```python
from your_project.tkq import broker

@pytest.mark.anyio
async def test_add_one():
# Call the function that triggers the async task
assert await parse_int_later("11") == 1
await broker.wait_all() # Waits for all tasks to complete
# At that time we can guarantee that all sent tasks
# have been completed and do all the assertions.
```

## Dependency injection

If you use dependencies in your tasks, you may think that this can become a problem. But it's not.
Expand All @@ -146,7 +203,7 @@ from typing import Annotated
from pathlib import Path
from taskiq import TaskiqDepends

from your_project.taskiq import broker
from your_project.tkq import broker


@broker.task
Expand All @@ -161,7 +218,7 @@ async def modify_path(some_path: Annotated[Path, TaskiqDepends()]):
from pathlib import Path
from taskiq import TaskiqDepends

from your_project.taskiq import broker
from your_project.tkq import broker


@broker.task
Expand All @@ -177,7 +234,7 @@ expected dependencies manually as function's arguments or key-word arguments.

```python
import pytest
from your_project.taskiq import broker
from your_project.tkq import broker

from pathlib import Path

Expand All @@ -193,7 +250,7 @@ must mutate dependency_context before calling a task. We suggest to do it in fix

```python
import pytest
from your_project.taskiq import broker
from your_project.tkq import broker
from pathlib import Path


Expand Down
20 changes: 19 additions & 1 deletion taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
cast_types: bool = True,
max_async_tasks: int = 30,
propagate_exceptions: bool = True,
await_inplace: bool = False,
) -> None:
super().__init__()
self.result_backend = InmemoryResultBackend(
Expand All @@ -140,6 +141,7 @@ def __init__(
max_async_tasks=max_async_tasks,
propagate_exceptions=propagate_exceptions,
)
self.await_inplace = await_inplace
self._running_tasks: "Set[asyncio.Task[Any]]" = set()

async def kick(self, message: BrokerMessage) -> None:
Expand All @@ -156,7 +158,12 @@ async def kick(self, message: BrokerMessage) -> None:
if target_task is None:
raise TaskiqError("Unknown task.")

task = asyncio.create_task(self.receiver.callback(message=message.message))
receiver_cb = self.receiver.callback(message=message.message)
if self.await_inplace:
await receiver_cb
return

task = asyncio.create_task(receiver_cb)
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)

Expand All @@ -171,6 +178,17 @@ def listen(self) -> AsyncGenerator[bytes, None]:
"""
raise RuntimeError("Inmemory brokers cannot listen.")

async def wait_all(self) -> None:
"""
Wait for all currently running tasks to complete.

Useful when used in testing and you need to await all sent tasks
before asserting results.
"""
to_await = list(self._running_tasks)
for task in to_await:
await task

async def startup(self) -> None:
"""Runs startup events for client and worker side."""
for event in (TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.WORKER_STARTUP):
Expand Down
36 changes: 36 additions & 0 deletions tests/brokers/test_inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,39 @@ async def test_task() -> str:

result = await task.wait_result()
assert result.return_value == test_value


@pytest.mark.anyio
async def test_inline_awaits() -> None:
broker = InMemoryBroker(await_inplace=True)
slept = False

@broker.task
async def test_task() -> None:
nonlocal slept
await asyncio.sleep(0.2)
slept = True

task = await test_task.kiq()
assert slept
assert await task.is_ready()
assert not broker._running_tasks


@pytest.mark.anyio
async def test_wait_all() -> None:
broker = InMemoryBroker()
slept = False

@broker.task
async def test_task() -> None:
nonlocal slept
await asyncio.sleep(0.2)
slept = True

task = await test_task.kiq()
assert not slept
await broker.wait_all()
assert slept
assert await task.is_ready()
assert not broker._running_tasks
Loading