Skip to content

Commit 02a686c

Browse files
authored
refactor: Enqueue exceptions in ActiveTask (a2aproject#1053)
# Description Remove ActiveTask._exception. Make exception handling fully synchronized in the queues. Fixes a2aproject#1032 🦕
1 parent 4e01a91 commit 02a686c

4 files changed

Lines changed: 100 additions & 86 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -104,36 +104,60 @@ async def run(self) -> None:
104104
)
105105
except Exception as e:
106106
logger.exception('Consumer[%s]: Failed', self.active_task._task_id)
107-
async with self.active_task._lock:
108-
await self.active_task._mark_task_as_failed(e)
107+
108+
updated_task = None
109+
task = await self.active_task._task_manager.get_task()
110+
if task:
111+
handled_event = TaskStatusUpdateEvent(
112+
task_id=task.id,
113+
context_id=task.context_id,
114+
status=TaskStatus(
115+
state=TaskState.TASK_STATE_FAILED,
116+
),
117+
)
118+
updated_task = await self._handle_task_event(handled_event)
119+
120+
await self._enqueue_to_subscribers(cast('Event', e), updated_task)
109121

110122
async def _process_event(self, event: Event) -> None:
111123
updated_task = None
124+
handled_event: (
125+
Task
126+
| TaskStatusUpdateEvent
127+
| TaskArtifactUpdateEvent
128+
| PushNotificationEvent
129+
| None
130+
) = None
131+
132+
if isinstance(event, _RequestCompleted):
133+
logger.debug(
134+
'Consumer[%s]: Request completed', self.active_task._task_id
135+
)
136+
self.active_task._request_lock.release()
137+
elif isinstance(event, _RequestStarted):
138+
logger.debug(
139+
'Consumer[%s]: Request started', self.active_task._task_id
140+
)
141+
self.message_to_save = event.request_context.message
142+
elif isinstance(event, BaseException):
143+
raise event
144+
elif isinstance(event, Message):
145+
self._handle_message_event(event)
146+
elif isinstance(
147+
event,
148+
TaskStatusUpdateEvent
149+
| TaskArtifactUpdateEvent
150+
| PushNotificationEvent
151+
| Task,
152+
):
153+
updated_task = await self._handle_task_event(event)
154+
handled_event = updated_task if isinstance(event, Task) else event
112155

113-
try:
114-
if isinstance(event, _RequestCompleted):
115-
logger.debug(
116-
'Consumer[%s]: Request completed', self.active_task._task_id
117-
)
118-
self.active_task._request_lock.release()
119-
elif isinstance(event, _RequestStarted):
120-
logger.debug(
121-
'Consumer[%s]: Request started', self.active_task._task_id
122-
)
123-
self.message_to_save = event.request_context.message
124-
elif isinstance(event, Message):
125-
self._handle_message_event(event)
126-
else:
127-
updated_task = await self._handle_task_event(event)
128-
if isinstance(event, Task):
129-
event = updated_task
130-
131-
if updated_task is not None:
132-
await self._update_task_state(updated_task, event)
133-
self.active_task._task_created.set()
156+
if updated_task is not None and handled_event is not None:
157+
await self._update_task_state(updated_task, handled_event)
158+
self.active_task._task_created.set()
134159

135-
finally:
136-
await self._enqueue_to_subscribers(event, updated_task)
160+
await self._enqueue_to_subscribers(event, updated_task)
137161

