Skip to content

Commit deee9ea

Browse files
committed
Support enqueue_on_commit option
1 parent 161b6ac commit deee9ea

File tree

2 files changed

+116
-30
lines changed

2 files changed

+116
-30
lines changed

django_tasks/backends/celery/backend.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from dataclasses import dataclass
1+
from functools import partial
22
from typing import Any, Iterable, TypeVar
33

44
from celery import shared_task
55
from celery.app import default_app
66
from celery.local import Proxy as CeleryTaskProxy
77
from django.apps import apps
88
from django.core.checks import ERROR, CheckMessage
9+
from django.db import transaction
910
from django.utils import timezone
11+
from kombu.utils.uuid import uuid
1012
from typing_extensions import ParamSpec
1113

1214
from django_tasks.backends.base import BaseTaskBackend
1315
from django_tasks.task import MAX_PRIORITY, MIN_PRIORITY, ResultStatus, TaskResult
1416
from django_tasks.task import Task as BaseTask
15-
from django_tasks.utils import json_normalize
1617

1718
if not default_app:
1819
from django_tasks.backends.celery.app import app as celery_app
@@ -44,7 +45,6 @@ def _map_priority(value: int) -> int:
4445
return mapped_value
4546

4647

47-
@dataclass
4848
class Task(BaseTask[P, T]):
4949
celery_task: CeleryTaskProxy = None
5050
"""Celery proxy to the task in the current celery app task registry."""
@@ -77,19 +77,35 @@ def enqueue(
7777
priority = _map_priority(task.priority)
7878
apply_async_kwargs["priority"] = priority
7979

80+
task_id = uuid()
81+
apply_async_kwargs["task_id"] = task_id
82+
83+
if self._get_enqueue_on_commit_for_task(task):
84+
transaction.on_commit(
85+
partial(
86+
task.celery_task.apply_async,
87+
args,
88+
kwargs=kwargs,
89+
**apply_async_kwargs,
90+
)
91+
)
92+
else:
93+
task.celery_task.apply_async(args, kwargs=kwargs, **apply_async_kwargs)
94+
95+
# TODO: send task_enqueued signal
96+
# TODO: link a task to trigger the task_finished signal?
97+
# TODO: consider using DBTaskResult for results?
98+
8099
# TODO: a Celery result backend is required to get additional information
81-
async_result = task.celery_task.apply_async(
82-
args, kwargs=kwargs, **apply_async_kwargs
83-
)
84100
task_result = TaskResult[T](
85101
task=task,
86-
id=async_result.id,
102+
id=task_id,
87103
status=ResultStatus.NEW,
88104
enqueued_at=timezone.now(),
89105
started_at=None,
90106
finished_at=None,
91-
args=json_normalize(args),
92-
kwargs=json_normalize(kwargs),
107+
args=args,
108+
kwargs=kwargs,
93109
backend=self.alias,
94110
)
95111
return task_result

tests/tests/test_celery_backend.py

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from celery import Celery
55
from celery.result import AsyncResult
6-
from django.test import TestCase, override_settings
6+
from django.db import transaction
7+
from django.test import TestCase, TransactionTestCase, override_settings
78
from django.utils import timezone
89

910
from django_tasks import ResultStatus, default_task_backend, task, tasks
@@ -16,6 +17,10 @@ def noop_task(*args: tuple, **kwargs: dict) -> None:
1617
return None
1718

1819

20+
def enqueue_on_commit_task(*args: tuple, **kwargs: dict) -> None:
21+
pass
22+
23+
1924
@override_settings(
2025
TASKS={
2126
"default": {
@@ -24,10 +29,13 @@ def noop_task(*args: tuple, **kwargs: dict) -> None:
2429
}
2530
}
2631
)
27-
class CeleryBackendTestCase(TestCase):
32+
class CeleryBackendTestCase(TransactionTestCase):
2833
def setUp(self) -> None:
2934
# register task during setup so it is registered as a Celery task
3035
self.task = task()(noop_task)
36+
self.enqueue_on_commit_task = task(enqueue_on_commit=True)(
37+
enqueue_on_commit_task
38+
)
3139

3240
def test_using_correct_backend(self) -> None:
3341
self.assertEqual(default_task_backend, tasks["default"])
@@ -43,7 +51,7 @@ def test_celery_backend_app_missing(self) -> None:
4351
errors = list(default_task_backend.check())
4452

4553
self.assertEqual(len(errors), 1)
46-
self.assertIn("django_tasks.backends.celery", errors[0].hint)
54+
self.assertIn("django_tasks.backends.celery", errors[0].hint) # type:ignore[arg-type]
4755

4856
def test_enqueue_task(self) -> None:
4957
task = self.task
@@ -53,52 +61,114 @@ def test_enqueue_task(self) -> None:
5361
from django_tasks.backends.celery.app import app as celery_app
5462

5563
self.assertEqual(task.celery_task.app, celery_app) # type: ignore[attr-defined]
56-
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
57-
mock_apply_async.return_value = AsyncResult(id="123")
58-
result = default_task_backend.enqueue(task, (1,), {"two": 3})
64+
task_id = "123"
65+
with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id):
66+
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
67+
mock_apply_async.return_value = AsyncResult(id=task_id)
68+
result = default_task_backend.enqueue(task, (1,), {"two": 3})
5969

