Skip to content

Commit 0a4391d

Browse files
authored
Merge pull request #86 from stealthrocket/test-server
dispatch.test package
2 parents ca215f5 + 3790422 commit 0a4391d

File tree

13 files changed

+588
-364
lines changed

13 files changed

+588
-364
lines changed

examples/auto_retry/test_app.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
from fastapi.testclient import TestClient
1010

11-
import dispatch.sdk.v1.status_pb2 as status_pb
12-
13-
from ... import function_service
14-
from ...test_client import ServerTest
11+
from dispatch import Client
12+
from dispatch.sdk.v1 import status_pb2 as status_pb
13+
from dispatch.test import DispatchServer, DispatchService, EndpointClient
1514

1615

1716
class TestAutoRetry(unittest.TestCase):
@@ -22,29 +21,33 @@ class TestAutoRetry(unittest.TestCase):
2221
"DISPATCH_API_KEY": "0000000000000000",
2322
},
2423
)
25-
def test_foo(self):
26-
from . import app
24+
def test_app(self):
25+
from .app import app, dispatch
26+
27+
# Setup a fake Dispatch server.
28+
endpoint_client = EndpointClient.from_app(app)
29+
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
30+
with DispatchServer(dispatch_service) as dispatch_server:
2731

28-
server = ServerTest()
29-
servicer = server.servicer
30-
app.dispatch._client = server.client
31-
app.some_logic._client = server.client
32+
# Use it when dispatching function calls.
33+
dispatch.set_client(Client(api_url=dispatch_server.url))
3234

33-
http_client = TestClient(app.app, base_url="http://dispatch-service")
34-
app_client = function_service.client(http_client)
35+
http_client = TestClient(app)
36+
response = http_client.get("/")
37+
self.assertEqual(response.status_code, 200)
3538

36-
response = http_client.get("/")
37-
self.assertEqual(response.status_code, 200)
39+
dispatch_service.dispatch_calls()
3840

39-
server.execute(app_client)
41+
# Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6
42+
# calls, including 5 retries.
43+
for i in range(6):
44+
dispatch_service.dispatch_calls()
4045

41-
# Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6
42-
# calls, including 5 retries.
43-
for i in range(6):
44-
server.execute(app_client)
45-
self.assertEqual(len(servicer.responses), 6)
46+
self.assertEqual(len(dispatch_service.roundtrips), 1)
47+
roundtrips = list(dispatch_service.roundtrips.values())[0]
48+
self.assertEqual(len(roundtrips), 6)
4649

47-
statuses = [r["response"].status for r in servicer.responses]
48-
self.assertEqual(
49-
statuses, [status_pb.STATUS_TEMPORARY_ERROR] * 5 + [status_pb.STATUS_OK]
50-
)
50+
statuses = [response.status for request, response in roundtrips]
51+
self.assertEqual(
52+
statuses, [status_pb.STATUS_TEMPORARY_ERROR] * 5 + [status_pb.STATUS_OK]
53+
)

examples/getting_started/test_app.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from fastapi.testclient import TestClient
1010

11-
from ... import function_service
12-
from ...test_client import ServerTest
11+
from dispatch import Client
12+
from dispatch.test import DispatchServer, DispatchService, EndpointClient
1313

1414

1515
class TestGettingStarted(unittest.TestCase):
@@ -20,20 +20,23 @@ class TestGettingStarted(unittest.TestCase):
2020
"DISPATCH_API_KEY": "0000000000000000",
2121
},
2222
)
23-
def test_foo(self):
24-
from . import app
23+
def test_app(self):
24+
from .app import app, dispatch
2525

26-
server = ServerTest()
27-
servicer = server.servicer
28-
app.dispatch._client = server.client
29-
app.publish._client = server.client
26+
# Setup a fake Dispatch server.
27+
endpoint_client = EndpointClient.from_app(app)
28+
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
29+
with DispatchServer(dispatch_service) as dispatch_server:
3030

31-
http_client = TestClient(app.app, base_url="http://dispatch-service")
32-
app_client = function_service.client(http_client)
31+
# Use it when dispatching function calls.
32+
dispatch.set_client(Client(api_url=dispatch_server.url))
3333

34-
response = http_client.get("/")
35-
self.assertEqual(response.status_code, 200)
34+
http_client = TestClient(app)
35+
response = http_client.get("/")
36+
self.assertEqual(response.status_code, 200)
3637

37-
server.execute(app_client)
38+
dispatch_service.dispatch_calls()
3839

39-
self.assertEqual(len(servicer.responses), 1)
40+
self.assertEqual(len(dispatch_service.roundtrips), 1) # one call submitted
41+
dispatch_id, roundtrips = list(dispatch_service.roundtrips.items())[0]
42+
self.assertEqual(len(roundtrips), 1) # one roundtrip for this call

