Skip to content

Commit 34eaed6

Browse files
Abandon workitems (#72)
* Initial work - not complete * Fix for crash in worker run loop * Log clarity * Revert comment * Linting fix * Fix tests * Feedback, fix tests * Fix tests --------- Co-authored-by: Bernd Verst <[email protected]>
1 parent 5b453ed commit 34eaed6

File tree

3 files changed

+160
-36
lines changed

3 files changed

+160
-36
lines changed

durabletask/worker.py

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def __init__(
346346
else:
347347
self._interceptors = None
348348

349-
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
349+
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
350350

351351
@property
352352
def concurrency_options(self) -> ConcurrencyOptions:
@@ -533,27 +533,31 @@ def stream_reader():
533533
if work_item.HasField("orchestratorRequest"):
534534
self._async_worker_manager.submit_orchestration(
535535
self._execute_orchestrator,
536+
self._cancel_orchestrator,
536537
work_item.orchestratorRequest,
537538
stub,
538539
work_item.completionToken,
539540
)
540541
elif work_item.HasField("activityRequest"):
541542
self._async_worker_manager.submit_activity(
542543
self._execute_activity,
544+
self._cancel_activity,
543545
work_item.activityRequest,
544546
stub,
545547
work_item.completionToken,
546548
)
547549
elif work_item.HasField("entityRequest"):
548550
self._async_worker_manager.submit_entity_batch(
549551
self._execute_entity_batch,
552+
self._cancel_entity_batch,
550553
work_item.entityRequest,
551554
stub,
552555
work_item.completionToken,
553556
)
554557
elif work_item.HasField("entityRequestV2"):
555558
self._async_worker_manager.submit_entity_batch(
556559
self._execute_entity_batch,
560+
self._cancel_entity_batch,
557561
work_item.entityRequestV2,
558562
stub,
559563
work_item.completionToken
@@ -670,6 +674,19 @@ def _execute_orchestrator(
670674
f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
671675
)
672676

677+
def _cancel_orchestrator(
678+
self,
679+
req: pb.OrchestratorRequest,
680+
stub: stubs.TaskHubSidecarServiceStub,
681+
completionToken,
682+
):
683+
stub.AbandonTaskOrchestratorWorkItem(
684+
pb.AbandonOrchestrationTaskRequest(
685+
completionToken=completionToken
686+
)
687+
)
688+
self._logger.info(f"Cancelled orchestration task for invocation ID: {req.instanceId}")
689+
673690
def _execute_activity(
674691
self,
675692
req: pb.ActivityRequest,
@@ -703,6 +720,19 @@ def _execute_activity(
703720
f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
704721
)
705722

723+
def _cancel_activity(
724+
self,
725+
req: pb.ActivityRequest,
726+
stub: stubs.TaskHubSidecarServiceStub,
727+
completionToken,
728+
):
729+
stub.AbandonTaskActivityWorkItem(
730+
pb.AbandonActivityTaskRequest(
731+
completionToken=completionToken
732+
)
733+
)
734+
self._logger.info(f"Cancelled activity task for task ID: {req.taskId} on orchestration ID: {req.orchestrationInstance.instanceId}")
735+
706736
def _execute_entity_batch(
707737
self,
708738
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
@@ -771,6 +801,19 @@ def _execute_entity_batch(
771801

772802
return batch_result
773803

804+
def _cancel_entity_batch(
805+
self,
806+
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
807+
stub: stubs.TaskHubSidecarServiceStub,
808+
completionToken,
809+
):
810+
stub.AbandonTaskEntityWorkItem(
811+
pb.AbandonEntityTaskRequest(
812+
completionToken=completionToken
813+
)
814+
)
815+
self._logger.info(f"Cancelled entity batch task for instance ID: {req.instanceId}")
816+
774817

775818
class _RuntimeOrchestrationContext(task.OrchestrationContext):
776819
_generator: Optional[Generator[task.Task, Any, Any]]
@@ -1933,8 +1976,10 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:
19331976

19341977

19351978
class _AsyncWorkerManager:
1936-
def __init__(self, concurrency_options: ConcurrencyOptions):
1979+
def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
19371980
self.concurrency_options = concurrency_options
1981+
self._logger = logger
1982+
19381983
self.activity_semaphore = None
19391984
self.orchestration_semaphore = None
19401985
self.entity_semaphore = None
@@ -2044,17 +2089,51 @@ async def run(self):
20442089
)
20452090

20462091
# Start background consumers for each work type
2047-
if self.activity_queue is not None and self.orchestration_queue is not None \
2048-
and self.entity_batch_queue is not None:
2049-
await asyncio.gather(
2050-
self._consume_queue(self.activity_queue, self.activity_semaphore),
2051-
self._consume_queue(
2052-
self.orchestration_queue, self.orchestration_semaphore
2053-
),
2054-
self._consume_queue(
2055-
self.entity_batch_queue, self.entity_semaphore
2092+
try:
2093+
if self.activity_queue is not None and self.orchestration_queue is not None \
2094+
and self.entity_batch_queue is not None:
2095+
await asyncio.gather(
2096+
self._consume_queue(self.activity_queue, self.activity_semaphore),
2097+
self._consume_queue(
2098+
self.orchestration_queue, self.orchestration_semaphore
2099+
),
2100+
self._consume_queue(
2101+
self.entity_batch_queue, self.entity_semaphore
2102+
)
20562103
)
2057-
)
2104+
except Exception as queue_exception:
2105+
self._logger.error(f"Shutting down worker - Uncaught error in worker manager: {queue_exception}")
2106+
while self.activity_queue is not None and not self.activity_queue.empty():
2107+
try:
2108+
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2109+
await self._run_func(cancellation_func, *args, **kwargs)
2110+
self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}")
2111+
except asyncio.QueueEmpty:
2112+
# Queue was empty, no cancellation needed
2113+
pass
2114+
except Exception as cancellation_exception:
2115+
self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}")
2116+
while self.orchestration_queue is not None and not self.orchestration_queue.empty():
2117+
try:
2118+
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2119+
await self._run_func(cancellation_func, *args, **kwargs)
2120+
self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}")
2121+
except asyncio.QueueEmpty:
2122+
# Queue was empty, no cancellation needed
2123+
pass
2124+
except Exception as cancellation_exception:
2125+
self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}")
2126+
while self.entity_batch_queue is not None and not self.entity_batch_queue.empty():
2127+
try:
2128+
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2129+
await self._run_func(cancellation_func, *args, **kwargs)
2130+
self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}")
2131+
except asyncio.QueueEmpty:
2132+
# Queue was empty, no cancellation needed
2133+
pass
2134+
except Exception as cancellation_exception:
2135+
self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
2136+
self.shutdown()
20582137

20592138
async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
20602139
# List to track running tasks
@@ -2074,19 +2153,22 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor
20742153
except asyncio.TimeoutError:
20752154
continue
20762155

2077-
func, args, kwargs = work
2156+
func, cancellation_func, args, kwargs = work
20782157
# Create a concurrent task for processing
20792158
task = asyncio.create_task(
2080-
self._process_work_item(semaphore, queue, func, args, kwargs)
2159+
self._process_work_item(semaphore, queue, func, cancellation_func, args, kwargs)
20812160
)
20822161
running_tasks.add(task)
20832162

20842163
async def _process_work_item(
2085-
self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs
2164+
self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, cancellation_func, args, kwargs
20862165
):
20872166
async with semaphore:
20882167
try:
20892168
await self._run_func(func, *args, **kwargs)
2169+
except Exception as work_exception:
2170+
self._logger.error(f"Uncaught error while processing work item, item will be abandoned: {work_exception}")
2171+
await self._run_func(cancellation_func, *args, **kwargs)
20902172
finally:
20912173
queue.task_done()
20922174

@@ -2105,26 +2187,32 @@ async def _run_func(self, func, *args, **kwargs):
21052187
self.thread_pool, lambda: func(*args, **kwargs)
21062188
)
21072189

