Skip to content

Commit 7d161d7

Browse files
committed
Allow provider to be a context manager (sync/async)
1 parent 5be9189 commit 7d161d7

File tree

2 files changed

+137
-4
lines changed

2 files changed

+137
-4
lines changed

src/inject/__init__.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def my_config(binder):
7373
inject.configure(my_config)
7474
7575
"""
76+
import contextlib
77+
7678
from inject._version import __version__
7779

7880
import inspect
@@ -335,12 +337,28 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[...,
335337
async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
336338
provided_params = frozenset(
337339
arg_names[:len(args)]) | frozenset(kwargs.keys())
340+
ctx_managers = {}
341+
async_ctx_managers = {}
338342
for param, cls in params_to_provide.items():
339343
if param not in provided_params:
340-
kwargs[param] = instance(cls)
344+
inst = instance(cls)
345+
if isinstance(inst, contextlib.AbstractContextManager):
346+
ctx_managers[param] = inst
347+
elif isinstance(inst, contextlib.AbstractAsyncContextManager):
348+
async_ctx_managers[param] = inst
349+
else:
350+
kwargs[param] = inst
341351
async_func = cast(Callable[..., Awaitable[T]], func)
342352
try:
343-
return await async_func(*args, **kwargs)
353+
with contextlib.ExitStack() as sync_stack:
354+
ctx_kwargs = {param: sync_stack.enter_context(ctx_manager) for param, ctx_manager in
355+
ctx_managers.items()}
356+
kwargs.update(ctx_kwargs)
357+
async with contextlib.AsyncExitStack() as async_stack:
358+
asynx_ctx_kwargs = {param: await async_stack.enter_async_context(ctx_manager) for param, ctx_manager in
359+
async_ctx_managers.items()}
360+
kwargs.update(asynx_ctx_kwargs)
361+
return await async_func(*args, **kwargs)
344362
except TypeError as previous_error:
345363
raise ConstructorTypeError(func, previous_error)
346364

@@ -350,12 +368,20 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
350368
def injection_wrapper(*args: Any, **kwargs: Any) -> T:
351369
provided_params = frozenset(
352370
arg_names[:len(args)]) | frozenset(kwargs.keys())
371+
ctx_managers = {}
353372
for param, cls in params_to_provide.items():
354373
if param not in provided_params:
355-
kwargs[param] = instance(cls)
374+
inst = instance(cls)
375+
if isinstance(inst, contextlib.AbstractContextManager):
376+
ctx_managers[param] = inst
377+
else:
378+
kwargs[param] = inst
356379
sync_func = cast(Callable[..., T], func)
357380
try:
358-
return sync_func(*args, **kwargs)
381+
with contextlib.ExitStack() as stack:
382+
ctx_kwargs = {param: stack.enter_context(ctx_manager) for param, ctx_manager in ctx_managers.items()}
383+
kwargs.update(ctx_kwargs)
384+
return sync_func(*args, **kwargs)
359385
except TypeError as previous_error:
360386
raise ConstructorTypeError(func, previous_error)
361387
return injection_wrapper

test/test_context_manager.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import contextlib
2+
3+
import inject
4+
from test import BaseTestInject
5+
6+
7+
class Destroyable:
8+
def __init__(self):
9+
self.started = True
10+
11+
def destroy(self):
12+
self.started = False
13+
14+
15+
class MockFile(Destroyable):
16+
...
17+
18+
19+
class MockConnection(Destroyable):
20+
...
21+
22+
23+
class MockFoo(Destroyable):
24+
...
25+
26+
27+
@contextlib.contextmanager
28+
def get_file_sync():
29+
obj = MockFile()
30+
yield obj
31+
obj.destroy()
32+
33+
34+
@contextlib.contextmanager
35+
def get_conn_sync():
36+
obj = MockConnection()
37+
yield obj
38+
obj.destroy()
39+
40+
41+
@contextlib.contextmanager
42+
def get_foo_sync():
43+
obj = MockFoo()
44+
yield obj
45+
obj.destroy()
46+
47+
48+
@contextlib.asynccontextmanager
49+
async def get_file_async():
50+
obj = MockFile()
51+
yield obj
52+
obj.destroy()
53+
54+
55+
@contextlib.asynccontextmanager
56+
async def get_conn_async():
57+
obj = MockConnection()
58+
yield obj
59+
obj.destroy()
60+
61+
62+
class TestContextManagerFunctional(BaseTestInject):
63+
64+
def test_provider_as_context_manager_sync(self):
65+
def config(binder):
66+
binder.bind_to_provider(MockFile, get_file_sync)
67+
binder.bind(int, 100)
68+
binder.bind_to_provider(str, lambda: "Hello")
69+
binder.bind_to_provider(MockConnection, get_conn_sync)
70+
71+
inject.configure(config)
72+
73+
@inject.autoparams()
74+
def mock_func(conn: MockConnection, name: str, f: MockFile, number: int):
75+
assert f.started
76+
assert conn.started
77+
assert name == "Hello"
78+
assert number == 100
79+
return f, conn
80+
81+
f_, conn_ = mock_func()
82+
assert not f_.started
83+
assert not conn_.started
84+
85+
def test_provider_as_context_manager_async(self):
86+
def config(binder):
87+
binder.bind_to_provider(MockFile, get_file_async)
88+
binder.bind(int, 100)
89+
binder.bind_to_provider(str, lambda: "Hello")
90+
binder.bind_to_provider(MockConnection, get_conn_async)
91+
binder.bind_to_provider(MockFoo, get_foo_sync)
92+
93+
inject.configure(config)
94+
95+
@inject.autoparams()
96+
async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo):
97+
assert f.started
98+
assert conn.started
99+
assert foo.started
100+
assert name == "Hello"
101+
assert number == 100
102+
return f, conn, foo
103+
104+
f_, conn_, foo_ = self.run_async(mock_func())
105+
assert not f_.started
106+
assert not conn_.started
107+
assert not foo_.started

0 commit comments

Comments
 (0)