@@ -104,36 +104,60 @@ async def run(self) -> None:
104104 )
105105 except Exception as e :
106106 logger .exception ('Consumer[%s]: Failed' , self .active_task ._task_id )
107- async with self .active_task ._lock :
108- await self .active_task ._mark_task_as_failed (e )
107+
108+ updated_task = None
109+ task = await self .active_task ._task_manager .get_task ()
110+ if task :
111+ handled_event = TaskStatusUpdateEvent (
112+ task_id = task .id ,
113+ context_id = task .context_id ,
114+ status = TaskStatus (
115+ state = TaskState .TASK_STATE_FAILED ,
116+ ),
117+ )
118+ updated_task = await self ._handle_task_event (handled_event )
119+
120+ await self ._enqueue_to_subscribers (cast ('Event' , e ), updated_task )
109121
110122 async def _process_event (self , event : Event ) -> None :
111123 updated_task = None
124+ handled_event : (
125+ Task
126+ | TaskStatusUpdateEvent
127+ | TaskArtifactUpdateEvent
128+ | PushNotificationEvent
129+ | None
130+ ) = None
131+
132+ if isinstance (event , _RequestCompleted ):
133+ logger .debug (
134+ 'Consumer[%s]: Request completed' , self .active_task ._task_id
135+ )
136+ self .active_task ._request_lock .release ()
137+ elif isinstance (event , _RequestStarted ):
138+ logger .debug (
139+ 'Consumer[%s]: Request started' , self .active_task ._task_id
140+ )
141+ self .message_to_save = event .request_context .message
142+ elif isinstance (event , BaseException ):
143+ raise event
144+ elif isinstance (event , Message ):
145+ self ._handle_message_event (event )
146+ elif isinstance (
147+ event ,
148+ TaskStatusUpdateEvent
149+ | TaskArtifactUpdateEvent
150+ | PushNotificationEvent
151+ | Task ,
152+ ):
153+ updated_task = await self ._handle_task_event (event )
154+ handled_event = updated_task if isinstance (event , Task ) else event
112155
113- try :
114- if isinstance (event , _RequestCompleted ):
115- logger .debug (
116- 'Consumer[%s]: Request completed' , self .active_task ._task_id
117- )
118- self .active_task ._request_lock .release ()
119- elif isinstance (event , _RequestStarted ):
120- logger .debug (
121- 'Consumer[%s]: Request started' , self .active_task ._task_id
122- )
123- self .message_to_save = event .request_context .message
124- elif isinstance (event , Message ):
125- self ._handle_message_event (event )
126- else :
127- updated_task = await self ._handle_task_event (event )
128- if isinstance (event , Task ):
129- event = updated_task
130-
131- if updated_task is not None :
132- await self ._update_task_state (updated_task , event )
133- self .active_task ._task_created .set ()
156+ if updated_task is not None and handled_event is not None :
157+ await self ._update_task_state (updated_task , handled_event )
158+ self .active_task ._task_created .set ()
134159
135- finally :
136- await self ._enqueue_to_subscribers (event , updated_task )
160+ await self ._enqueue_to_subscribers (event , updated_task )
137161
138162 def _handle_message_event (self , event : Message ) -> None :
139163 if self .task_mode is True :
@@ -286,9 +310,6 @@ class ActiveTask:
286310 - `self._lock` (asyncio.Lock) ensures mutually exclusive access for critical
287311 lifecycle state changes, such as starting the task, subscribing, and
288312 determining if cleanup is safe to trigger.
289-
290- mutation to the observable result state (like `_exception`,
291- or `_is_finished`) notifies waiting coroutines (like `wait()`).
292313 - `self._is_finished` (asyncio.Event) provides a thread-safe, non-blocking way
293314 for external observers and internal loops to check if the ActiveTask has
294315 permanently ceased execution and closed its queues.
@@ -349,10 +370,6 @@ def __init__(
349370 # Protected by `_lock`.
350371 self ._reference_count = 0
351372
352- # Holds any fatal exception that crashed the producer or consumer.
353- # TODO: Synchronize exception handling (ideally mix it in the queue).
354- self ._exception : Exception | None = None
355-
356373 # Queue for incoming requests
357374 self ._request_queue : AsyncQueue [tuple [RequestContext , uuid .UUID ]] = (
358375 _create_async_queue ()
@@ -481,22 +498,17 @@ async def _run_producer(self) -> None:
481498 _RequestStarted (request_id , request_context ),
482499 )
483500 )
484-
485501 await self ._agent_executor .execute (
486502 request_context , self ._event_queue_agent
487503 )
488504 logger .debug (
489505 'Producer[%s]: Execution finished successfully' ,
490506 self ._task_id ,
491507 )
492- finally :
493- logger .debug (
494- 'Producer[%s]: Enqueuing request completed event' ,
495- self ._task_id ,
496- )
497508 await self ._event_queue_agent .enqueue_event (
498509 cast ('Event' , _RequestCompleted (request_id ))
499510 )
511+ finally :
500512 self ._request_queue .task_done ()
501513 except asyncio .CancelledError :
502514 logger .debug ('Producer[%s]: Cancelled' , self ._task_id )
@@ -516,8 +528,7 @@ async def _run_producer(self) -> None:
516528 request_context .context_id or '' ,
517529 )
518530 self ._task_created .set ()
519- async with self ._lock :
520- await self ._mark_task_as_failed (e )
531+ await self ._event_queue_agent .enqueue_event (cast ('Event' , e ))
521532
522533 finally :
523534 self ._request_queue .shutdown (immediate = True )
@@ -537,7 +548,7 @@ async def _run_consumer(self) -> None:
537548 logger .debug ('Consumer[%s]: Finishing' , self ._task_id )
538549 await self ._maybe_cleanup ()
539550
540- async def subscribe ( # noqa: PLR0912, PLR0915
551+ async def subscribe (
541552 self ,
542553 * ,
543554 request : RequestContext | None = None ,
@@ -554,12 +565,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
554565 logger .debug ('Subscribe[%s]: New subscriber' , self ._task_id )
555566
556567 async with self ._lock :
557- if self ._exception :
558- logger .debug (
559- 'Subscribe[%s]: Failed, exception already set' ,
560- self ._task_id ,
561- )
562- raise self ._exception
563568 if self ._is_finished .is_set ():
564569 raise InvalidParamsError (
565570 f'Task { self ._task_id } is already completed.'
@@ -585,17 +590,23 @@ async def subscribe( # noqa: PLR0912, PLR0915
585590
586591 while True :
587592 try :
588- if self ._exception :
589- raise self ._exception
590-
591593 dequeued = await tapped_queue .dequeue_event ()
592594 event , updated_task = cast ('Any' , dequeued )
593595 logger .debug (
594- 'Subscriber[%s]\n Dequeued event %s \ n Updated task %s\n ' ,
596+ 'Subscriber[%s] Dequeued event [%s]: \n %s \ n Updated task: \n %s\n ' ,
595597 self ._task_id ,
598+ type (event ).__name__ ,
596599 event ,
597600 updated_task ,
598601 )
602+ if isinstance (event , BaseException ):
603+ logger .debug (
604+ 'Subscriber[%s]: Raising exception: %s' ,
605+ self ._task_id ,
606+ event ,
607+ )
608+ raise event
609+
599610 if replace_status_update_with_task and isinstance (
600611 event , TaskStatusUpdateEvent
601612 ):
@@ -605,8 +616,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
605616 updated_task ,
606617 )
607618 event = updated_task
608- if self ._exception :
609- raise self ._exception from None
610619 if isinstance (event , _RequestCompleted ):
611620 if (
612621 request_id is not None
@@ -629,8 +638,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
629638 finally :
630639 tapped_queue .task_done ()
631640 except (QueueShutDown , asyncio .CancelledError ):
632- if self ._exception :
633- raise self ._exception from None
634641 break
635642 finally :
636643 logger .debug ('Subscribe[%s]: Unsubscribing' , self ._task_id )
@@ -714,9 +721,9 @@ async def _maybe_cleanup(self) -> None:
714721 logger .debug ('Cleanup[%s]: Triggering cleanup' , self ._task_id )
715722 self ._on_cleanup (self )
716723
717- async def _mark_task_as_failed (self , exception : Exception ) -> None :
718- if self . _exception is None :
719- self . _exception = exception
724+ async def _mark_task_as_failed (self , exception : Exception ) -> Task | None :
725+ logger . debug ( 'Marking task %s as failed: %s' , self . _task_id , exception )
726+ task = None
720727 if self ._task_created .is_set ():
721728 try :
722729 task = await self ._task_manager .get_task ()
@@ -732,10 +739,10 @@ async def _mark_task_as_failed(self, exception: Exception) -> None:
732739 )
733740 except QueueShutDown :
734741 pass
742+ return task
735743
736744 async def get_task (self ) -> Task :
737745 """Get task from db."""
738- # TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation).
739746 await self ._task_created .wait ()
740747 task = await self ._task_manager .get_task ()
741748 if not task :
0 commit comments