2108-
def submit_activity(self, func, *args, **kwargs):
2109-
work_item = (func, args, kwargs)
2190+
def submit_activity(self, func, cancellation_func, *args, **kwargs):
2191+
if self._shutdown:
2192+
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2193+
work_item = (func, cancellation_func, args, kwargs)
21102194
self._ensure_queues_for_current_loop()
21112195
if self.activity_queue is not None:
21122196
self.activity_queue.put_nowait(work_item)
21132197
else:
21142198
# No event loop running, store in pending list
21152199
self._pending_activity_work.append(work_item)
21162200

2117-
def submit_orchestration(self, func, *args, **kwargs):
2118-
work_item = (func, args, kwargs)
2201+
def submit_orchestration(self, func, cancellation_func, *args, **kwargs):
2202+
if self._shutdown:
2203+
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2204+
work_item = (func, cancellation_func, args, kwargs)
21192205
self._ensure_queues_for_current_loop()
21202206
if self.orchestration_queue is not None:
21212207
self.orchestration_queue.put_nowait(work_item)
21222208
else:
21232209
# No event loop running, store in pending list
21242210
self._pending_orchestration_work.append(work_item)
21252211

2126-
def submit_entity_batch(self, func, *args, **kwargs):
2127-
work_item = (func, args, kwargs)
2212+
def submit_entity_batch(self, func, cancellation_func, *args, **kwargs):
2213+
if self._shutdown:
2214+
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2215+
work_item = (func, cancellation_func, args, kwargs)
21282216
self._ensure_queues_for_current_loop()
21292217
if self.entity_batch_queue is not None:
21302218
self.entity_batch_queue.put_nowait(work_item)
@@ -2136,7 +2224,7 @@ def shutdown(self):
21362224
self._shutdown = True
21372225
self.thread_pool.shutdown(wait=True)
21382226