examples/github_stats/test_app.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from fastapi.testclient import TestClient
1010

11-
from ... import function_service
12-
from ...test_client import ServerTest
11+
from dispatch.client import Client
12+
from dispatch.test import DispatchServer, DispatchService, EndpointClient
1313

1414

1515
class TestGithubStats(unittest.TestCase):
@@ -20,22 +20,35 @@ class TestGithubStats(unittest.TestCase):
2020
"DISPATCH_API_KEY": "0000000000000000",
2121
},
2222
)
23-
def test_foo(self):
24-
from . import app
25-
26-
server = ServerTest()
27-
servicer = server.servicer
28-
app.dispatch._client = server.client
29-
app.get_repo_info._client = server.client
30-
app.get_contributors._client = server.client
31-
app.main._client = server.client
32-
33-
http_client = TestClient(app.app, base_url="http://dispatch-service")
34-
app_client = function_service.client(http_client)
35-
36-
response = http_client.get("/")
37-
self.assertEqual(response.status_code, 200)
38-
39-
server.execute(app_client)
40-
41-
self.assertEqual(len(servicer.responses), 1)
23+
def test_app(self):
24+
from .app import app, dispatch
25+
26+
# Setup a fake Dispatch server.
27+
endpoint_client = EndpointClient.from_app(app)
28+
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
29+
with DispatchServer(dispatch_service) as dispatch_server:
30+
31+
# Use it when dispatching function calls.
32+
dispatch.set_client(Client(api_url=dispatch_server.url))
33+
34+
http_client = TestClient(app)
35+
response = http_client.get("/")
36+
self.assertEqual(response.status_code, 200)
37+
38+
while dispatch_service.queue:
39+
dispatch_service.dispatch_calls()
40+
41+
# Three unique functions were called, with five total round-trips.
42+
# The main function is called initially, and then polls
43+
# twice, for three total round-trips. There's one round-trip
44+
# to get_repo_info and one round-trip to get_contributors.
45+
self.assertEqual(
46+
3, len(dispatch_service.roundtrips)
47+
) # 3 unique functions were called
48+
self.assertEqual(
49+
5,
50+
sum(
51+
len(roundtrips)
52+
for roundtrips in dispatch_service.roundtrips.values()
53+
),
54+
)

src/dispatch/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _init_stub(self):
8484

8585
self._stub = dispatch_grpc.DispatchServiceStub(channel)
8686

87-
def dispatch(self, calls: Iterable[Call]) -> Iterable[DispatchID]:
87+
def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]:
8888
"""Dispatch function calls.
8989
9090
Args:

src/dispatch/function.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class Function:
5555
def __init__(
5656
self,
5757
endpoint: str,
58-
client: Client | None,
58+
client: Client,
5959
name: str,
6060
primitive_func: PrimitiveFunctionType,
6161
func: Callable,
@@ -102,11 +102,6 @@ def dispatch(self, *args: Any, **kwargs: Any) -> DispatchID:
102102
return self._primitive_dispatch(Arguments(args, kwargs))
103103

104104
def _primitive_dispatch(self, input: Any = None) -> DispatchID:
105-
if self._client is None:
106-
raise RuntimeError(
107-
"Dispatch Client has not been configured (api_key not provided)"
108-
)
109-
110105
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
111106
return dispatch_id
112107

@@ -151,13 +146,13 @@ class Registry:
151146

152147
__slots__ = ("_functions", "_endpoint", "_client")
153148

154-
def __init__(self, endpoint: str, client: Client | None):
149+
def __init__(self, endpoint: str, client: Client):
155150
"""Initialize a local function registry.
156151
157152
Args:
158153
endpoint: URL of the endpoint that the function is accessible from.
159-
client: Optional client for the Dispatch API. If provided, calls
160-
to local functions can be dispatched directly.
154+
client: Client for the Dispatch API. Used to dispatch calls to
155+
local functions.
161156
"""
162157
self._functions: Dict[str, Function] = {}
163158
self._endpoint = endpoint
@@ -235,3 +230,9 @@ def _register(
235230
)
236231
self._functions[name] = wrapped_func
237232
return wrapped_func
233+
234+
def set_client(self, client: Client):
235+
"""Set the Client instance used to dispatch calls to local functions."""
236+
self._client = client
237+
for fn in self._functions.values():
238+
fn._client = client

src/dispatch/test/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .client import EndpointClient
2+
from .server import DispatchServer
3+
from .service import DispatchService
4+
5+
__all__ = ["EndpointClient", "DispatchServer", "DispatchService"]

0 commit comments

Comments
 (0)