14
14
import threading
15
15
import warnings
16
16
from abc import ABC , abstractmethod
17
+ from collections .abc import Iterator , Sequence
17
18
from contextlib import contextmanager
18
19
from dataclasses import dataclass , field
19
20
from datetime import datetime , timedelta , timezone
20
21
from typing import (
21
22
Any ,
22
23
Callable ,
23
- Dict ,
24
- Iterator ,
25
24
NoReturn ,
26
25
Optional ,
27
- Sequence ,
28
- Tuple ,
29
- Type ,
30
26
Union ,
31
27
)
32
28
33
29
import google .protobuf .duration_pb2
34
30
import google .protobuf .timestamp_pb2
35
31
36
32
import temporalio .activity
37
- import temporalio .api .common .v1
38
- import temporalio .bridge .client
39
- import temporalio .bridge .proto
40
- import temporalio .bridge .proto .activity_result
41
- import temporalio .bridge .proto .activity_task
42
- import temporalio .bridge .proto .common
43
33
import temporalio .bridge .runtime
44
34
import temporalio .bridge .worker
45
35
import temporalio .client
@@ -76,7 +66,7 @@ def __init__(
76
66
self ._task_queue = task_queue
77
67
self ._activity_executor = activity_executor
78
68
self ._shared_state_manager = shared_state_manager
79
- self ._running_activities : Dict [bytes , _RunningActivity ] = {}
69
+ self ._running_activities : dict [bytes , _RunningActivity ] = {}
80
70
self ._data_converter = data_converter
81
71
self ._interceptors = interceptors
82
72
self ._metric_meter = metric_meter
@@ -90,7 +80,7 @@ def __init__(
90
80
self ._client = client
91
81
92
82
# Validate and build activity dict
93
- self ._activities : Dict [str , temporalio .activity ._Definition ] = {}
83
+ self ._activities : dict [str , temporalio .activity ._Definition ] = {}
94
84
self ._dynamic_activity : Optional [temporalio .activity ._Definition ] = None
95
85
for activity in activities :
96
86
# Get definition
@@ -178,7 +168,7 @@ async def raise_from_exception_queue() -> NoReturn:
178
168
self ._handle_cancel_activity_task (task .task_token , task .cancel )
179
169
else :
180
170
raise RuntimeError (f"Unrecognized activity task: { task } " )
181
- except temporalio .bridge .worker .PollShutdownError :
171
+ except temporalio .bridge .worker .PollShutdownError : # type: ignore[reportPrivateLocalImportUsage]
182
172
exception_task .cancel ()
183
173
return
184
174
except Exception as err :
@@ -195,12 +185,12 @@ async def drain_poll_queue(self) -> None:
195
185
try :
196
186
# Just take all tasks and say we can't handle them
197
187
task = await self ._bridge_worker ().poll_activity_task ()
198
- completion = temporalio .bridge .proto .ActivityTaskCompletion (
188
+ completion = temporalio .bridge .proto .ActivityTaskCompletion ( # type: ignore[reportAttributeAccessIssue]
199
189
task_token = task .task_token
200
190
)
201
191
completion .result .failed .failure .message = "Worker shutting down"
202
192
await self ._bridge_worker ().complete_activity_task (completion )
203
- except temporalio .bridge .worker .PollShutdownError :
193
+ except temporalio .bridge .worker .PollShutdownError : # type: ignore[reportPrivateLocalImportUsage]
204
194
return
205
195
206
196
# Only call this after run()/drain_poll_queue() have returned. This will not
@@ -214,7 +204,9 @@ async def wait_all_completed(self) -> None:
214
204
await asyncio .gather (* running_tasks , return_exceptions = False )
215
205
216
206
def _handle_cancel_activity_task (
217
- self , task_token : bytes , cancel : temporalio .bridge .proto .activity_task .Cancel
207
+ self ,
208
+ task_token : bytes ,
209
+ cancel : temporalio .bridge .proto .activity_task .Cancel , # type: ignore[reportAttributeAccessIssue]
218
210
) -> None :
219
211
"""Request cancellation of a running activity task."""
220
212
activity = self ._running_activities .get (task_token )
@@ -262,7 +254,9 @@ async def _heartbeat_async(
262
254
263
255
# Perform the heartbeat
264
256
try :
265
- heartbeat = temporalio .bridge .proto .ActivityHeartbeat (task_token = task_token )
257
+ heartbeat = temporalio .bridge .proto .ActivityHeartbeat ( # type: ignore[reportAttributeAccessIssue]
258
+ task_token = task_token
259
+ )
266
260
if details :
267
261
# Convert to core payloads
268
262
heartbeat .details .extend (await self ._data_converter .encode (details ))
@@ -284,7 +278,7 @@ async def _heartbeat_async(
284
278
async def _handle_start_activity_task (
285
279
self ,
286
280
task_token : bytes ,
287
- start : temporalio .bridge .proto .activity_task .Start ,
281
+ start : temporalio .bridge .proto .activity_task .Start , # type: ignore[reportAttributeAccessIssue]
288
282
running_activity : _RunningActivity ,
289
283
) -> None :
290
284
"""Handle a start activity task.
@@ -296,7 +290,7 @@ async def _handle_start_activity_task(
296
290
# We choose to surround interceptor creation and activity invocation in
297
291
# a try block so we can mark the workflow as failed on any error instead
298
292
# of having error handling in the interceptor
299
- completion = temporalio .bridge .proto .ActivityTaskCompletion (
293
+ completion = temporalio .bridge .proto .ActivityTaskCompletion ( # type: ignore[reportAttributeAccessIssue]
300
294
task_token = task_token
301
295
)
302
296
try :
@@ -413,7 +407,7 @@ async def _handle_start_activity_task(
413
407
414
408
async def _execute_activity (
415
409
self ,
416
- start : temporalio .bridge .proto .activity_task .Start ,
410
+ start : temporalio .bridge .proto .activity_task .Start , # type: ignore[reportAttributeAccessIssue]
417
411
running_activity : _RunningActivity ,
418
412
task_token : bytes ,
419
413
) -> Any :
@@ -649,14 +643,14 @@ class _ThreadExceptionRaiser:
649
643
def __init__ (self ) -> None :
650
644
self ._lock = threading .Lock ()
651
645
self ._thread_id : Optional [int ] = None
652
- self ._pending_exception : Optional [Type [Exception ]] = None
646
+ self ._pending_exception : Optional [type [Exception ]] = None
653
647
self ._shield_depth = 0
654
648
655
649
def set_thread_id (self , thread_id : int ) -> None :
656
650
with self ._lock :
657
651
self ._thread_id = thread_id
658
652
659
- def raise_in_thread (self , exc_type : Type [Exception ]) -> None :
653
+ def raise_in_thread (self , exc_type : type [Exception ]) -> None :
660
654
with self ._lock :
661
655
self ._pending_exception = exc_type
662
656
self ._raise_in_thread_if_pending_unlocked ()
@@ -812,7 +806,7 @@ def _execute_sync_activity(
812
806
cancelled_event : threading .Event ,
813
807
worker_shutdown_event : threading .Event ,
814
808
payload_converter_class_or_instance : Union [
815
- Type [temporalio .converter .PayloadConverter ],
809
+ type [temporalio .converter .PayloadConverter ],
816
810
temporalio .converter .PayloadConverter ,
817
811
],
818
812
runtime_metric_meter : Optional [temporalio .common .MetricMeter ],
@@ -824,13 +818,10 @@ def _execute_sync_activity(
824
818
thread_id = threading .current_thread ().ident
825
819
if thread_id is not None :
826
820
cancel_thread_raiser .set_thread_id (thread_id )
827
- heartbeat_fn : Callable [..., None ]
828
821
if isinstance (heartbeat , SharedHeartbeatSender ):
829
- # To make mypy happy
830
- heartbeat_sender = heartbeat
831
- heartbeat_fn = lambda * details : heartbeat_sender .send_heartbeat (
832
- info .task_token , * details
833
- )
822
+
823
+ def heartbeat_fn (* details : Any ) -> None :
824
+ heartbeat .send_heartbeat (info .task_token , * details )
834
825
else :
835
826
heartbeat_fn = heartbeat
836
827
temporalio .activity ._Context .set (
@@ -940,11 +931,11 @@ def __init__(
940
931
self ._mgr = mgr
941
932
self ._queue_poller_executor = queue_poller_executor
942
933
# 1000 in-flight heartbeats should be plenty
943
- self ._heartbeat_queue : queue .Queue [Tuple [bytes , Sequence [Any ]]] = mgr .Queue (
934
+ self ._heartbeat_queue : queue .Queue [tuple [bytes , Sequence [Any ]]] = mgr .Queue (
944
935
1000
945
936
)
946
- self ._heartbeats : Dict [bytes , Callable [..., None ]] = {}
947
- self ._heartbeat_completions : Dict [bytes , Callable ] = {}
937
+ self ._heartbeats : dict [bytes , Callable [..., None ]] = {}
938
+ self ._heartbeat_completions : dict [bytes , Callable ] = {}
948
939
949
940
def new_event (self ) -> threading .Event :
950
941
return self ._mgr .Event ()
@@ -1002,7 +993,7 @@ def _heartbeat_processor(self) -> None:
1002
993
1003
994
class _MultiprocessingSharedHeartbeatSender (SharedHeartbeatSender ):
1004
995
def __init__ (
1005
- self , heartbeat_queue : queue .Queue [Tuple [bytes , Sequence [Any ]]]
996
+ self , heartbeat_queue : queue .Queue [tuple [bytes , Sequence [Any ]]]
1006
997
) -> None :
1007
998
super ().__init__ ()
1008
999
self ._heartbeat_queue = heartbeat_queue
0 commit comments