|
| 1 | +from dataclasses import dataclass |
1 | 2 | from typing import TypeVar
|
2 | 3 |
|
3 | 4 | from celery import shared_task
|
4 | 5 | from celery.app import default_app
|
5 | 6 | from celery.local import Proxy as CeleryTaskProxy
|
| 7 | +from django.utils import timezone |
6 | 8 | from typing_extensions import ParamSpec
|
7 | 9 |
|
8 | 10 | from django_tasks.backends.base import BaseTaskBackend
|
9 |
| -from django_tasks.task import Task, TaskResult |
| 11 | +from django_tasks.task import ResultStatus, TaskResult |
| 12 | +from django_tasks.task import Task as BaseTask |
| 13 | +from django_tasks.utils import json_normalize |
10 | 14 |
|
11 | 15 | if not default_app:
|
12 | 16 | from django_tasks.backends.celery.app import app as celery_app
|
| 17 | + |
13 | 18 | celery_app.set_default()
|
14 | 19 |
|
15 | 20 |
|
16 | 21 | T = TypeVar("T")
|
17 | 22 | P = ParamSpec("P")
|
18 | 23 |
|
19 | 24 |
|
20 |
| -class CeleryTask(Task): |
21 |
| - |
22 |
| - celery_task: CeleryTaskProxy |
| 25 | +@dataclass |
| 26 | +class Task(BaseTask[P, T]): |
| 27 | + celery_task: CeleryTaskProxy = None |
23 | 28 | """Celery proxy to the task in the current celery app task registry."""
|
24 | 29 |
|
25 | 30 | def __post_init__(self) -> None:
|
26 |
| - # TODO: allow passing extra celery specific parameters? |
27 | 31 | celery_task = shared_task()(self.func)
|
28 | 32 | self.celery_task = celery_task
|
29 | 33 | return super().__post_init__()
|
30 | 34 |
|
31 | 35 |
|
32 | 36 | class CeleryBackend(BaseTaskBackend):
|
33 |
| - task_class = CeleryTask |
| 37 | + task_class = Task |
34 | 38 | supports_defer = True
|
35 | 39 |
|
36 | 40 | def enqueue(
|
37 |
| - self, task: Task[P, T], args: P.args, kwargs: P.kwargs |
| 41 | + self, |
| 42 | + task: Task[P, T], # type: ignore[override] |
| 43 | + args: P.args, |
| 44 | + kwargs: P.kwargs, |
38 | 45 | ) -> TaskResult[T]:
|
39 | 46 | self.validate_task(task)
|
40 | 47 |
|
41 |
| - apply_async_kwargs = { |
| 48 | + apply_async_kwargs: P.kwargs = { |
42 | 49 | "eta": task.run_after,
|
43 | 50 | }
|
44 | 51 | if task.queue_name:
|
45 | 52 | apply_async_kwargs["queue"] = task.queue_name
|
46 | 53 | if task.priority:
|
47 | 54 | apply_async_kwargs["priority"] = task.priority
|
48 |
| - task.celery_task.apply_async(args, kwargs=kwargs, **apply_async_kwargs) |
| 55 | + |
| 56 | + # TODO: a Celery result backend is required to get additional information |
| 57 | + async_result = task.celery_task.apply_async( |
| 58 | + args, kwargs=kwargs, **apply_async_kwargs |
| 59 | + ) |
| 60 | + task_result = TaskResult[T]( |
| 61 | + task=task, |
| 62 | + id=async_result.id, |
| 63 | + status=ResultStatus.NEW, |
| 64 | + enqueued_at=timezone.now(), |
| 65 | + finished_at=None, |
| 66 | + args=json_normalize(args), |
| 67 | + kwargs=json_normalize(kwargs), |
| 68 | + backend=self.alias, |
| 69 | + ) |
| 70 | + return task_result |
0 commit comments