@@ -59,14 +59,15 @@ def __init__(
59
59
name : str ,
60
60
primitive_func : PrimitiveFunctionType ,
61
61
func : Callable ,
62
+ coroutine : bool = False ,
62
63
):
63
64
self ._endpoint = endpoint
64
65
self ._client = client
65
66
self ._name = name
66
67
self ._primitive_func = primitive_func
67
68
# FIXME: is there a way to decorate the function at the definition
68
69
# without making it a class method?
69
- if inspect . iscoroutinefunction ( func ) :
70
+ if coroutine :
70
71
self ._func = durable (self ._call_async )
71
72
else :
72
73
self ._func = func
@@ -204,7 +205,7 @@ def primitive_func(input: Input) -> Output:
204
205
primitive_func .__qualname__ = f"{ func .__qualname__ } _primitive"
205
206
primitive_func = durable (primitive_func )
206
207
207
- return self ._register (func , primitive_func )
208
+ return self ._register (primitive_func , func , coroutine = False )
208
209
209
210
def _register_coroutine (self , func : Callable ) -> Function :
210
211
logger .info ("registering coroutine: %s" , func .__qualname__ )
@@ -218,22 +219,22 @@ def primitive_func(input: Input) -> Output:
218
219
primitive_func .__qualname__ = f"{ func .__qualname__ } _primitive"
219
220
primitive_func = durable (primitive_func )
220
221
221
- return self ._register (func , primitive_func )
222
+ return self ._register (primitive_func , func , coroutine = True )
222
223
223
224
def _register_primitive_function (self , func : PrimitiveFunctionType ) -> Function :
224
225
logger .info ("registering primitive function: %s" , func .__qualname__ )
225
- return self ._register (func , func )
226
+ return self ._register (func , func , coroutine = inspect . iscoroutinefunction ( func ) )
226
227
227
228
def _register (
228
- self , func : Callable , primitive_func : PrimitiveFunctionType
229
+ self , primitive_func : PrimitiveFunctionType , func : Callable , coroutine : bool
229
230
) -> Function :
230
231
name = func .__qualname__
231
232
if name in self ._functions :
232
233
raise ValueError (
233
234
f"function or coroutine already registered with name '{ name } '"
234
235
)
235
236
wrapped_func = Function (
236
- self ._endpoint , self ._client , name , primitive_func , func
237
+ self ._endpoint , self ._client , name , primitive_func , func , coroutine
237
238
)
238
239
self ._functions [name ] = wrapped_func
239
240
return wrapped_func
0 commit comments