@@ -73,6 +73,8 @@ def my_config(binder):
73
73
inject.configure(my_config)
74
74
75
75
"""
76
+ import contextlib
77
+
76
78
from inject ._version import __version__
77
79
78
80
import inspect
@@ -335,12 +337,28 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[...,
335
337
async def async_injection_wrapper (* args : Any , ** kwargs : Any ) -> T :
336
338
provided_params = frozenset (
337
339
arg_names [:len (args )]) | frozenset (kwargs .keys ())
340
+ ctx_managers = {}
341
+ async_ctx_managers = {}
338
342
for param , cls in params_to_provide .items ():
339
343
if param not in provided_params :
340
- kwargs [param ] = instance (cls )
344
+ inst = instance (cls )
345
+ if isinstance (inst , contextlib .AbstractContextManager ):
346
+ ctx_managers [param ] = inst
347
+ elif isinstance (inst , contextlib .AbstractAsyncContextManager ):
348
+ async_ctx_managers [param ] = inst
349
+ else :
350
+ kwargs [param ] = inst
341
351
async_func = cast (Callable [..., Awaitable [T ]], func )
342
352
try :
343
- return await async_func (* args , ** kwargs )
353
+ with contextlib .ExitStack () as sync_stack :
354
+ ctx_kwargs = {param : sync_stack .enter_context (ctx_manager ) for param , ctx_manager in
355
+ ctx_managers .items ()}
356
+ kwargs .update (ctx_kwargs )
357
+ async with contextlib .AsyncExitStack () as async_stack :
358
+ asynx_ctx_kwargs = {param : await async_stack .enter_async_context (ctx_manager ) for param , ctx_manager in
359
+ async_ctx_managers .items ()}
360
+ kwargs .update (asynx_ctx_kwargs )
361
+ return await async_func (* args , ** kwargs )
344
362
except TypeError as previous_error :
345
363
raise ConstructorTypeError (func , previous_error )
346
364
@@ -350,12 +368,20 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
350
368
def injection_wrapper (* args : Any , ** kwargs : Any ) -> T :
351
369
provided_params = frozenset (
352
370
arg_names [:len (args )]) | frozenset (kwargs .keys ())
371
+ ctx_managers = {}
353
372
for param , cls in params_to_provide .items ():
354
373
if param not in provided_params :
355
- kwargs [param ] = instance (cls )
374
+ inst = instance (cls )
375
+ if isinstance (inst , contextlib .AbstractContextManager ):
376
+ ctx_managers [param ] = inst
377
+ else :
378
+ kwargs [param ] = inst
356
379
sync_func = cast (Callable [..., T ], func )
357
380
try :
358
- return sync_func (* args , ** kwargs )
381
+ with contextlib .ExitStack () as stack :
382
+ ctx_kwargs = {param : stack .enter_context (ctx_manager ) for param , ctx_manager in ctx_managers .items ()}
383
+ kwargs .update (ctx_kwargs )
384
+ return sync_func (* args , ** kwargs )
359
385
except TypeError as previous_error :
360
386
raise ConstructorTypeError (func , previous_error )
361
387
return injection_wrapper
0 commit comments