1
1
from types import TracebackType
2
+ from contextvars import ContextVar
2
3
from contextlib import contextmanager , asynccontextmanager
3
4
from typing import (
4
5
Any ,
@@ -129,14 +130,18 @@ def __init__(
129
130
base_url , accept_format , previews , user_agent , follow_redirects , timeout
130
131
)
131
132
132
- self .__sync_client : Optional [httpx .Client ] = None
133
- self .__async_client : Optional [httpx .AsyncClient ] = None
133
+ self .__sync_client : ContextVar [Optional [httpx .Client ]] = ContextVar (
134
+ "sync_client" , default = None
135
+ )
136
+ self .__async_client : ContextVar [Optional [httpx .AsyncClient ]] = ContextVar (
137
+ "async_client" , default = None
138
+ )
134
139
135
140
# sync context
136
141
def __enter__ (self ):
137
- if self .__sync_client is not None :
142
+ if self .__sync_client . get () is not None :
138
143
raise RuntimeError ("Cannot enter sync context twice" )
139
- self .__sync_client = self ._create_sync_client ()
144
+ self .__sync_client . set ( self ._create_sync_client () )
140
145
return self
141
146
142
147
def __exit__ (
@@ -145,14 +150,14 @@ def __exit__(
145
150
exc_value : Optional [BaseException ] = None ,
146
151
traceback : Optional [TracebackType ] = None ,
147
152
):
148
- cast (httpx .Client , self .__sync_client ).close ()
149
- self .__sync_client = None
153
+ cast (httpx .Client , self .__sync_client . get () ).close ()
154
+ self .__sync_client . set ( None )
150
155
151
156
# async context
152
157
async def __aenter__ (self ):
153
- if self .__async_client is not None :
158
+ if self .__async_client . get () is not None :
154
159
raise RuntimeError ("Cannot enter async context twice" )
155
- self .__async_client = self ._create_async_client ()
160
+ self .__async_client . set ( self ._create_async_client () )
156
161
return self
157
162
158
163
async def __aexit__ (
@@ -161,8 +166,8 @@ async def __aexit__(
161
166
exc_value : Optional [BaseException ] = None ,
162
167
traceback : Optional [TracebackType ] = None ,
163
168
):
164
- await cast (httpx .AsyncClient , self .__async_client ).aclose ()
165
- self .__async_client = None
169
+ await cast (httpx .AsyncClient , self .__async_client . get () ).aclose ()
170
+ self .__async_client . set ( None )
166
171
167
172
# default args for creating client
168
173
def _get_client_defaults (self ):
@@ -184,8 +189,8 @@ def _create_sync_client(self) -> httpx.Client:
184
189
# get or create sync client
185
190
@contextmanager
186
191
def get_sync_client (self ) -> Generator [httpx .Client , None , None ]:
187
- if self .__sync_client :
188
- yield self . __sync_client
192
+ if client := self .__sync_client . get () :
193
+ yield client
189
194
else :
190
195
client = self ._create_sync_client ()
191
196
try :
@@ -200,8 +205,8 @@ def _create_async_client(self) -> httpx.AsyncClient:
200
205
# get or create async client
201
206
@asynccontextmanager
202
207
async def get_async_client (self ) -> AsyncGenerator [httpx .AsyncClient , None ]:
203
- if self .__async_client :
204
- yield self . __async_client
208
+ if client := self .__async_client . get () :
209
+ yield client
205
210
else :
206
211
client = self ._create_async_client ()
207
212
try :
0 commit comments