Skip to content

Commit d016221

Browse files
fix running examples as tests
Signed-off-by: Achille Roussel <[email protected]>
1 parent 0918918 commit d016221

File tree

7 files changed

+83
-32
lines changed

7 files changed

+83
-32
lines changed

examples/auto_retry.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@ def third_party_api_call(x):
1212
# Simulate a third-party API call that fails.
1313
print(f"Simulating third-party API call with {x}")
1414
if x < 3:
15+
print("RAISE EXCEPTION")
1516
raise requests.RequestException("Simulated failure")
1617
else:
1718
return "SUCCESS"
1819

1920

2021
# Use the `dispatch.function` decorator to declare a stateful function.
2122
@dispatch.function
22-
def application():
23+
def auto_retry():
2324
x = rng.randint(0, 5)
2425
return third_party_api_call(x)
2526

2627

27-
dispatch.run(application())
28+
if __name__ == "__main__":
29+
print(dispatch.run(auto_retry()))

examples/fanout.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ async def fanout():
4242
return await reduce_stargazers(repos)
4343

4444

45-
print(dispatch.run(fanout()))
45+
if __name__ == "__main__":
46+
print(dispatch.run(fanout()))

examples/getting_started.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
import dispatch
44

55

6-
# Use the `dispatch.function` decorator declare a stateful function.
76
@dispatch.function
87
def publish(url, payload):
98
r = requests.post(url, data=payload)
109
r.raise_for_status()
1110
return r.text
1211

1312

14-
# Use the `dispatch.run` function to run the function with automatic error
15-
# handling and retries.
16-
res = dispatch.run(publish("https://httpstat.us/200", {"hello": "world"}))
17-
print(res)
13+
@dispatch.function
14+
async def getting_started():
15+
return await publish("https://httpstat.us/200", {"hello": "world"})
16+
17+
18+
if __name__ == "__main__":
19+
print(dispatch.run(getting_started()))

examples/github_stats.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,3 @@
1-
"""Github repository stats example.
2-
3-
This example demonstrates how to use async functions orchestrated by Dispatch.
4-
5-
Make sure to follow the setup instructions at
6-
https://docs.dispatch.run/dispatch/stateful-functions/getting-started/
7-
8-
Run with:
9-
10-
uvicorn app:app
11-
12-
13-
Logs will show a pipeline of functions being called and their results.
14-
15-
"""
16-
171
import httpx
182

193
import dispatch
@@ -31,21 +15,21 @@ def get_gh_api(url):
3115

3216

3317
@dispatch.function
34-
async def get_repo_info(repo_owner, repo_name):
18+
def get_repo_info(repo_owner, repo_name):
3519
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}"
3620
repo_info = get_gh_api(url)
3721
return repo_info
3822

3923

4024
@dispatch.function
41-
async def get_contributors(repo_info):
25+
def get_contributors(repo_info):
4226
url = repo_info["contributors_url"]
4327
contributors = get_gh_api(url)
4428
return contributors
4529

4630

