|
2 | 2 |
|
3 | 3 | import inspect
|
4 | 4 | import logging
|
| 5 | +import os |
5 | 6 | from functools import wraps
|
6 | 7 | from types import CoroutineType
|
7 | 8 | from typing import (
|
|
10 | 11 | Coroutine,
|
11 | 12 | Dict,
|
12 | 13 | Generic,
|
| 14 | + Iterable, |
13 | 15 | ParamSpec,
|
14 | 16 | TypeAlias,
|
15 | 17 | TypeVar,
|
16 | 18 | overload,
|
17 | 19 | )
|
| 20 | +from urllib.parse import urlparse |
| 21 | + |
| 22 | +import grpc |
18 | 23 |
|
19 | 24 | import dispatch.coroutine
|
20 |
| -from dispatch.client import Client |
| 25 | +import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb |
| 26 | +import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc |
21 | 27 | from dispatch.experimental.durable import durable
|
22 | 28 | from dispatch.id import DispatchID
|
23 | 29 | from dispatch.proto import Arguments, Call, Error, Input, Output
|
|
33 | 39 | """
|
34 | 40 |
|
35 | 41 |
|
| 42 | +DEFAULT_API_URL = "https://api.dispatch.run" |
| 43 | + |
| 44 | + |
36 | 45 | class PrimitiveFunction:
|
37 | 46 | __slots__ = ("_endpoint", "_client", "_name", "_primitive_func")
|
38 | 47 |
|
@@ -234,3 +243,147 @@ def set_client(self, client: Client):
|
234 | 243 | self._client = client
|
235 | 244 | for fn in self._functions.values():
|
236 | 245 | fn._client = client
|
| 246 | + |
| 247 | + |
| 248 | +class Client: |
| 249 | + """Client for the Dispatch API.""" |
| 250 | + |
| 251 | + __slots__ = ("api_url", "api_key", "_stub", "api_key_from") |
| 252 | + |
| 253 | + def __init__(self, api_key: None | str = None, api_url: None | str = None): |
| 254 | + """Create a new Dispatch client. |
| 255 | +
|
| 256 | + Args: |
| 257 | + api_key: Dispatch API key to use for authentication. Uses the value of |
| 258 | + the DISPATCH_API_KEY environment variable by default. |
| 259 | +
|
| 260 | + api_url: The URL of the Dispatch API to use. Uses the value of the |
| 261 | + DISPATCH_API_URL environment variable if set, otherwise |
| 262 | + defaults to the public Dispatch API (DEFAULT_API_URL). |
| 263 | +
|
| 264 | + Raises: |
| 265 | + ValueError: if the API key is missing. |
| 266 | + """ |
| 267 | + |
| 268 | + if api_key: |
| 269 | + self.api_key_from = "api_key" |
| 270 | + else: |
| 271 | + self.api_key_from = "DISPATCH_API_KEY" |
| 272 | + api_key = os.environ.get("DISPATCH_API_KEY") |
| 273 | + if not api_key: |
| 274 | + raise ValueError( |
| 275 | + "missing API key: set it with the DISPATCH_API_KEY environment variable" |
| 276 | + ) |
| 277 | + |
| 278 | + if not api_url: |
| 279 | + api_url = os.environ.get("DISPATCH_API_URL", DEFAULT_API_URL) |
| 280 | + if not api_url: |
| 281 | + raise ValueError( |
| 282 | + "missing API URL: set it with the DISPATCH_API_URL environment variable" |
| 283 | + ) |
| 284 | + |
| 285 | + logger.debug("initializing client for Dispatch API at URL %s", api_url) |
| 286 | + self.api_url = api_url |
| 287 | + self.api_key = api_key |
| 288 | + self._init_stub() |
| 289 | + |
| 290 | + def __getstate__(self): |
| 291 | + return {"api_url": self.api_url, "api_key": self.api_key} |
| 292 | + |
| 293 | + def __setstate__(self, state): |
| 294 | + self.api_url = state["api_url"] |
| 295 | + self.api_key = state["api_key"] |
| 296 | + self._init_stub() |
| 297 | + |
| 298 | + def _init_stub(self): |
| 299 | + result = urlparse(self.api_url) |
| 300 | + match result.scheme: |
| 301 | + case "http": |
| 302 | + creds = grpc.local_channel_credentials() |
| 303 | + case "https": |
| 304 | + creds = grpc.ssl_channel_credentials() |
| 305 | + case _: |
| 306 | + raise ValueError(f"Invalid API scheme: '{result.scheme}'") |
| 307 | + |
| 308 | + call_creds = grpc.access_token_call_credentials(self.api_key) |
| 309 | + creds = grpc.composite_channel_credentials(creds, call_creds) |
| 310 | + channel = grpc.secure_channel(result.netloc, creds) |
| 311 | + |
| 312 | + self._stub = dispatch_grpc.DispatchServiceStub(channel) |
| 313 | + |
| 314 | + def batch(self) -> Batch: |
| 315 | + """Returns a Batch instance that can be used to build |
| 316 | + a set of calls to dispatch.""" |
| 317 | + return Batch(self) |
| 318 | + |
| 319 | + def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]: |
| 320 | + """Dispatch function calls. |
| 321 | +
|
| 322 | + Args: |
| 323 | + calls: Calls to dispatch. |
| 324 | +
|
| 325 | + Returns: |
| 326 | + Identifiers for the function calls, in the same order as the inputs. |
| 327 | + """ |
| 328 | + calls_proto = [c._as_proto() for c in calls] |
| 329 | + logger.debug("dispatching %d function call(s)", len(calls_proto)) |
| 330 | + req = dispatch_pb.DispatchRequest(calls=calls_proto) |
| 331 | + |
| 332 | + try: |
| 333 | + resp = self._stub.Dispatch(req) |
| 334 | + except grpc.RpcError as e: |
| 335 | + status_code = e.code() |
| 336 | + match status_code: |
| 337 | + case grpc.StatusCode.UNAUTHENTICATED: |
| 338 | + raise PermissionError( |
| 339 | + f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" |
| 340 | + ) from e |
| 341 | + raise |
| 342 | + |
| 343 | + dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] |
| 344 | + if logger.isEnabledFor(logging.DEBUG): |
| 345 | + logger.debug( |
| 346 | + "dispatched %d function call(s): %s", |
| 347 | + len(calls_proto), |
| 348 | + ", ".join(dispatch_ids), |
| 349 | + ) |
| 350 | + return dispatch_ids |
| 351 | + |
| 352 | + |
| 353 | +class Batch: |
| 354 | + """A batch of calls to dispatch.""" |
| 355 | + |
| 356 | + __slots__ = ("client", "calls") |
| 357 | + |
| 358 | + def __init__(self, client: Client): |
| 359 | + self.client = client |
| 360 | + self.calls: list[Call] = [] |
| 361 | + |
| 362 | + def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch: |
| 363 | + """Add a call to the specified function to the batch.""" |
| 364 | + return self.add_call(func.build_call(*args, correlation_id=None, **kwargs)) |
| 365 | + |
| 366 | + def add_call(self, call: Call) -> Batch: |
| 367 | + """Add a Call to the batch.""" |
| 368 | + self.calls.append(call) |
| 369 | + return self |
| 370 | + |
| 371 | + def dispatch(self) -> list[DispatchID]: |
| 372 | + """Dispatch dispatches the calls asynchronously. |
| 373 | +
|
| 374 | + The batch is reset when the calls are dispatched successfully. |
| 375 | +
|
| 376 | + Returns: |
| 377 | + Identifiers for the function calls, in the same order they |
| 378 | + were added. |
| 379 | + """ |
| 380 | + if not self.calls: |
| 381 | + return [] |
| 382 | + |
| 383 | + dispatch_ids = self.client.dispatch(self.calls) |
| 384 | + self.reset() |
| 385 | + return dispatch_ids |
| 386 | + |
| 387 | + def reset(self): |
| 388 | + """Reset the batch.""" |
| 389 | + self.calls = [] |
0 commit comments