2139-
def reset_for_new_run(self):
2227+
async def reset_for_new_run(self):
21402228
"""Reset the manager state for a new run."""
21412229
self._shutdown = False
21422230
# Clear any existing queues - they'll be recreated when needed
@@ -2145,18 +2233,28 @@ def reset_for_new_run(self):
21452233
# This ensures no items from previous runs remain
21462234
try:
21472235
while not self.activity_queue.empty():
2148-
self.activity_queue.get_nowait()
2149-
except Exception:
2150-
pass
2236+
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2237+
await self._run_func(cancellation_func, *args, **kwargs)
2238+
except Exception as reset_exception:
2239+
self._logger.warning(f"Error while clearing activity queue during reset: {reset_exception}")
21512240
if self.orchestration_queue is not None:
21522241
try:
21532242
while not self.orchestration_queue.empty():
2154-
self.orchestration_queue.get_nowait()
2155-
except Exception:
2156-
pass
2243+
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2244+
await self._run_func(cancellation_func, *args, **kwargs)
2245+
except Exception as reset_exception:
2246+
self._logger.warning(f"Error while clearing orchestration queue during reset: {reset_exception}")
2247+
if self.entity_batch_queue is not None:
2248+
try:
2249+
while not self.entity_batch_queue.empty():
2250+
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2251+
await self._run_func(cancellation_func, *args, **kwargs)
2252+
except Exception as reset_exception:
2253+
self._logger.warning(f"Error while clearing entity queue during reset: {reset_exception}")
21572254
# Clear pending work lists
21582255
self._pending_activity_work.clear()
21592256
self._pending_orchestration_work.clear()
2257+
self._pending_entity_batch_work.clear()
21602258

21612259

21622260
# Export public API

tests/durabletask/test_worker_concurrency_loop.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,21 @@ def dummy_orchestrator(req, stub, completionToken):
5252
time.sleep(0.1)
5353
stub.CompleteOrchestratorTask('ok')
5454

