Skip to content

Commit 8aa6795

Browse files
authored
Merge pull request #127 from stealthrocket/batch-submit
Batch submit
2 parents e0a1b1d + ff4be5e commit 8aa6795

File tree

7 files changed

+191
-128
lines changed

7 files changed

+191
-128
lines changed

Diff for: examples/github_stats/test_app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fastapi.testclient import TestClient
1010

11-
from dispatch.client import Client
11+
from dispatch.function import Client
1212
from dispatch.test import DispatchServer, DispatchService, EndpointClient
1313

1414

Diff for: src/dispatch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from __future__ import annotations
44

55
import dispatch.integrations
6-
from dispatch.client import DEFAULT_API_URL, Client
76
from dispatch.coroutine import call, gather
7+
from dispatch.function import DEFAULT_API_URL, Client
88
from dispatch.id import DispatchID
99
from dispatch.proto import Call, Error, Input, Output
1010
from dispatch.status import Status

Diff for: src/dispatch/client.py

-118
This file was deleted.

Diff for: src/dispatch/fastapi.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def read_root():
2727
import fastapi.responses
2828
from http_message_signatures import InvalidSignature
2929

30-
from dispatch.client import Client
31-
from dispatch.function import Registry
30+
from dispatch.function import Batch, Client, Registry
3231
from dispatch.proto import Input
3332
from dispatch.sdk.v1 import function_pb2 as function_pb
3433
from dispatch.signature import (
@@ -47,7 +46,7 @@ def read_root():
4746
class Dispatch(Registry):
4847
"""A Dispatch programmable endpoint, powered by FastAPI."""
4948

50-
__slots__ = ()
49+
__slots__ = ("client",)
5150

5251
def __init__(
5352
self,
@@ -116,12 +115,17 @@ def __init__(
116115
"request verification is disabled because DISPATCH_VERIFICATION_KEY is not set"
117116
)
118117

119-
client = Client(api_key=api_key, api_url=api_url)
120-
super().__init__(endpoint, client)
118+
self.client = Client(api_key=api_key, api_url=api_url)
119+
super().__init__(endpoint, self.client)
121120

122121
function_service = _new_app(self, verification_key)
123122
app.mount("/dispatch.sdk.v1.FunctionService", function_service)
124123

124+
def batch(self) -> Batch:
125+
"""Returns a Batch instance that can be used to build
126+
a set of calls to dispatch."""
127+
return self.client.batch()
128+
125129

126130
def parse_verification_key(
127131
verification_key: Ed25519PublicKey | str | bytes | None,

Diff for: src/dispatch/function.py

+154-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import inspect
44
import logging
5+
import os
56
from functools import wraps
67
from types import CoroutineType
78
from typing import (
@@ -10,14 +11,19 @@
1011
Coroutine,
1112
Dict,
1213
Generic,
14+
Iterable,
1315
ParamSpec,
1416
TypeAlias,
1517
TypeVar,
1618
overload,
1719
)
20+
from urllib.parse import urlparse
21+
22+
import grpc
1823

1924
import dispatch.coroutine
20-
from dispatch.client import Client
25+
import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb
26+
import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc
2127
from dispatch.experimental.durable import durable
2228
from dispatch.id import DispatchID
2329
from dispatch.proto import Arguments, Call, Error, Input, Output
@@ -33,6 +39,9 @@
3339
"""
3440

3541

42+
DEFAULT_API_URL = "https://api.dispatch.run"
43+
44+
3645
class PrimitiveFunction:
3746
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func")
3847

@@ -234,3 +243,147 @@ def set_client(self, client: Client):
234243
self._client = client
235244
for fn in self._functions.values():
236245
fn._client = client
246+
247+
248+
class Client:
249+
"""Client for the Dispatch API."""
250+
251+
__slots__ = ("api_url", "api_key", "_stub", "api_key_from")
252+
253+
def __init__(self, api_key: None | str = None, api_url: None | str = None):
254+
"""Create a new Dispatch client.
255+
256+
Args:
257+
api_key: Dispatch API key to use for authentication. Uses the value of
258+
the DISPATCH_API_KEY environment variable by default.
259+
260+
api_url: The URL of the Dispatch API to use. Uses the value of the
261+
DISPATCH_API_URL environment variable if set, otherwise
262+
defaults to the public Dispatch API (DEFAULT_API_URL).
263+
264+
Raises:
265+
ValueError: if the API key is missing.
266+
"""
267+
268+
if api_key:
269+
self.api_key_from = "api_key"
270+
else:
271+
self.api_key_from = "DISPATCH_API_KEY"
272+
api_key = os.environ.get("DISPATCH_API_KEY")
273+
if not api_key:
274+
raise ValueError(
275+
"missing API key: set it with the DISPATCH_API_KEY environment variable"
276+
)
277+
278+
if not api_url:
279+
api_url = os.environ.get("DISPATCH_API_URL", DEFAULT_API_URL)
280+
if not api_url:
281+
raise ValueError(
282+
"missing API URL: set it with the DISPATCH_API_URL environment variable"
283+
)
284+
285+
logger.debug("initializing client for Dispatch API at URL %s", api_url)
286+
self.api_url = api_url
287+
self.api_key = api_key
288+
self._init_stub()
289+
290+
def __getstate__(self):
291+
return {"api_url": self.api_url, "api_key": self.api_key}
292+
293+
def __setstate__(self, state):
294+
self.api_url = state["api_url"]
295+
self.api_key = state["api_key"]
296+
self._init_stub()
297+
298+
def _init_stub(self):
299+
result = urlparse(self.api_url)
300+
match result.scheme:
301+
case "http":
302+
creds = grpc.local_channel_credentials()
303+
case "https":
304+
creds = grpc.ssl_channel_credentials()
305+
case _:
306+
raise ValueError(f"Invalid API scheme: '{result.scheme}'")
307+
308+
call_creds = grpc.access_token_call_credentials(self.api_key)
309+
creds = grpc.composite_channel_credentials(creds, call_creds)
310+
channel = grpc.secure_channel(result.netloc, creds)
311+
312+
self._stub = dispatch_grpc.DispatchServiceStub(channel)
313+
314+
def batch(self) -> Batch:
315+
"""Returns a Batch instance that can be used to build
316+
a set of calls to dispatch."""
317+
return Batch(self)
318+
319+
def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]:
320+
"""Dispatch function calls.
321+
322+
Args:
323+
calls: Calls to dispatch.
324+
325+
Returns:
326+
Identifiers for the function calls, in the same order as the inputs.
327+
"""
328+
calls_proto = [c._as_proto() for c in calls]
329+
logger.debug("dispatching %d function call(s)", len(calls_proto))
330+
req = dispatch_pb.DispatchRequest(calls=calls_proto)
331+
332+
try:
333+
resp = self._stub.Dispatch(req)
334+
except grpc.RpcError as e:
335+
status_code = e.code()
336+
match status_code:
337+
case grpc.StatusCode.UNAUTHENTICATED:
338+
raise PermissionError(
339+
f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)"
340+
) from e
341+
raise
342+
343+
dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids]
344+
if logger.isEnabledFor(logging.DEBUG):
345+
logger.debug(
346+
"dispatched %d function call(s): %s",
347+
len(calls_proto),
348+
", ".join(dispatch_ids),
349+
)
350+
return dispatch_ids
351+
352+
353+
class Batch:
354+
"""A batch of calls to dispatch."""
355+
356+
__slots__ = ("client", "calls")
357+
358+
def __init__(self, client: Client):
359+
self.client = client
360+
self.calls: list[Call] = []
361+
362+
def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch:
363+
"""Add a call to the specified function to the batch."""
364+
return self.add_call(func.build_call(*args, correlation_id=None, **kwargs))
365+
366+
def add_call(self, call: Call) -> Batch:
367+
"""Add a Call to the batch."""
368+
self.calls.append(call)
369+
return self
370+
371+
def dispatch(self) -> list[DispatchID]:
372+
"""Dispatch dispatches the calls asynchronously.
373+
374+
The batch is reset when the calls are dispatched successfully.
375+
376+
Returns:
377+
Identifiers for the function calls, in the same order they
378+
were added.
379+
"""
380+
if not self.calls:
381+
return []
382+
383+
dispatch_ids = self.client.dispatch(self.calls)
384+
self.reset()
385+
return dispatch_ids
386+
387+
def reset(self):
388+
"""Reset the batch."""
389+
self.calls = []

Diff for: tests/dispatch/test_function.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import pickle
22
import unittest
33

4-
from dispatch.client import Client
5-
from dispatch.function import Registry
4+
from dispatch.function import Client, Registry
65

76

87
class TestFunction(unittest.TestCase):

0 commit comments

Comments
 (0)