60-
self.assertEqual(result.id, "123")
70+
self.assertEqual(result.id, task_id)
6171
self.assertEqual(result.status, ResultStatus.NEW)
6272
self.assertIsNone(result.started_at)
6373
self.assertIsNone(result.finished_at)
6474
with self.assertRaisesMessage(ValueError, "Task has not finished yet"):
65-
result.result # noqa:B018
75+
result.return_value # noqa:B018
6676
self.assertEqual(result.task, task)
67-
self.assertEqual(result.args, [1])
77+
self.assertEqual(result.args, (1,))
6878
self.assertEqual(result.kwargs, {"two": 3})
6979
expected_priority = _map_priority(DEFAULT_PRIORITY)
7080
mock_apply_async.assert_called_once_with(
7181
(1,),
7282
kwargs={"two": 3},
83+
task_id=task_id,
7384
eta=None,
7485
priority=expected_priority,
7586
queue=DEFAULT_QUEUE_NAME,
7687
)
7788

7889
def test_using_additional_params(self) -> None:
79-
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
80-
mock_apply_async.return_value = AsyncResult(id="123")
81-
run_after = timezone.now() + timedelta(hours=10)
82-
result = self.task.using(
83-
run_after=run_after, priority=75, queue_name="queue-1"
84-
).enqueue()
90+
task_id = "123"
91+
with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id):
92+
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
93+
mock_apply_async.return_value = AsyncResult(id=task_id)
94+
run_after = timezone.now() + timedelta(hours=10)
95+
result = self.task.using(
96+
run_after=run_after, priority=75, queue_name="queue-1"
97+
).enqueue()
8598

86-
self.assertEqual(result.id, "123")
99+
self.assertEqual(result.id, task_id)
87100
self.assertEqual(result.status, ResultStatus.NEW)
88101
mock_apply_async.assert_called_once_with(
89-
(), kwargs={}, eta=run_after, priority=7, queue="queue-1"
102+
[], kwargs={}, task_id=task_id, eta=run_after, priority=7, queue="queue-1"
90103
)
91104

92105
def test_priority_mapping(self) -> None:
93106
for priority, expected in [(-100, 0), (-50, 2), (0, 4), (75, 7), (100, 9)]:
94-
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
95-
mock_apply_async.return_value = AsyncResult(id="123")
96-
self.task.using(priority=priority).enqueue()
107+
task_id = "123"
108+
with patch(
109+
"django_tasks.backends.celery.backend.uuid", return_value=task_id
110+
):
111+
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
112+
mock_apply_async.return_value = AsyncResult(id=task_id)
113+
self.task.using(priority=priority).enqueue()
97114

98115
mock_apply_async.assert_called_with(
99-
(), kwargs={}, eta=None, priority=expected, queue=DEFAULT_QUEUE_NAME
116+
[],
117+
kwargs={},
118+
task_id=task_id,
119+
eta=None,
120+
priority=expected,
121+
queue=DEFAULT_QUEUE_NAME,
100122
)
101123

124+
@override_settings(
125+
TASKS={
126+
"default": {
127+
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
128+
"ENQUEUE_ON_COMMIT": True,
129+
}
130+
}
131+
)
132+
def test_wait_until_transaction_commit(self) -> None:
133+
self.assertTrue(default_task_backend.enqueue_on_commit)
134+
self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task))
135+
136+
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
137+
mock_apply_async.return_value = AsyncResult(id="task_id")
138+
with transaction.atomic():
139+
self.task.enqueue()
140+
assert not mock_apply_async.called
141+
142+
mock_apply_async.assert_called_once()
143+
144+
@override_settings(
145+
TASKS={
146+
"default": {
147+
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
148+
}
149+
}
150+
)
151+
def test_wait_until_transaction_by_default(self) -> None:
152+
self.assertTrue(default_task_backend.enqueue_on_commit)
153+
self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task))
154+
155+
@override_settings(
156+
TASKS={
157+
"default": {
158+
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
159+
"ENQUEUE_ON_COMMIT": False,
160+
}
161+
}
162+
)
163+
def test_task_specific_enqueue_on_commit(self) -> None:
164+
self.assertFalse(default_task_backend.enqueue_on_commit)
165+
self.assertTrue(self.enqueue_on_commit_task.enqueue_on_commit)
166+
self.assertTrue(
167+
default_task_backend._get_enqueue_on_commit_for_task(
168+
self.enqueue_on_commit_task
169+
)
170+
)
171+
102172

103173
@override_settings(
104174
TASKS={

0 commit comments

Comments
 (0)