138162
def _handle_message_event(self, event: Message) -> None:
139163
if self.task_mode is True:
@@ -286,9 +310,6 @@ class ActiveTask:
286310
- `self._lock` (asyncio.Lock) ensures mutually exclusive access for critical
287311
lifecycle state changes, such as starting the task, subscribing, and
288312
determining if cleanup is safe to trigger.
289-
290-
mutation to the observable result state (like `_exception`,
291-
or `_is_finished`) notifies waiting coroutines (like `wait()`).
292313
- `self._is_finished` (asyncio.Event) provides a thread-safe, non-blocking way
293314
for external observers and internal loops to check if the ActiveTask has
294315
permanently ceased execution and closed its queues.
@@ -349,10 +370,6 @@ def __init__(
349370
# Protected by `_lock`.
350371
self._reference_count = 0
351372

352-
# Holds any fatal exception that crashed the producer or consumer.
353-
# TODO: Synchronize exception handling (ideally mix it in the queue).
354-
self._exception: Exception | None = None
355-
356373
# Queue for incoming requests
357374
self._request_queue: AsyncQueue[tuple[RequestContext, uuid.UUID]] = (
358375
_create_async_queue()
@@ -481,22 +498,17 @@ async def _run_producer(self) -> None:
481498
_RequestStarted(request_id, request_context),
482499
)
483500
)
484-
485501
await self._agent_executor.execute(
486502
request_context, self._event_queue_agent
487503
)
488504
logger.debug(
489505
'Producer[%s]: Execution finished successfully',
490506
self._task_id,
491507
)
492-
finally:
493-
logger.debug(
494-
'Producer[%s]: Enqueuing request completed event',
495-
self._task_id,
496-
)
497508
await self._event_queue_agent.enqueue_event(
498509
cast('Event', _RequestCompleted(request_id))
499510
)
511+
finally:
500512
self._request_queue.task_done()
501513
except asyncio.CancelledError:
502514
logger.debug('Producer[%s]: Cancelled', self._task_id)
@@ -516,8 +528,7 @@ async def _run_producer(self) -> None:
516528
request_context.context_id or '',
517529
)
518530
self._task_created.set()
519-
async with self._lock:
520-
await self._mark_task_as_failed(e)
531+
await self._event_queue_agent.enqueue_event(cast('Event', e))
521532

522533
finally:
523534
self._request_queue.shutdown(immediate=True)
@@ -537,7 +548,7 @@ async def _run_consumer(self) -> None:
537548
logger.debug('Consumer[%s]: Finishing', self._task_id)
538549
await self._maybe_cleanup()
539550

