-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathtest_activity.py
179 lines (146 loc) · 5.1 KB
/
test_activity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import asyncio
import sys
import threading
import time
from contextvars import copy_context
from temporalio import activity
from temporalio.exceptions import CancelledError
from temporalio.testing import ActivityEnvironment
async def test_activity_env_async():
waiting = asyncio.Event()
async def do_stuff(param: str) -> str:
activity.heartbeat(f"param: {param}")
# Ensure it works across create_task
async def via_create_task():
activity.heartbeat(f"task, type: {activity.info().activity_type}")
await asyncio.create_task(via_create_task())
# Wait for cancel
try:
waiting.set()
await asyncio.Future()
raise RuntimeError("Unreachable")
except asyncio.CancelledError:
activity.heartbeat("cancelled")
return "done"
env = ActivityEnvironment()
# Set heartbeat handler to add to list
heartbeats = []
env.on_heartbeat = lambda *args: heartbeats.append(args[0])
# Start task and wait until waiting
task = asyncio.create_task(env.run(do_stuff, "param1"))
await waiting.wait()
# Cancel and confirm done
env.cancel()
assert "done" == await task
assert heartbeats == ["param: param1", "task, type: unknown", "cancelled"]
def test_activity_env_sync():
waiting = threading.Event()
properly_cancelled = False
def do_stuff(param: str) -> None:
activity.heartbeat(f"param: {param}")
# Ensure it works across thread
context = copy_context()
def via_thread():
activity.heartbeat(f"task, type: {activity.info().activity_type}")
thread = threading.Thread(target=context.run, args=[via_thread])
thread.start()
thread.join()
# Wait for cancel
waiting.set()
try:
# Confirm shielding works
with activity.shield_thread_cancel_exception():
try:
while not activity.is_cancelled():
time.sleep(0.2)
time.sleep(0.2)
except:
raise RuntimeError("Unexpected")
except CancelledError:
nonlocal properly_cancelled
properly_cancelled = True
env = ActivityEnvironment()
# Set heartbeat handler to add to list
heartbeats = []
env.on_heartbeat = lambda *args: heartbeats.append(args[0])
# Start thread and wait until waiting
thread = threading.Thread(target=env.run, args=[do_stuff, "param1"])
thread.start()
waiting.wait()
# Cancel and confirm done
time.sleep(1)
env.cancel()
thread.join()
assert heartbeats == ["param: param1", "task, type: unknown"]
assert properly_cancelled
async def test_activity_env_assert():
async def assert_equals(a: str, b: str) -> None:
assert a == b
# Get out-of-env expected err
try:
await assert_equals("foo", "bar")
assert False
except Exception as err:
expected_err = err
# Get in-env actual err
try:
await ActivityEnvironment().run(assert_equals, "foo", "bar")
assert False
except Exception as err:
actual_err = err
assert type(expected_err) == type(actual_err)
assert str(expected_err) == str(actual_err)
def test_get_activities_from_cls():
class ClassAndStaticActivities(activity.ActivitiesProvider):
@classmethod
@activity.defn
async def class_method_activity(cls):
pass
@staticmethod
@activity.defn
async def static_method_activity():
pass
assert ClassAndStaticActivities.get_activities_from_cls() == [
ClassAndStaticActivities.class_method_activity,
ClassAndStaticActivities.static_method_activity,
]
class _AllActivityMethodTypes(activity.ActivitiesProvider):
@activity.defn
async def instance_method_activity(self):
pass
@classmethod
@activity.defn
async def class_method_activity(cls):
pass
@staticmethod
@activity.defn
async def static_method_activity():
pass
def test_get_activities_from_cls_error():
try:
_AllActivityMethodTypes.get_activities_from_cls()
raise Exception("above call should have thrown value error")
except ValueError as ex:
assert str(ex) == (f"Class _AllActivityMethodTypes method instance_method_activity is an activity, but it is an instance method. "
"Because of that, you cannot gather activities from the class, you must get them from "
"an instance using instance.get_activities_from_instance()"
)
def test_get_activities_from_instance():
inst = _AllActivityMethodTypes()
assert inst.get_activities_from_instance() == [
inst.class_method_activity,
inst.instance_method_activity,
inst.static_method_activity,
]
@activity.defn
def _some_activity():
pass
@activity.defn
async def _some_async_activity():
pass
def test_get_activities():
current_module = sys.modules[__name__]
assert activity.get_activities(current_module) == [
_some_activity,
_some_async_activity
]