Skip to content

Commit 6a5a0d4

Browse files
authored
✨ allow concurrent enter context
1 parent b4766cb commit 6a5a0d4

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

githubkit/core.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from types import TracebackType
2+
from contextvars import ContextVar
23
from contextlib import contextmanager, asynccontextmanager
34
from typing import (
45
Any,
@@ -129,14 +130,18 @@ def __init__(
129130
base_url, accept_format, previews, user_agent, follow_redirects, timeout
130131
)
131132

132-
self.__sync_client: Optional[httpx.Client] = None
133-
self.__async_client: Optional[httpx.AsyncClient] = None
133+
self.__sync_client: ContextVar[Optional[httpx.Client]] = ContextVar(
134+
"sync_client", default=None
135+
)
136+
self.__async_client: ContextVar[Optional[httpx.AsyncClient]] = ContextVar(
137+
"async_client", default=None
138+
)
134139

135140
# sync context
136141
def __enter__(self):
137-
if self.__sync_client is not None:
142+
if self.__sync_client.get() is not None:
138143
raise RuntimeError("Cannot enter sync context twice")
139-
self.__sync_client = self._create_sync_client()
144+
self.__sync_client.set(self._create_sync_client())
140145
return self
141146

142147
def __exit__(
@@ -145,14 +150,14 @@ def __exit__(
145150
exc_value: Optional[BaseException] = None,
146151
traceback: Optional[TracebackType] = None,
147152
):
148-
cast(httpx.Client, self.__sync_client).close()
149-
self.__sync_client = None
153+
cast(httpx.Client, self.__sync_client.get()).close()
154+
self.__sync_client.set(None)
150155

151156
# async context
152157
async def __aenter__(self):
153-
if self.__async_client is not None:
158+
if self.__async_client.get() is not None:
154159
raise RuntimeError("Cannot enter async context twice")
155-
self.__async_client = self._create_async_client()
160+
self.__async_client.set(self._create_async_client())
156161
return self
157162

158163
async def __aexit__(
@@ -161,8 +166,8 @@ async def __aexit__(
161166
exc_value: Optional[BaseException] = None,
162167
traceback: Optional[TracebackType] = None,
163168
):
164-
await cast(httpx.AsyncClient, self.__async_client).aclose()
165-
self.__async_client = None
169+
await cast(httpx.AsyncClient, self.__async_client.get()).aclose()
170+
self.__async_client.set(None)
166171

167172
# default args for creating client
168173
def _get_client_defaults(self):
@@ -184,8 +189,8 @@ def _create_sync_client(self) -> httpx.Client:
184189
# get or create sync client
185190
@contextmanager
186191
def get_sync_client(self) -> Generator[httpx.Client, None, None]:
187-
if self.__sync_client:
188-
yield self.__sync_client
192+
if client := self.__sync_client.get():
193+
yield client
189194
else:
190195
client = self._create_sync_client()
191196
try:
@@ -200,8 +205,8 @@ def _create_async_client(self) -> httpx.AsyncClient:
200205
# get or create async client
201206
@asynccontextmanager
202207
async def get_async_client(self) -> AsyncGenerator[httpx.AsyncClient, None]:
203-
if self.__async_client:
204-
yield self.__async_client
208+
if client := self.__async_client.get():
209+
yield client
205210
else:
206211
client = self._create_async_client()
207212
try:

0 commit comments

Comments
 (0)