540-
async def subscribe( # noqa: PLR0912, PLR0915
551+
async def subscribe(
541552
self,
542553
*,
543554
request: RequestContext | None = None,
@@ -554,12 +565,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
554565
logger.debug('Subscribe[%s]: New subscriber', self._task_id)
555566

556567
async with self._lock:
557-
if self._exception:
558-
logger.debug(
559-
'Subscribe[%s]: Failed, exception already set',
560-
self._task_id,
561-
)
562-
raise self._exception
563568
if self._is_finished.is_set():
564569
raise InvalidParamsError(
565570
f'Task {self._task_id} is already completed.'
@@ -585,17 +590,23 @@ async def subscribe( # noqa: PLR0912, PLR0915
585590

586591
while True:
587592
try:
588-
if self._exception:
589-
raise self._exception
590-
591593
dequeued = await tapped_queue.dequeue_event()
592594
event, updated_task = cast('Any', dequeued)
593595
logger.debug(
594-
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
596+
'Subscriber[%s] Dequeued event [%s]:\n %s\nUpdated task:\n%s\n',
595597
self._task_id,
598+
type(event).__name__,
596599
event,
597600
updated_task,
598601
)
602+
if isinstance(event, BaseException):
603+
logger.debug(
604+
'Subscriber[%s]: Raising exception: %s',
605+
self._task_id,
606+
event,
607+
)
608+
raise event
609+
599610
if replace_status_update_with_task and isinstance(
600611
event, TaskStatusUpdateEvent
601612
):
@@ -605,8 +616,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
605616
updated_task,
606617
)
607618
event = updated_task
608-
if self._exception:
609-
raise self._exception from None
610619
if isinstance(event, _RequestCompleted):
611620
if (
612621
request_id is not None
@@ -629,8 +638,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
629638
finally:
630639
tapped_queue.task_done()
631640
except (QueueShutDown, asyncio.CancelledError):
632-
if self._exception:
633-
raise self._exception from None
634641
break
635642
finally:
636643
logger.debug('Subscribe[%s]: Unsubscribing', self._task_id)
@@ -714,9 +721,9 @@ async def _maybe_cleanup(self) -> None:
714721
logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id)
715722
self._on_cleanup(self)
716723

717-
async def _mark_task_as_failed(self, exception: Exception) -> None:
718-
if self._exception is None:
719-
self._exception = exception
724+
async def _mark_task_as_failed(self, exception: Exception) -> Task | None:
725+
logger.debug('Marking task %s as failed: %s', self._task_id, exception)
726+
task = None
720727
if self._task_created.is_set():
721728
try:
722729
task = await self._task_manager.get_task()
@@ -732,10 +739,10 @@ async def _mark_task_as_failed(self, exception: Exception) -> None:
732739
)
733740
except QueueShutDown:
734741
pass
742+
return task
735743

736744
async def get_task(self) -> Task:
737745
"""Get task from db."""
738-
# TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation).
739746
await self._task_created.wait()
740747
task = await self._task_manager.get_task()
741748
if not task:

src/a2a/server/request_handlers/default_request_handler_v2.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
SubscribeToTaskRequest,
3838
Task,
3939
TaskPushNotificationConfig,
40-
TaskStatusUpdateEvent,
4140
)
4241
from a2a.utils.errors import (
4342
ExtendedAgentCardNotConfiguredError,
@@ -252,13 +251,6 @@ async def on_message_send( # noqa: D102
252251
type(event).__name__,
253252
event,
254253
)
255-
if isinstance(event, TaskStatusUpdateEvent):
256-
self._validate_task_id_match(task_id, event.task_id)
257-
event = await active_task.get_task()
258-
logger.debug(
259-
'Replaced TaskStatusUpdateEvent with Task: %s', event
260-
)
261-
262254
if isinstance(event, Task) and (
263255
params.configuration.return_immediately
264256
or event.status.state

tests/integration/test_scenarios.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,15 @@ async def cancel(
461461
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
462462
assert task.status.state == TaskState.TASK_STATE_FAILED
463463

464+
if streaming:
465+
with pytest.raises(
466+
InvalidParamsError,
467+
match='Task .* is already completed',
468+
):
469+
await client.subscribe(
470+
SubscribeToTaskRequest(id=task.id)
471+
).__anext__()
472+
464473

465474
# Scenario 12/13: Exception after initial event
466475
@pytest.mark.timeout(2.0)

tests/server/agent_execution/test_active_task.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -316,26 +316,42 @@ async def test_active_task_subscribe_exception_handling(
316316
active_task: ActiveTask,
317317
agent_executor: Mock,
318318
request_context: Mock,
319+
task_manager: Mock,
319320
) -> None:
320321
"""Test exception handling in subscribe."""
321-
agent_executor.execute = AsyncMock(
322-
side_effect=ValueError('Producer failure')
322+
event = asyncio.Event()
323+
324+
task_manager.get_task.return_value = Task(
325+
id='test-task-id',
326+
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
323327
)
324328

329+
async def execute_mock(req, q):
330+
await q.enqueue_event(
331+
Task(
332+
id='test-task-id',
333+
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
334+
)
335+
)
336+
await event.wait()
337+
raise ValueError('Producer failure')
338+
339+
agent_executor.execute = AsyncMock(side_effect=execute_mock)
340+
325341
await active_task.enqueue_request(request_context)
326342
await active_task.start(
327343
call_context=ServerCallContext(), create_task_if_missing=True
328344
)
329345

330-
# Give it a moment to fail
331-
for _ in range(10):
332-
if active_task._exception:
333-
break
334-
await asyncio.sleep(0.05)
346+
subscriber = active_task.subscribe()
347+
task = await anext(subscriber)
348+
assert task.status.state == TaskState.TASK_STATE_SUBMITTED
349+
350+
# Now trigger the exception
351+
event.set()
335352

336353
with pytest.raises(ValueError, match='Producer failure'):
337-
async for _ in active_task.subscribe():
338-
pass
354+
await anext(subscriber)
339355

340356
@pytest.mark.asyncio
341357
async def test_active_task_cancel_not_started(
@@ -766,16 +782,6 @@ async def test_active_task_maybe_cleanup_not_finished(
766782
await active_task._maybe_cleanup()
767783
on_cleanup.assert_not_called()
768784

769-
@pytest.mark.asyncio
770-
async def test_active_task_subscribe_exception_already_set(
771-
self, active_task: ActiveTask
772-
) -> None:
773-
"""Test subscribe when exception is already set."""
774-
active_task._exception = ValueError('Pre-existing error')
775-
with pytest.raises(ValueError, match='Pre-existing error'):
776-
async for _ in active_task.subscribe():
777-
pass
778-
779785
@pytest.mark.asyncio
780786
async def test_active_task_subscribe_inner_exception(
781787
self,

0 commit comments

Comments
 (0)