55import asyncio
66import concurrent .futures
77import json
8+ import threading
89from dataclasses import dataclass
910from typing import (
1011 Any ,
3233import temporalio .common
3334import temporalio .converter
3435import temporalio .nexus
35- from temporalio .exceptions import ApplicationError , WorkflowAlreadyStartedError
36+ from temporalio .exceptions import (
37+ ApplicationError ,
38+ WorkflowAlreadyStartedError ,
39+ )
3640from temporalio .nexus import Info , logger
3741from temporalio .service import RPCError , RPCStatusCode
3842
4145_TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure"
4246
4347
48+ @dataclass
49+ class _RunningNexusTask :
50+ task : asyncio .Task [Any ]
51+ cancellation : _NexusTaskCancellation
52+
53+ def cancel (self , reason : str ):
54+ self .cancellation .cancel (reason )
55+ self .task .cancel ()
56+
57+
4458class _NexusWorker :
4559 def __init__ (
4660 self ,
@@ -65,7 +79,7 @@ def __init__(
6579 self ._interceptors = interceptors
6680 # TODO(nexus-preview): metric_meter
6781 self ._metric_meter = metric_meter
68- self ._running_tasks : dict [bytes , asyncio . Task [ Any ] ] = {}
82+ self ._running_tasks : dict [bytes , _RunningNexusTask ] = {}
6983 self ._fail_worker_exception_queue : asyncio .Queue [Exception ] = asyncio .Queue ()
7084
7185 async def run (self ) -> None :
@@ -90,21 +104,31 @@ async def raise_from_exception_queue() -> NoReturn:
90104 if nexus_task .HasField ("task" ):
91105 task = nexus_task .task
92106 if task .request .HasField ("start_operation" ):
93- self ._running_tasks [task .task_token ] = asyncio .create_task (
107+ task_cancellation = _NexusTaskCancellation ()
108+ start_op_task = asyncio .create_task (
94109 self ._handle_start_operation_task (
95110 task .task_token ,
96111 task .request .start_operation ,
97112 dict (task .request .header ),
113+ task_cancellation ,
98114 )
99115 )
116+ self ._running_tasks [task .task_token ] = _RunningNexusTask (
117+ start_op_task , task_cancellation
118+ )
100119 elif task .request .HasField ("cancel_operation" ):
101- self ._running_tasks [task .task_token ] = asyncio .create_task (
120+ task_cancellation = _NexusTaskCancellation ()
121+ cancel_op_task = asyncio .create_task (
102122 self ._handle_cancel_operation_task (
103123 task .task_token ,
104124 task .request .cancel_operation ,
105125 dict (task .request .header ),
126+ task_cancellation ,
106127 )
107128 )
129+ self ._running_tasks [task .task_token ] = _RunningNexusTask (
130+ cancel_op_task , task_cancellation
131+ )
108132 else :
109133 raise NotImplementedError (
110134 f"Invalid Nexus task request: { task .request } "
@@ -113,8 +137,12 @@ async def raise_from_exception_queue() -> NoReturn:
113137 if running_task := self ._running_tasks .get (
114138 nexus_task .cancel_task .task_token
115139 ):
116- # TODO(nexus-prerelease): when do we remove the entry from _running_operations?
117- running_task .cancel ()
140+ reason = (
141+ temporalio .bridge .proto .nexus .NexusTaskCancelReason .Name (
142+ nexus_task .cancel_task .reason
143+ )
144+ )
145+ running_task .cancel (reason )
118146 else :
119147 logger .debug (
120148 f"Received cancel_task but no running task exists for "
@@ -147,7 +175,10 @@ async def drain_poll_queue(self) -> None:
147175 # Only call this after run()/drain_poll_queue() have returned. This will not
148176 # raise an exception.
149177 async def wait_all_completed (self ) -> None :
150- await asyncio .gather (* self ._running_tasks .values (), return_exceptions = True )
178+ running_tasks = [
179+ running_task .task for running_task in self ._running_tasks .values ()
180+ ]
181+ await asyncio .gather (* running_tasks , return_exceptions = True )
151182
152183 # TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute
153184 # "Any call up to this function and including this one will be trimmed out of stack traces.""
@@ -157,6 +188,7 @@ async def _handle_cancel_operation_task(
157188 task_token : bytes ,
158189 request : temporalio .api .nexus .v1 .CancelOperationRequest ,
159190 headers : Mapping [str , str ],
191+ task_cancellation : nexusrpc .handler .OperationTaskCancellation ,
160192 ) -> None :
161193 """Handle a cancel operation task.
162194
@@ -168,6 +200,7 @@ async def _handle_cancel_operation_task(
168200 service = request .service ,
169201 operation = request .operation ,
170202 headers = headers ,
203+ task_cancellation = task_cancellation ,
171204 )
172205 temporalio .nexus ._operation_context ._TemporalCancelOperationContext (
173206 info = lambda : Info (task_queue = self ._task_queue ),
@@ -177,6 +210,11 @@ async def _handle_cancel_operation_task(
177210 try :
178211 try :
179212 await self ._handler .cancel_operation (ctx , request .operation_token )
213+ except asyncio .CancelledError :
214+ completion = temporalio .bridge .proto .nexus .NexusTaskCompletion (
215+ task_token = task_token ,
216+ ack_cancel = task_cancellation .is_cancelled (),
217+ )
180218 except BaseException as err :
181219 logger .warning ("Failed to execute Nexus cancel operation method" )
182220 completion = temporalio .bridge .proto .nexus .NexusTaskCompletion (
@@ -209,6 +247,7 @@ async def _handle_start_operation_task(
209247 task_token : bytes ,
210248 start_request : temporalio .api .nexus .v1 .StartOperationRequest ,
211249 headers : Mapping [str , str ],
250+ task_cancellation : nexusrpc .handler .OperationTaskCancellation ,
212251 ) -> None :
213252 """Handle a start operation task.
214253
@@ -217,7 +256,14 @@ async def _handle_start_operation_task(
217256 """
218257 try :
219258 try :
220- start_response = await self ._start_operation (start_request , headers )
259+ start_response = await self ._start_operation (
260+ start_request , headers , task_cancellation
261+ )
262+ except asyncio .CancelledError :
263+ completion = temporalio .bridge .proto .nexus .NexusTaskCompletion (
264+ task_token = task_token ,
265+ ack_cancel = task_cancellation .is_cancelled (),
266+ )
221267 except BaseException as err :
222268 logger .warning ("Failed to execute Nexus start operation method" )
223269 completion = temporalio .bridge .proto .nexus .NexusTaskCompletion (
@@ -226,6 +272,7 @@ async def _handle_start_operation_task(
226272 _exception_to_handler_error (err )
227273 ),
228274 )
275+
229276 if isinstance (err , concurrent .futures .BrokenExecutor ):
230277 self ._fail_worker_exception_queue .put_nowait (err )
231278 else :
@@ -235,6 +282,7 @@ async def _handle_start_operation_task(
235282 start_operation = start_response
236283 ),
237284 )
285+
238286 await self ._bridge_worker ().complete_nexus_task (completion )
239287 except Exception :
240288 logger .exception ("Failed to send Nexus task completion" )
@@ -250,6 +298,7 @@ async def _start_operation(
250298 self ,
251299 start_request : temporalio .api .nexus .v1 .StartOperationRequest ,
252300 headers : Mapping [str , str ],
301+ cancellation : nexusrpc .handler .OperationTaskCancellation ,
253302 ) -> temporalio .api .nexus .v1 .StartOperationResponse :
254303 """Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.
255304
@@ -268,6 +317,7 @@ async def _start_operation(
268317 for link in start_request .links
269318 ],
270319 callback_headers = dict (start_request .callback_header ),
320+ task_cancellation = cancellation ,
271321 )
272322 temporalio .nexus ._operation_context ._TemporalStartOperationContext (
273323 nexus_context = ctx ,
@@ -517,3 +567,33 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError:
517567 )
518568 handler_err .__cause__ = err
519569 return handler_err
570+
571+
572+ class _NexusTaskCancellation (nexusrpc .handler .OperationTaskCancellation ):
573+ def __init__ (self ):
574+ self ._thread_evt = threading .Event ()
575+ self ._async_evt = asyncio .Event ()
576+ self ._lock = threading .Lock ()
577+ self ._reason : Optional [str ] = None
578+
579+ def is_cancelled (self ) -> bool :
580+ return self ._thread_evt .is_set ()
581+
582+ def cancellation_reason (self ) -> Optional [str ]:
583+ with self ._lock :
584+ return self ._reason
585+
586+ def wait_until_cancelled_sync (self , timeout : float | None = None ) -> bool :
587+ return self ._thread_evt .wait (timeout )
588+
589+ async def wait_until_cancelled (self ) -> None :
590+ await self ._async_evt .wait ()
591+
592+ def cancel (self , reason : str ) -> bool :
593+ with self ._lock :
594+ if self ._thread_evt .is_set ():
595+ return False
596+ self ._reason = reason
597+ self ._thread_evt .set ()
598+ self ._async_evt .set ()
599+ return True
0 commit comments