@@ -59,17 +59,15 @@ def __init__(
59
59
name : str ,
60
60
primitive_func : PrimitiveFunctionType ,
61
61
func : Callable ,
62
+ coroutine : bool = False ,
62
63
):
63
64
self ._endpoint = endpoint
64
65
self ._client = client
65
66
self ._name = name
66
67
self ._primitive_func = primitive_func
67
68
# FIXME: is there a way to decorate the function at the definition
68
69
# without making it a class method?
69
- if inspect .iscoroutinefunction (func ):
70
- self ._func = durable (self ._call_async )
71
- else :
72
- self ._func = func
70
+ self ._func = durable (self ._call_async ) if coroutine else func
73
71
74
72
def __call__ (self , * args , ** kwargs ):
75
73
return self ._func (* args , ** kwargs )
@@ -204,7 +202,7 @@ def primitive_func(input: Input) -> Output:
204
202
primitive_func .__qualname__ = f"{ func .__qualname__ } _primitive"
205
203
primitive_func = durable (primitive_func )
206
204
207
- return self ._register (func , primitive_func )
205
+ return self ._register (primitive_func , func , coroutine = False )
208
206
209
207
def _register_coroutine (self , func : Callable ) -> Function :
210
208
logger .info ("registering coroutine: %s" , func .__qualname__ )
@@ -218,22 +216,22 @@ def primitive_func(input: Input) -> Output:
218
216
primitive_func .__qualname__ = f"{ func .__qualname__ } _primitive"
219
217
primitive_func = durable (primitive_func )
220
218
221
- return self ._register (func , primitive_func )
219
+ return self ._register (primitive_func , func , coroutine = True )
222
220
223
221
def _register_primitive_function (self , func : PrimitiveFunctionType ) -> Function :
224
222
logger .info ("registering primitive function: %s" , func .__qualname__ )
225
- return self ._register (func , func )
223
+ return self ._register (func , func , coroutine = inspect . iscoroutinefunction ( func ) )
226
224
227
225
def _register (
228
- self , func : Callable , primitive_func : PrimitiveFunctionType
226
+ self , primitive_func : PrimitiveFunctionType , func : Callable , coroutine : bool
229
227
) -> Function :
230
228
name = func .__qualname__
231
229
if name in self ._functions :
232
230
raise ValueError (
233
231
f"function or coroutine already registered with name '{ name } '"
234
232
)
235
233
wrapped_func = Function (
236
- self ._endpoint , self ._client , name , primitive_func , func
234
+ self ._endpoint , self ._client , name , primitive_func , func , coroutine
237
235
)
238
236
self ._functions [name ] = wrapped_func
239
237
return wrapped_func
0 commit comments