55+
def cancel_dummy_orchestrator(req, stub, completionToken):
56+
pass
57+
5558
def dummy_activity(req, stub, completionToken):
5659
time.sleep(0.1)
5760
stub.CompleteActivityTask('ok')
5861

62+
def cancel_dummy_activity(req, stub, completionToken):
63+
pass
64+
5965
# Patch the worker's _execute_orchestrator and _execute_activity
6066
worker._execute_orchestrator = dummy_orchestrator
67+
worker._cancel_orchestrator = cancel_dummy_orchestrator
6168
worker._execute_activity = dummy_activity
69+
worker._cancel_activity = cancel_dummy_activity
6270

6371
orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
6472
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
@@ -67,9 +75,9 @@ async def run_test():
6775
# Start the worker manager's run loop in the background
6876
worker_task = asyncio.create_task(worker._async_worker_manager.run())
6977
for req in orchestrator_requests:
70-
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
78+
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
7179
for req in activity_requests:
72-
worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
80+
worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
7381
await asyncio.sleep(1.0)
7482
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
7583
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')
@@ -120,8 +128,8 @@ def fn(*args, **kwargs):
120128

121129
# Submit more work than concurrency allows
122130
for i in range(5):
123-
manager.submit_orchestration(make_work("orch", i))
124-
manager.submit_activity(make_work("act", i))
131+
manager.submit_orchestration(make_work("orch", i), lambda *a, **k: None)
132+
manager.submit_activity(make_work("act", i), lambda *a, **k: None)
125133

126134
# Run the manager loop in a thread (sync context)
127135
def run_manager():
@@ -131,6 +139,11 @@ def run_manager():
131139
t.start()
132140
time.sleep(1.5) # Let work process
133141
manager.shutdown()
142+
143+
# Ensure the queues have been started
144+
if (manager.activity_queue is None or manager.orchestration_queue is None):
145+
raise RuntimeError("Worker manager queues not initialized")
146+
134147
# Unblock the consumers by putting dummy items in the queues
135148
manager.activity_queue.put_nowait((lambda: None, (), {}))
136149
manager.orchestration_queue.put_nowait((lambda: None, (), {}))

tests/durabletask/test_worker_concurrency_loop_async.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,21 @@ async def dummy_orchestrator(req, stub, completionToken):
5050
await asyncio.sleep(0.1)
5151
stub.CompleteOrchestratorTask('ok')
5252

53+
async def cancel_dummy_orchestrator(req, stub, completionToken):
54+
pass
55+
5356
async def dummy_activity(req, stub, completionToken):
5457
await asyncio.sleep(0.1)
5558
stub.CompleteActivityTask('ok')
5659

60+
async def cancel_dummy_activity(req, stub, completionToken):
61+
pass
62+
5763
# Patch the worker's _execute_orchestrator and _execute_activity
58-
grpc_worker._execute_orchestrator = dummy_orchestrator
59-
grpc_worker._execute_activity = dummy_activity
64+
grpc_worker._execute_orchestrator = dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker)
65+
grpc_worker._cancel_orchestrator = cancel_dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker)
66+
grpc_worker._execute_activity = dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker)
67+
grpc_worker._cancel_activity = cancel_dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker)
6068

6169
orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
6270
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
@@ -65,10 +73,15 @@ async def run_test():
6573
# Clear stub state before each run
6674
stub.completed.clear()
6775
worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run())
76+
# Need to yield to that thread in order to let it start up on the second run
77+
startup_attempts = 0
78+
while grpc_worker._async_worker_manager._shutdown and startup_attempts < 10:
79+
await asyncio.sleep(0.1)
80+
startup_attempts += 1
6881
for req in orchestrator_requests:
69-
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
82+
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
7083
for req in activity_requests:
71-
grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
84+
grpc_worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
7285
await asyncio.sleep(1.0)
7386
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
7487
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')

0 commit comments

Comments
 (0)