4731
@dispatch.function
48-
async def main():
32+
async def github_stats():
4933
repo_info = await get_repo_info("dispatchrun", "coroutine")
5034
print(
5135
f"""Repository: {repo_info['full_name']}
@@ -57,5 +41,5 @@ async def main():
5741

5842

5943
if __name__ == "__main__":
60-
contributors = dispatch.run(main())
44+
contributors = dispatch.run(github_stats())
6145
print(f"Contributors: {len(contributors)}")

examples/test_examples.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import dispatch.test
2+
3+
from .auto_retry import auto_retry
4+
from .fanout import fanout
5+
from .getting_started import getting_started
6+
from .github_stats import github_stats
7+
8+
9+
@dispatch.test.function
10+
async def test_auto_retry():
11+
assert await auto_retry() == "SUCCESS"
12+
13+
14+
@dispatch.test.function
15+
async def test_fanout():
16+
contributors = await fanout()
17+
assert len(contributors) >= 15
18+
assert "achille-roussel" in contributors
19+
20+
21+
@dispatch.test.function
22+
async def test_getting_started():
23+
assert await getting_started() == "200 OK"
24+
25+
26+
@dispatch.test.function
27+
async def test_github_stats():
28+
contributors = await github_stats()
29+
assert len(contributors) >= 6

src/dispatch/scheduler.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
CoroutineID: TypeAlias = int
3333
CorrelationID: TypeAlias = int
3434

35-
_in_function_call = contextvars.ContextVar("dispatch.scheduler.in_function_call", default=False)
35+
_in_function_call = contextvars.ContextVar(
36+
"dispatch.scheduler.in_function_call", default=False
37+
)
38+
3639

3740
def in_function_call() -> bool:
3841
return bool(_in_function_call.get())
@@ -523,7 +526,11 @@ def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]
523526
if isinstance(coroutine_yield, RaceDirective):
524527
return set_coroutine_race(state, coroutine, coroutine_yield.awaitables)
525528

526-
yield coroutine_yield
529+
try:
530+
yield coroutine_yield
531+
except Exception as e:
532+
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)
533+
return set_coroutine_result(state, coroutine, coroutine_result)
527534

528535

529536
def set_coroutine_result(

src/dispatch/test/__init__.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,17 @@
2828
from dispatch.sdk.v1.error_pb2 import Error
2929
from dispatch.sdk.v1.function_pb2 import RunRequest, RunResponse
3030
from dispatch.sdk.v1.poll_pb2 import PollResult
31-
from dispatch.sdk.v1.status_pb2 import STATUS_OK
31+
from dispatch.sdk.v1.status_pb2 import (
32+
STATUS_DNS_ERROR,
33+
STATUS_HTTP_ERROR,
34+
STATUS_INCOMPATIBLE_STATE,
35+
STATUS_OK,
36+
STATUS_TCP_ERROR,
37+
STATUS_TEMPORARY_ERROR,
38+
STATUS_THROTTLED,
39+
STATUS_TIMEOUT,
40+
STATUS_TLS_ERROR,
41+
)
3242

3343
from .client import EndpointClient
3444
from .server import DispatchServer
@@ -183,7 +193,18 @@ def make_request(call: Call) -> RunRequest:
183193
res = await self.run(call.endpoint, req)
184194

185195
if res.status != STATUS_OK:
186-
# TODO: emulate retries etc...
196+
if res.status in (
197+
STATUS_TIMEOUT,
198+
STATUS_THROTTLED,
199+
STATUS_TEMPORARY_ERROR,
200+
STATUS_INCOMPATIBLE_STATE,
201+
STATUS_DNS_ERROR,
202+
STATUS_TCP_ERROR,
203+
STATUS_TLS_ERROR,
204+
STATUS_HTTP_ERROR,
205+
):
206+
continue # emulate retries, without backoff for now
207+
187208
if (
188209
res.HasField("exit")
189210
and res.exit.HasField("result")
@@ -263,14 +284,19 @@ async def main(coro: Coroutine[Any, Any, None]) -> None:
263284
api = Service()
264285
app = Dispatch(reg)
265286
try:
287+
print("Starting bakend")
266288
async with Server(api) as backend:
289+
print("Starting server")
267290
async with Server(app) as server:
268291
# Here we break through the abstraction layers a bit, it's not
269292
# ideal but it works for now.
270293
reg.client.api_url.value = backend.url
271294
reg.endpoint = server.url
295+
print("BACKEND:", backend.url)
296+
print("SERVER:", server.url)
272297
await coro
273298
finally:
299+
print("DONE!")
274300
await api.close()
275301
# TODO: let's figure out how to get rid of this global registry
276302
# state at some point, which forces tests to be run sequentially.

0 commit comments

Comments
 (0)