@@ -346,7 +346,7 @@ def __init__(
346346 else :
347347 self ._interceptors = None
348348
349- self ._async_worker_manager = _AsyncWorkerManager (self ._concurrency_options )
349+ self ._async_worker_manager = _AsyncWorkerManager (self ._concurrency_options , self . _logger )
350350
351351 @property
352352 def concurrency_options (self ) -> ConcurrencyOptions :
@@ -533,27 +533,31 @@ def stream_reader():
533533 if work_item .HasField ("orchestratorRequest" ):
534534 self ._async_worker_manager .submit_orchestration (
535535 self ._execute_orchestrator ,
536+ self ._cancel_orchestrator ,
536537 work_item .orchestratorRequest ,
537538 stub ,
538539 work_item .completionToken ,
539540 )
540541 elif work_item .HasField ("activityRequest" ):
541542 self ._async_worker_manager .submit_activity (
542543 self ._execute_activity ,
544+ self ._cancel_activity ,
543545 work_item .activityRequest ,
544546 stub ,
545547 work_item .completionToken ,
546548 )
547549 elif work_item .HasField ("entityRequest" ):
548550 self ._async_worker_manager .submit_entity_batch (
549551 self ._execute_entity_batch ,
552+ self ._cancel_entity_batch ,
550553 work_item .entityRequest ,
551554 stub ,
552555 work_item .completionToken ,
553556 )
554557 elif work_item .HasField ("entityRequestV2" ):
555558 self ._async_worker_manager .submit_entity_batch (
556559 self ._execute_entity_batch ,
560+ self ._cancel_entity_batch ,
557561 work_item .entityRequestV2 ,
558562 stub ,
559563 work_item .completionToken
@@ -670,6 +674,19 @@ def _execute_orchestrator(
670674 f"Failed to deliver orchestrator response for '{ req .instanceId } ' to sidecar: { ex } "
671675 )
672676
677+ def _cancel_orchestrator (
678+ self ,
679+ req : pb .OrchestratorRequest ,
680+ stub : stubs .TaskHubSidecarServiceStub ,
681+ completionToken ,
682+ ):
683+ stub .AbandonTaskOrchestratorWorkItem (
684+ pb .AbandonOrchestrationTaskRequest (
685+ completionToken = completionToken
686+ )
687+ )
688+ self ._logger .info (f"Cancelled orchestration task for invocation ID: { req .instanceId } " )
689+
673690 def _execute_activity (
674691 self ,
675692 req : pb .ActivityRequest ,
@@ -703,6 +720,19 @@ def _execute_activity(
703720 f"Failed to deliver activity response for '{ req .name } #{ req .taskId } ' of orchestration ID '{ instance_id } ' to sidecar: { ex } "
704721 )
705722
723+ def _cancel_activity (
724+ self ,
725+ req : pb .ActivityRequest ,
726+ stub : stubs .TaskHubSidecarServiceStub ,
727+ completionToken ,
728+ ):
729+ stub .AbandonTaskActivityWorkItem (
730+ pb .AbandonActivityTaskRequest (
731+ completionToken = completionToken
732+ )
733+ )
734+ self ._logger .info (f"Cancelled activity task for task ID: { req .taskId } on orchestration ID: { req .orchestrationInstance .instanceId } " )
735+
706736 def _execute_entity_batch (
707737 self ,
708738 req : Union [pb .EntityBatchRequest , pb .EntityRequest ],
@@ -771,6 +801,19 @@ def _execute_entity_batch(
771801
772802 return batch_result
773803
804+ def _cancel_entity_batch (
805+ self ,
806+ req : Union [pb .EntityBatchRequest , pb .EntityRequest ],
807+ stub : stubs .TaskHubSidecarServiceStub ,
808+ completionToken ,
809+ ):
810+ stub .AbandonTaskEntityWorkItem (
811+ pb .AbandonEntityTaskRequest (
812+ completionToken = completionToken
813+ )
814+ )
815+ self ._logger .info (f"Cancelled entity batch task for instance ID: { req .instanceId } " )
816+
774817
775818class _RuntimeOrchestrationContext (task .OrchestrationContext ):
776819 _generator : Optional [Generator [task .Task , Any , Any ]]
@@ -1933,8 +1976,10 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:
19331976
19341977
19351978class _AsyncWorkerManager :
1936- def __init__ (self , concurrency_options : ConcurrencyOptions ):
1979+ def __init__ (self , concurrency_options : ConcurrencyOptions , logger : logging . Logger ):
19371980 self .concurrency_options = concurrency_options
1981+ self ._logger = logger
1982+
19381983 self .activity_semaphore = None
19391984 self .orchestration_semaphore = None
19401985 self .entity_semaphore = None
@@ -2044,17 +2089,51 @@ async def run(self):
20442089 )
20452090
20462091 # Start background consumers for each work type
2047- if self .activity_queue is not None and self .orchestration_queue is not None \
2048- and self .entity_batch_queue is not None :
2049- await asyncio .gather (
2050- self ._consume_queue (self .activity_queue , self .activity_semaphore ),
2051- self ._consume_queue (
2052- self .orchestration_queue , self .orchestration_semaphore
2053- ),
2054- self ._consume_queue (
2055- self .entity_batch_queue , self .entity_semaphore
2092+ try :
2093+ if self .activity_queue is not None and self .orchestration_queue is not None \
2094+ and self .entity_batch_queue is not None :
2095+ await asyncio .gather (
2096+ self ._consume_queue (self .activity_queue , self .activity_semaphore ),
2097+ self ._consume_queue (
2098+ self .orchestration_queue , self .orchestration_semaphore
2099+ ),
2100+ self ._consume_queue (
2101+ self .entity_batch_queue , self .entity_semaphore
2102+ )
20562103 )
2057- )
2104+ except Exception as queue_exception :
2105+ self ._logger .error (f"Shutting down worker - Uncaught error in worker manager: { queue_exception } " )
2106+ while self .activity_queue is not None and not self .activity_queue .empty ():
2107+ try :
2108+ func , cancellation_func , args , kwargs = self .activity_queue .get_nowait ()
2109+ await self ._run_func (cancellation_func , * args , ** kwargs )
2110+ self ._logger .error (f"Activity work item args: { args } , kwargs: { kwargs } " )
2111+ except asyncio .QueueEmpty :
2112+ # Queue was empty, no cancellation needed
2113+ pass
2114+ except Exception as cancellation_exception :
2115+ self ._logger .error (f"Uncaught error while cancelling activity work item: { cancellation_exception } " )
2116+ while self .orchestration_queue is not None and not self .orchestration_queue .empty ():
2117+ try :
2118+ func , cancellation_func , args , kwargs = self .orchestration_queue .get_nowait ()
2119+ await self ._run_func (cancellation_func , * args , ** kwargs )
2120+ self ._logger .error (f"Orchestration work item args: { args } , kwargs: { kwargs } " )
2121+ except asyncio .QueueEmpty :
2122+ # Queue was empty, no cancellation needed
2123+ pass
2124+ except Exception as cancellation_exception :
2125+ self ._logger .error (f"Uncaught error while cancelling orchestration work item: { cancellation_exception } " )
2126+ while self .entity_batch_queue is not None and not self .entity_batch_queue .empty ():
2127+ try :
2128+ func , cancellation_func , args , kwargs = self .entity_batch_queue .get_nowait ()
2129+ await self ._run_func (cancellation_func , * args , ** kwargs )
2130+ self ._logger .error (f"Entity batch work item args: { args } , kwargs: { kwargs } " )
2131+ except asyncio .QueueEmpty :
2132+ # Queue was empty, no cancellation needed
2133+ pass
2134+ except Exception as cancellation_exception :
2135+ self ._logger .error (f"Uncaught error while cancelling entity batch work item: { cancellation_exception } " )
2136+ self .shutdown ()
20582137
20592138 async def _consume_queue (self , queue : asyncio .Queue , semaphore : asyncio .Semaphore ):
20602139 # List to track running tasks
@@ -2074,19 +2153,22 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor
20742153 except asyncio .TimeoutError :
20752154 continue
20762155
2077- func , args , kwargs = work
2156+ func , cancellation_func , args , kwargs = work
20782157 # Create a concurrent task for processing
20792158 task = asyncio .create_task (
2080- self ._process_work_item (semaphore , queue , func , args , kwargs )
2159+ self ._process_work_item (semaphore , queue , func , cancellation_func , args , kwargs )
20812160 )
20822161 running_tasks .add (task )
20832162
20842163 async def _process_work_item (
2085- self , semaphore : asyncio .Semaphore , queue : asyncio .Queue , func , args , kwargs
2164+ self , semaphore : asyncio .Semaphore , queue : asyncio .Queue , func , cancellation_func , args , kwargs
20862165 ):
20872166 async with semaphore :
20882167 try :
20892168 await self ._run_func (func , * args , ** kwargs )
2169+ except Exception as work_exception :
2170+ self ._logger .error (f"Uncaught error while processing work item, item will be abandoned: { work_exception } " )
2171+ await self ._run_func (cancellation_func , * args , ** kwargs )
20902172 finally :
20912173 queue .task_done ()
20922174
@@ -2105,26 +2187,32 @@ async def _run_func(self, func, *args, **kwargs):
21052187 self .thread_pool , lambda : func (* args , ** kwargs )
21062188 )
21072189
2108- def submit_activity (self , func , * args , ** kwargs ):
2109- work_item = (func , args , kwargs )
2190+ def submit_activity (self , func , cancellation_func , * args , ** kwargs ):
2191+ if self ._shutdown :
2192+ raise RuntimeError ("Cannot submit new work items after shutdown has been initiated." )
2193+ work_item = (func , cancellation_func , args , kwargs )
21102194 self ._ensure_queues_for_current_loop ()
21112195 if self .activity_queue is not None :
21122196 self .activity_queue .put_nowait (work_item )
21132197 else :
21142198 # No event loop running, store in pending list
21152199 self ._pending_activity_work .append (work_item )
21162200
2117- def submit_orchestration (self , func , * args , ** kwargs ):
2118- work_item = (func , args , kwargs )
2201+ def submit_orchestration (self , func , cancellation_func , * args , ** kwargs ):
2202+ if self ._shutdown :
2203+ raise RuntimeError ("Cannot submit new work items after shutdown has been initiated." )
2204+ work_item = (func , cancellation_func , args , kwargs )
21192205 self ._ensure_queues_for_current_loop ()
21202206 if self .orchestration_queue is not None :
21212207 self .orchestration_queue .put_nowait (work_item )
21222208 else :
21232209 # No event loop running, store in pending list
21242210 self ._pending_orchestration_work .append (work_item )
21252211
2126- def submit_entity_batch (self , func , * args , ** kwargs ):
2127- work_item = (func , args , kwargs )
2212+ def submit_entity_batch (self , func , cancellation_func , * args , ** kwargs ):
2213+ if self ._shutdown :
2214+ raise RuntimeError ("Cannot submit new work items after shutdown has been initiated." )
2215+ work_item = (func , cancellation_func , args , kwargs )
21282216 self ._ensure_queues_for_current_loop ()
21292217 if self .entity_batch_queue is not None :
21302218 self .entity_batch_queue .put_nowait (work_item )
@@ -2136,7 +2224,7 @@ def shutdown(self):
21362224 self ._shutdown = True
21372225 self .thread_pool .shutdown (wait = True )
21382226
2139- def reset_for_new_run (self ):
2227+ async def reset_for_new_run (self ):
21402228 """Reset the manager state for a new run."""
21412229 self ._shutdown = False
21422230 # Clear any existing queues - they'll be recreated when needed
@@ -2145,18 +2233,28 @@ def reset_for_new_run(self):
21452233 # This ensures no items from previous runs remain
21462234 try :
21472235 while not self .activity_queue .empty ():
2148- self .activity_queue .get_nowait ()
2149- except Exception :
2150- pass
2236+ func , cancellation_func , args , kwargs = self .activity_queue .get_nowait ()
2237+ await self ._run_func (cancellation_func , * args , ** kwargs )
2238+ except Exception as reset_exception :
2239+ self ._logger .warning (f"Error while clearing activity queue during reset: { reset_exception } " )
21512240 if self .orchestration_queue is not None :
21522241 try :
21532242 while not self .orchestration_queue .empty ():
2154- self .orchestration_queue .get_nowait ()
2155- except Exception :
2156- pass
2243+ func , cancellation_func , args , kwargs = self .orchestration_queue .get_nowait ()
2244+ await self ._run_func (cancellation_func , * args , ** kwargs )
2245+ except Exception as reset_exception :
2246+ self ._logger .warning (f"Error while clearing orchestration queue during reset: { reset_exception } " )
2247+ if self .entity_batch_queue is not None :
2248+ try :
2249+ while not self .entity_batch_queue .empty ():
2250+ func , cancellation_func , args , kwargs = self .entity_batch_queue .get_nowait ()
2251+ await self ._run_func (cancellation_func , * args , ** kwargs )
2252+ except Exception as reset_exception :
2253+ self ._logger .warning (f"Error while clearing entity queue during reset: { reset_exception } " )
21572254 # Clear pending work lists
21582255 self ._pending_activity_work .clear ()
21592256 self ._pending_orchestration_work .clear ()
2257+ self ._pending_entity_batch_work .clear ()
21602258
21612259
21622260# Export public API
0 commit comments