Skip to content

Commit 20b8f6c

Browse files
Merge pull request #81 from stealthrocket/fix-coroutine-decoration
fix coroutine decoration
2 parents 46b4bf9 + fe5e04a commit 20b8f6c

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

src/dispatch/function.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,15 @@ def __init__(
5959
name: str,
6060
primitive_func: PrimitiveFunctionType,
6161
func: Callable,
62+
coroutine: bool = False,
6263
):
6364
self._endpoint = endpoint
6465
self._client = client
6566
self._name = name
6667
self._primitive_func = primitive_func
6768
# FIXME: is there a way to decorate the function at the definition
6869
# without making it a class method?
69-
if inspect.iscoroutinefunction(func):
70-
self._func = durable(self._call_async)
71-
else:
72-
self._func = func
70+
self._func = durable(self._call_async) if coroutine else func
7371

7472
def __call__(self, *args, **kwargs):
7573
return self._func(*args, **kwargs)
@@ -204,7 +202,7 @@ def primitive_func(input: Input) -> Output:
204202
primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
205203
primitive_func = durable(primitive_func)
206204

207-
return self._register(func, primitive_func)
205+
return self._register(primitive_func, func, coroutine=False)
208206

209207
def _register_coroutine(self, func: Callable) -> Function:
210208
logger.info("registering coroutine: %s", func.__qualname__)
@@ -218,22 +216,22 @@ def primitive_func(input: Input) -> Output:
218216
primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
219217
primitive_func = durable(primitive_func)
220218

221-
return self._register(func, primitive_func)
219+
return self._register(primitive_func, func, coroutine=True)
222220

223221
def _register_primitive_function(self, func: PrimitiveFunctionType) -> Function:
224222
logger.info("registering primitive function: %s", func.__qualname__)
225-
return self._register(func, func)
223+
return self._register(func, func, coroutine=inspect.iscoroutinefunction(func))
226224

227225
def _register(
228-
self, func: Callable, primitive_func: PrimitiveFunctionType
226+
self, primitive_func: PrimitiveFunctionType, func: Callable, coroutine: bool
229227
) -> Function:
230228
name = func.__qualname__
231229
if name in self._functions:
232230
raise ValueError(
233231
f"function or coroutine already registered with name '{name}'"
234232
)
235233
wrapped_func = Function(
236-
self._endpoint, self._client, name, primitive_func, func
234+
self._endpoint, self._client, name, primitive_func, func, coroutine
237235
)
238236
self._functions[name] = wrapped_func
239237
return wrapped_func

0 commit comments

Comments
 (0)