Skip to content

Commit 3142282

Browse files
authored
fix(signals): support Celery protocol v2 in create_user_task (#422)
* fix(signals): support Celery protocol v2 in create_user_task
1 parent 55075f0 commit 3142282

File tree

3 files changed

+172
-20
lines changed

3 files changed

+172
-20
lines changed

tests/test_signals.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818

1919
from user_tasks import user_task_stopped
2020
from user_tasks.models import UserTaskStatus
21-
from user_tasks.signals import start_user_task
21+
from user_tasks.signals import celery_app, create_user_task, start_user_task
2222
from user_tasks.tasks import UserTask
23+
from user_tasks.utils import extract_proto2_embed, extract_proto2_headers, proto2_to_proto1
2324

2425
User = auth.get_user_model()
2526

@@ -189,6 +190,31 @@ def test_non_user_task_success(self):
189190
statuses = UserTaskStatus.objects.all()
190191
assert not statuses
191192

193+
def test_create_user_task_protocol_v2(self):
194+
"""The create_user_task signal handler should work with Celery protocol version 2."""
195+
196+
original_protocol = getattr(celery_app.conf, 'task_protocol', 1)
197+
celery_app.conf.task_protocol = 2
198+
try:
199+
body = (
200+
[self.user.id, 'Argument'],
201+
{},
202+
{'callbacks': [], 'errbacks': [], 'task_chain': None, 'chord': None}
203+
)
204+
headers = {
205+
'task_id': 'tid', 'retries': 0, 'eta': None, 'expires': None,
206+
'group': None, 'timelimit': [None, None], 'task': 'test_signals.sample_task'
207+
}
208+
create_user_task(sender='test_signals.sample_task', body=body, headers=headers)
209+
statuses = UserTaskStatus.objects.all()
210+
assert len(statuses) == 1
211+
status = statuses[0]
212+
assert status.task_class == 'test_signals.sample_task'
213+
assert status.user_id == self.user.id
214+
assert status.name == 'SampleTask: Argument'
215+
finally:
216+
celery_app.conf.task_protocol = original_protocol
217+
192218
def _create_user_task(self, eager):
193219
"""Create a task based on UserTaskMixin and verify some assertions about its corresponding status."""
194220
result = sample_task.delay(self.user.id, 'Argument')
@@ -530,3 +556,68 @@ def test_connections_not_closed_when_we_cant_get_a_connection(self, mock_close_o
530556
with mock.patch('user_tasks.signals.transaction.get_connection', side_effect=Exception):
531557
start_user_task(sender=SampleTask)
532558
assert mock_close_old_connections.called is False
559+
560+
561+
class TestUtils:
562+
"""
563+
Unit tests for utility functions in user_tasks/utils.py.
564+
"""
565+
566+
def test_extract_proto2_headers(self):
567+
headers = extract_proto2_headers(
568+
task_id='abc123', retries=2, eta='2025-05-30T12:00:00',
569+
expires=None, group='group1', timelimit=[10, 20],
570+
task='my_task', extra='ignored')
571+
assert headers == {
572+
'id': 'abc123',
573+
'task': 'my_task',
574+
'retries': 2,
575+
'eta': '2025-05-30T12:00:00',
576+
'expires': None,
577+
'utc': True,
578+
'taskset': 'group1',
579+
'timelimit': [10, 20],
580+
}
581+
582+
def test_extract_proto2_embed(self):
583+
embed = extract_proto2_embed(
584+
callbacks=['cb'], errbacks=['eb'], task_chain=['a', 'b'],
585+
chord='chord1', extra='ignored')
586+
assert embed == {
587+
'callbacks': ['cb'],
588+
'errbacks': ['eb'],
589+
'chain': ['a', 'b'],
590+
'chord': 'chord1',
591+
}
592+
embed = extract_proto2_embed()
593+
assert embed == {
594+
'callbacks': [],
595+
'errbacks': [],
596+
'chain': None,
597+
'chord': None
598+
}
599+
600+
def test_proto2_to_proto1(self, monkeypatch):
601+
monkeypatch.setattr(
602+
'user_tasks.utils.chain',
603+
lambda x: f'chain({x})'
604+
)
605+
body = (
606+
[1, 2],
607+
{'foo': 'bar'},
608+
{'callbacks': ['cb'], 'errbacks': ['eb'],
609+
'task_chain': ['a'], 'chord': 'ch'}
610+
)
611+
headers = {
612+
'task_id': 'tid', 'retries': 1, 'eta': 'eta', 'expires': 'exp',
613+
'group': 'grp', 'timelimit': [1, 2], 'task': 't',
614+
'extra': 'ignored'
615+
}
616+
result = proto2_to_proto1(body, headers)
617+
assert result['id'] == 'tid'
618+
assert result['args'] == [1, 2]
619+
assert result['kwargs'] == {'foo': 'bar'}
620+
assert result['callbacks'] == ['cb', "chain(['a'])"]
621+
assert result['errbacks'] == ['eb']
622+
assert 'chain' not in result
623+
assert result['chord'] == 'ch'

user_tasks/signals.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
from uuid import uuid4
77

8+
from celery import current_app as celery_app
89
from celery.signals import before_task_publish, task_failure, task_prerun, task_retry, task_success
910

1011
from django.contrib.auth import get_user_model
@@ -16,40 +17,49 @@
1617
from .exceptions import TaskCanceledException
1718
from .models import UserTaskStatus
1819
from .tasks import UserTaskMixin
20+
from .utils import proto2_to_proto1
1921

2022
LOGGER = logging.getLogger(__name__)
2123

2224

2325
@before_task_publish.connect
24-
def create_user_task(sender=None, body=None, **kwargs):
26+
def create_user_task(sender=None, body=None, headers=None, **kwargs):
2527
"""
2628
Create a :py:class:`UserTaskStatus` record for each :py:class:`UserTaskMixin`.
2729
2830
Also creates a :py:class:`UserTaskStatus` for each chain, chord, or group containing
2931
the new :py:class:`UserTaskMixin`.
32+
33+
Supports Celery protocol v1 and v2.
3034
"""
3135
try:
3236
task_class = import_string(sender)
3337
except ImportError:
3438
return
35-
if issubclass(task_class.__class__, UserTaskMixin):
36-
arguments_dict = task_class.arguments_as_dict(*body['args'], **body['kwargs'])
37-
user_id = _get_user_id(arguments_dict)
38-
task_id = body['id']
39-
if body.get('callbacks', []):
40-
_create_chain_entry(user_id, task_id, task_class, body['args'], body['kwargs'], body['callbacks'])
41-
return
42-
if body.get('chord', None):
43-
_create_chord_entry(task_id, task_class, body, user_id)
44-
return
45-
parent = _get_or_create_group_parent(body, user_id)
46-
name = task_class.generate_name(arguments_dict)
47-
total_steps = task_class.calculate_total_steps(arguments_dict)
48-
UserTaskStatus.objects.get_or_create(
49-
task_id=task_id, defaults={'user_id': user_id, 'parent': parent, 'name': name, 'task_class': sender,
50-
'total_steps': total_steps})
51-
if parent:
52-
parent.increment_total_steps(total_steps)
39+
40+
if celery_app.conf.task_protocol == 2 and isinstance(body, tuple):
41+
body = proto2_to_proto1(body, headers or {})
42+
43+
if not issubclass(task_class.__class__, UserTaskMixin):
44+
return
45+
46+
arguments_dict = task_class.arguments_as_dict(*body['args'], **body['kwargs'])
47+
user_id = _get_user_id(arguments_dict)
48+
task_id = body['id']
49+
if body.get('callbacks', []):
50+
_create_chain_entry(user_id, task_id, task_class, body['args'], body['kwargs'], body['callbacks'])
51+
return
52+
if body.get('chord', None):
53+
_create_chord_entry(task_id, task_class, body, user_id)
54+
return
55+
parent = _get_or_create_group_parent(body, user_id)
56+
name = task_class.generate_name(arguments_dict)
57+
total_steps = task_class.calculate_total_steps(arguments_dict)
58+
UserTaskStatus.objects.get_or_create(
59+
task_id=task_id, defaults={'user_id': user_id, 'parent': parent, 'name': name, 'task_class': sender,
60+
'total_steps': total_steps})
61+
if parent:
62+
parent.increment_total_steps(total_steps)
5363

5464

5565
def _create_chain_entry(user_id, task_id, task_class, args, kwargs, callbacks, parent=None):

user_tasks/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
Utility functions for handling Celery task protocol compatibility.
3+
"""
4+
5+
from celery import chain
6+
7+
8+
def proto2_to_proto1(body, headers):
9+
"""
10+
Convert a protocol v2 task body and headers to protocol v1 format.
11+
"""
12+
args, kwargs, embed = body
13+
embedded = extract_proto2_embed(**embed)
14+
chained = embedded.pop("chain", None)
15+
new_body = dict(
16+
extract_proto2_headers(**headers),
17+
args=args,
18+
kwargs=kwargs,
19+
**embedded,
20+
)
21+
if chained:
22+
new_body["callbacks"].append(chain(chained))
23+
return new_body
24+
25+
26+
def extract_proto2_headers(task_id, retries, eta, expires, group, timelimit, task, **_):
27+
"""
28+
Extract relevant headers from protocol v2 format.
29+
"""
30+
return {
31+
"id": task_id,
32+
"task": task,
33+
"retries": retries,
34+
"eta": eta,
35+
"expires": expires,
36+
"utc": True,
37+
"taskset": group,
38+
"timelimit": timelimit,
39+
}
40+
41+
42+
def extract_proto2_embed(callbacks=None, errbacks=None, task_chain=None, chord=None, **_):
43+
"""
44+
Extract embedded task metadata.
45+
"""
46+
return {
47+
"callbacks": callbacks or [],
48+
"errbacks": errbacks or [],
49+
"chain": task_chain,
50+
"chord": chord,
51+
}

0 commit comments

Comments
 (0)