|
18 | 18 |
|
19 | 19 | from user_tasks import user_task_stopped |
20 | 20 | from user_tasks.models import UserTaskStatus |
21 | | -from user_tasks.signals import start_user_task |
| 21 | +from user_tasks.signals import celery_app, create_user_task, start_user_task |
22 | 22 | from user_tasks.tasks import UserTask |
| 23 | +from user_tasks.utils import extract_proto2_embed, extract_proto2_headers, proto2_to_proto1 |
23 | 24 |
|
24 | 25 | User = auth.get_user_model() |
25 | 26 |
|
@@ -189,6 +190,31 @@ def test_non_user_task_success(self): |
189 | 190 | statuses = UserTaskStatus.objects.all() |
190 | 191 | assert not statuses |
191 | 192 |
|
| 193 | + def test_create_user_task_protocol_v2(self): |
| 194 | + """The create_user_task signal handler should work with Celery protocol version 2.""" |
| 195 | + |
| 196 | + original_protocol = getattr(celery_app.conf, 'task_protocol', 1) |
| 197 | + celery_app.conf.task_protocol = 2 |
| 198 | + try: |
| 199 | + body = ( |
| 200 | + [self.user.id, 'Argument'], |
| 201 | + {}, |
| 202 | + {'callbacks': [], 'errbacks': [], 'task_chain': None, 'chord': None} |
| 203 | + ) |
| 204 | + headers = { |
| 205 | + 'task_id': 'tid', 'retries': 0, 'eta': None, 'expires': None, |
| 206 | + 'group': None, 'timelimit': [None, None], 'task': 'test_signals.sample_task' |
| 207 | + } |
| 208 | + create_user_task(sender='test_signals.sample_task', body=body, headers=headers) |
| 209 | + statuses = UserTaskStatus.objects.all() |
| 210 | + assert len(statuses) == 1 |
| 211 | + status = statuses[0] |
| 212 | + assert status.task_class == 'test_signals.sample_task' |
| 213 | + assert status.user_id == self.user.id |
| 214 | + assert status.name == 'SampleTask: Argument' |
| 215 | + finally: |
| 216 | + celery_app.conf.task_protocol = original_protocol |
| 217 | + |
192 | 218 | def _create_user_task(self, eager): |
193 | 219 | """Create a task based on UserTaskMixin and verify some assertions about its corresponding status.""" |
194 | 220 | result = sample_task.delay(self.user.id, 'Argument') |
@@ -530,3 +556,68 @@ def test_connections_not_closed_when_we_cant_get_a_connection(self, mock_close_o |
530 | 556 | with mock.patch('user_tasks.signals.transaction.get_connection', side_effect=Exception): |
531 | 557 | start_user_task(sender=SampleTask) |
532 | 558 | assert mock_close_old_connections.called is False |
| 559 | + |
| 560 | + |
| 561 | +class TestUtils: |
| 562 | + """ |
| 563 | + Unit tests for utility functions in user_tasks/utils.py. |
| 564 | + """ |
| 565 | + |
| 566 | + def test_extract_proto2_headers(self): |
| 567 | + headers = extract_proto2_headers( |
| 568 | + task_id='abc123', retries=2, eta='2025-05-30T12:00:00', |
| 569 | + expires=None, group='group1', timelimit=[10, 20], |
| 570 | + task='my_task', extra='ignored') |
| 571 | + assert headers == { |
| 572 | + 'id': 'abc123', |
| 573 | + 'task': 'my_task', |
| 574 | + 'retries': 2, |
| 575 | + 'eta': '2025-05-30T12:00:00', |
| 576 | + 'expires': None, |
| 577 | + 'utc': True, |
| 578 | + 'taskset': 'group1', |
| 579 | + 'timelimit': [10, 20], |
| 580 | + } |
| 581 | + |
| 582 | + def test_extract_proto2_embed(self): |
| 583 | + embed = extract_proto2_embed( |
| 584 | + callbacks=['cb'], errbacks=['eb'], task_chain=['a', 'b'], |
| 585 | + chord='chord1', extra='ignored') |
| 586 | + assert embed == { |
| 587 | + 'callbacks': ['cb'], |
| 588 | + 'errbacks': ['eb'], |
| 589 | + 'chain': ['a', 'b'], |
| 590 | + 'chord': 'chord1', |
| 591 | + } |
| 592 | + embed = extract_proto2_embed() |
| 593 | + assert embed == { |
| 594 | + 'callbacks': [], |
| 595 | + 'errbacks': [], |
| 596 | + 'chain': None, |
| 597 | + 'chord': None |
| 598 | + } |
| 599 | + |
| 600 | + def test_proto2_to_proto1(self, monkeypatch): |
| 601 | + monkeypatch.setattr( |
| 602 | + 'user_tasks.utils.chain', |
| 603 | + lambda x: f'chain({x})' |
| 604 | + ) |
| 605 | + body = ( |
| 606 | + [1, 2], |
| 607 | + {'foo': 'bar'}, |
| 608 | + {'callbacks': ['cb'], 'errbacks': ['eb'], |
| 609 | + 'task_chain': ['a'], 'chord': 'ch'} |
| 610 | + ) |
| 611 | + headers = { |
| 612 | + 'task_id': 'tid', 'retries': 1, 'eta': 'eta', 'expires': 'exp', |
| 613 | + 'group': 'grp', 'timelimit': [1, 2], 'task': 't', |
| 614 | + 'extra': 'ignored' |
| 615 | + } |
| 616 | + result = proto2_to_proto1(body, headers) |
| 617 | + assert result['id'] == 'tid' |
| 618 | + assert result['args'] == [1, 2] |
| 619 | + assert result['kwargs'] == {'foo': 'bar'} |
| 620 | + assert result['callbacks'] == ['cb', "chain(['a'])"] |
| 621 | + assert result['errbacks'] == ['eb'] |
| 622 | + assert 'chain' not in result |
| 623 | + assert result['chord'] == 'ch' |
0 commit comments