Skip to content

Commit 5d1630d

Browse files
Nexus task cancellation (#1204)
* WIP: update to latest nexus patterns and try out task cancellation * fix up failing nexus test. Use deterministic as_completed in research_manager. Finish implementing nexus task cancellation * Apply formatting. Fix linter typing error. Remove some WIP elements that weren't relevant * Fix test to properly reference new cancellation details string * use threading.Event and logs to make sure test covers cancellation reason accuracy * remove unsued field in _NexusTaskCancellation * fix typo. require cancel reason. * Update nexus-rpc dependency to 1.2.0
1 parent ceb7058 commit 5d1630d

File tree

13 files changed

+224
-182
lines changed

13 files changed

+224
-182
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ license = "MIT"
99
license-files = ["LICENSE"]
1010
keywords = ["temporal", "workflow"]
1111
dependencies = [
12-
"nexus-rpc==1.1.0",
12+
"nexus-rpc==1.2.0",
1313
"protobuf>=3.20,<7.0.0",
1414
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
1515
"types-protobuf>=3.20",

temporalio/nexus/_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async def _start(
117117
return WorkflowRunOperationHandler(_start)
118118

119119
method_name = get_callable_name(start)
120-
nexusrpc.set_operation_definition(
120+
nexusrpc.set_operation(
121121
operation_handler_factory,
122122
nexusrpc.Operation(
123123
name=name or method_name,

temporalio/nexus/_operation_handlers.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@
1010
HandlerError,
1111
HandlerErrorType,
1212
InputT,
13-
OperationInfo,
1413
OutputT,
1514
)
1615
from nexusrpc.handler import (
1716
CancelOperationContext,
18-
FetchOperationInfoContext,
19-
FetchOperationResultContext,
2017
OperationHandler,
2118
StartOperationContext,
2219
StartOperationResultAsync,
@@ -81,22 +78,6 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
8178
"""Cancel the operation, by cancelling the workflow."""
8279
await _cancel_workflow(token)
8380

84-
async def fetch_info(
85-
self, ctx: FetchOperationInfoContext, token: str
86-
) -> OperationInfo:
87-
"""Fetch operation info (not supported for Temporal Nexus operations)."""
88-
raise NotImplementedError(
89-
"Temporal Nexus operation handlers do not support fetching operation info."
90-
)
91-
92-
async def fetch_result(
93-
self, ctx: FetchOperationResultContext, token: str
94-
) -> OutputT:
95-
"""Fetch operation result (not supported for Temporal Nexus operations)."""
96-
raise NotImplementedError(
97-
"Temporal Nexus operation handlers do not support fetching the operation result."
98-
)
99-
10081

10182
async def _cancel_workflow(
10283
token: str,

temporalio/nexus/_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def get_operation_factory(
129129
130130
``obj`` should be a decorated operation start method.
131131
"""
132-
op_defn = nexusrpc.get_operation_definition(obj)
132+
op_defn = nexusrpc.get_operation(obj)
133133
if op_defn:
134134
factory = obj
135135
else:
136136
if factory := getattr(obj, "__nexus_operation_factory__", None):
137-
op_defn = nexusrpc.get_operation_definition(factory)
137+
op_defn = nexusrpc.get_operation(factory)
138138
if not isinstance(op_defn, nexusrpc.Operation):
139139
return None, None
140140
return factory, op_defn

temporalio/worker/_nexus.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import concurrent.futures
77
import json
8+
import threading
89
from dataclasses import dataclass
910
from typing import (
1011
Any,
@@ -32,7 +33,10 @@
3233
import temporalio.common
3334
import temporalio.converter
3435
import temporalio.nexus
35-
from temporalio.exceptions import ApplicationError, WorkflowAlreadyStartedError
36+
from temporalio.exceptions import (
37+
ApplicationError,
38+
WorkflowAlreadyStartedError,
39+
)
3640
from temporalio.nexus import Info, logger
3741
from temporalio.service import RPCError, RPCStatusCode
3842

@@ -41,6 +45,16 @@
4145
_TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure"
4246

4347

48+
@dataclass
49+
class _RunningNexusTask:
50+
task: asyncio.Task[Any]
51+
cancellation: _NexusTaskCancellation
52+
53+
def cancel(self, reason: str):
54+
self.cancellation.cancel(reason)
55+
self.task.cancel()
56+
57+
4458
class _NexusWorker:
4559
def __init__(
4660
self,
@@ -65,7 +79,7 @@ def __init__(
6579
self._interceptors = interceptors
6680
# TODO(nexus-preview): metric_meter
6781
self._metric_meter = metric_meter
68-
self._running_tasks: dict[bytes, asyncio.Task[Any]] = {}
82+
self._running_tasks: dict[bytes, _RunningNexusTask] = {}
6983
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()
7084

7185
async def run(self) -> None:
@@ -90,21 +104,31 @@ async def raise_from_exception_queue() -> NoReturn:
90104
if nexus_task.HasField("task"):
91105
task = nexus_task.task
92106
if task.request.HasField("start_operation"):
93-
self._running_tasks[task.task_token] = asyncio.create_task(
107+
task_cancellation = _NexusTaskCancellation()
108+
start_op_task = asyncio.create_task(
94109
self._handle_start_operation_task(
95110
task.task_token,
96111
task.request.start_operation,
97112
dict(task.request.header),
113+
task_cancellation,
98114
)
99115
)
116+
self._running_tasks[task.task_token] = _RunningNexusTask(
117+
start_op_task, task_cancellation
118+
)
100119
elif task.request.HasField("cancel_operation"):
101-
self._running_tasks[task.task_token] = asyncio.create_task(
120+
task_cancellation = _NexusTaskCancellation()
121+
cancel_op_task = asyncio.create_task(
102122
self._handle_cancel_operation_task(
103123
task.task_token,
104124
task.request.cancel_operation,
105125
dict(task.request.header),
126+
task_cancellation,
106127
)
107128
)
129+
self._running_tasks[task.task_token] = _RunningNexusTask(
130+
cancel_op_task, task_cancellation
131+
)
108132
else:
109133
raise NotImplementedError(
110134
f"Invalid Nexus task request: {task.request}"
@@ -113,8 +137,12 @@ async def raise_from_exception_queue() -> NoReturn:
113137
if running_task := self._running_tasks.get(
114138
nexus_task.cancel_task.task_token
115139
):
116-
# TODO(nexus-prerelease): when do we remove the entry from _running_operations?
117-
running_task.cancel()
140+
reason = (
141+
temporalio.bridge.proto.nexus.NexusTaskCancelReason.Name(
142+
nexus_task.cancel_task.reason
143+
)
144+
)
145+
running_task.cancel(reason)
118146
else:
119147
logger.debug(
120148
f"Received cancel_task but no running task exists for "
@@ -147,7 +175,10 @@ async def drain_poll_queue(self) -> None:
147175
# Only call this after run()/drain_poll_queue() have returned. This will not
148176
# raise an exception.
149177
async def wait_all_completed(self) -> None:
150-
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
178+
running_tasks = [
179+
running_task.task for running_task in self._running_tasks.values()
180+
]
181+
await asyncio.gather(*running_tasks, return_exceptions=True)
151182

152183
# TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute
153184
# "Any call up to this function and including this one will be trimmed out of stack traces.""
@@ -157,6 +188,7 @@ async def _handle_cancel_operation_task(
157188
task_token: bytes,
158189
request: temporalio.api.nexus.v1.CancelOperationRequest,
159190
headers: Mapping[str, str],
191+
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
160192
) -> None:
161193
"""Handle a cancel operation task.
162194
@@ -168,6 +200,7 @@ async def _handle_cancel_operation_task(
168200
service=request.service,
169201
operation=request.operation,
170202
headers=headers,
203+
task_cancellation=task_cancellation,
171204
)
172205
temporalio.nexus._operation_context._TemporalCancelOperationContext(
173206
info=lambda: Info(task_queue=self._task_queue),
@@ -177,6 +210,11 @@ async def _handle_cancel_operation_task(
177210
try:
178211
try:
179212
await self._handler.cancel_operation(ctx, request.operation_token)
213+
except asyncio.CancelledError:
214+
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
215+
task_token=task_token,
216+
ack_cancel=task_cancellation.is_cancelled(),
217+
)
180218
except BaseException as err:
181219
logger.warning("Failed to execute Nexus cancel operation method")
182220
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
@@ -209,6 +247,7 @@ async def _handle_start_operation_task(
209247
task_token: bytes,
210248
start_request: temporalio.api.nexus.v1.StartOperationRequest,
211249
headers: Mapping[str, str],
250+
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
212251
) -> None:
213252
"""Handle a start operation task.
214253
@@ -217,7 +256,14 @@ async def _handle_start_operation_task(
217256
"""
218257
try:
219258
try:
220-
start_response = await self._start_operation(start_request, headers)
259+
start_response = await self._start_operation(
260+
start_request, headers, task_cancellation
261+
)
262+
except asyncio.CancelledError:
263+
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
264+
task_token=task_token,
265+
ack_cancel=task_cancellation.is_cancelled(),
266+
)
221267
except BaseException as err:
222268
logger.warning("Failed to execute Nexus start operation method")
223269
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
@@ -226,6 +272,7 @@ async def _handle_start_operation_task(
226272
_exception_to_handler_error(err)
227273
),
228274
)
275+
229276
if isinstance(err, concurrent.futures.BrokenExecutor):
230277
self._fail_worker_exception_queue.put_nowait(err)
231278
else:
@@ -235,6 +282,7 @@ async def _handle_start_operation_task(
235282
start_operation=start_response
236283
),
237284
)
285+
238286
await self._bridge_worker().complete_nexus_task(completion)
239287
except Exception:
240288
logger.exception("Failed to send Nexus task completion")
@@ -250,6 +298,7 @@ async def _start_operation(
250298
self,
251299
start_request: temporalio.api.nexus.v1.StartOperationRequest,
252300
headers: Mapping[str, str],
301+
cancellation: nexusrpc.handler.OperationTaskCancellation,
253302
) -> temporalio.api.nexus.v1.StartOperationResponse:
254303
"""Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.
255304
@@ -268,6 +317,7 @@ async def _start_operation(
268317
for link in start_request.links
269318
],
270319
callback_headers=dict(start_request.callback_header),
320+
task_cancellation=cancellation,
271321
)
272322
temporalio.nexus._operation_context._TemporalStartOperationContext(
273323
nexus_context=ctx,
@@ -517,3 +567,33 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError:
517567
)
518568
handler_err.__cause__ = err
519569
return handler_err
570+
571+
572+
class _NexusTaskCancellation(nexusrpc.handler.OperationTaskCancellation):
573+
def __init__(self):
574+
self._thread_evt = threading.Event()
575+
self._async_evt = asyncio.Event()
576+
self._lock = threading.Lock()
577+
self._reason: Optional[str] = None
578+
579+
def is_cancelled(self) -> bool:
580+
return self._thread_evt.is_set()
581+
582+
def cancellation_reason(self) -> Optional[str]:
583+
with self._lock:
584+
return self._reason
585+
586+
def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool:
587+
return self._thread_evt.wait(timeout)
588+
589+
async def wait_until_cancelled(self) -> None:
590+
await self._async_evt.wait()
591+
592+
def cancel(self, reason: str) -> bool:
593+
with self._lock:
594+
if self._thread_evt.is_set():
595+
return False
596+
self._reason = reason
597+
self._thread_evt.set()
598+
self._async_evt.set()
599+
return True

temporalio/workflow.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5417,6 +5417,22 @@ async def start_operation(
54175417
headers: Optional[Mapping[str, str]] = None,
54185418
) -> NexusOperationHandle[OutputT]: ...
54195419

5420+
# Overload for operation_handler
5421+
@overload
5422+
@abstractmethod
5423+
async def start_operation(
5424+
self,
5425+
operation: Callable[
5426+
[ServiceHandlerT], nexusrpc.handler.OperationHandler[InputT, OutputT]
5427+
],
5428+
input: InputT,
5429+
*,
5430+
output_type: Optional[Type[OutputT]] = None,
5431+
schedule_to_close_timeout: Optional[timedelta] = None,
5432+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
5433+
headers: Optional[Mapping[str, str]] = None,
5434+
) -> NexusOperationHandle[OutputT]: ...
5435+
54205436
@abstractmethod
54215437
async def start_operation(
54225438
self,
@@ -5527,6 +5543,23 @@ async def execute_operation(
55275543
headers: Optional[Mapping[str, str]] = None,
55285544
) -> OutputT: ...
55295545

5546+
# Overload for operation_handler
5547+
@overload
5548+
@abstractmethod
5549+
async def execute_operation(
5550+
self,
5551+
operation: Callable[
5552+
[ServiceT],
5553+
nexusrpc.handler.OperationHandler[InputT, OutputT],
5554+
],
5555+
input: InputT,
5556+
*,
5557+
output_type: Optional[Type[OutputT]] = None,
5558+
schedule_to_close_timeout: Optional[timedelta] = None,
5559+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
5560+
headers: Optional[Mapping[str, str]] = None,
5561+
) -> OutputT: ...
5562+
55305563
@abstractmethod
55315564
async def execute_operation(
55325565
self,

tests/contrib/openai_agents/research_agents/research_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from agents import Runner, custom_span, gen_trace_id, trace
66

7+
import temporalio.workflow
78
from tests.contrib.openai_agents.research_agents.planner_agent import (
89
WebSearchItem,
910
WebSearchPlan,
@@ -45,7 +46,7 @@ async def _perform_searches(self, search_plan: WebSearchPlan) -> list[str]:
4546
asyncio.create_task(self._search(item)) for item in search_plan.searches
4647
]
4748
results = []
48-
for task in asyncio.as_completed(tasks):
49+
for task in temporalio.workflow.as_completed(tasks):
4950
result = await task
5051
if result is not None:
5152
results.append(result)

tests/nexus/test_dynamic_creation_of_user_handler_classes.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,6 @@ async def cancel(
4444
) -> None:
4545
raise NotImplementedError
4646

47-
async def fetch_info(
48-
self,
49-
ctx: nexusrpc.handler.FetchOperationInfoContext,
50-
token: str,
51-
) -> nexusrpc.OperationInfo:
52-
raise NotImplementedError
53-
54-
async def fetch_result(
55-
self,
56-
ctx: nexusrpc.handler.FetchOperationResultContext,
57-
token: str,
58-
) -> int:
59-
raise NotImplementedError
60-
6147

6248
@nexusrpc.handler.service_handler
6349
class MyServiceHandlerWithWorkflowRunOperation:
@@ -78,8 +64,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
7864
service_handler = nexusrpc.handler._core.ServiceHandler(
7965
service=nexusrpc.ServiceDefinition(
8066
name="MyService",
81-
operations={
82-
"increment": nexusrpc.Operation[int, int](
67+
operation_definitions={
68+
"increment": nexusrpc.OperationDefinition[int, int](
8369
name="increment",
8470
method_name="increment",
8571
input_type=int,

0 commit comments

Comments
 (0)