6
6
from collections import deque
7
7
from typing import Deque , AsyncIterator , Union , List , Optional , Dict , Callable
8
8
9
+ import logging
10
+
9
11
import ydb
10
12
from .topic_writer import (
11
13
PublicWriterSettings ,
38
40
GrpcWrapperAsyncIO ,
39
41
)
40
42
43
+ logger = logging .getLogger (__name__ )
44
+
41
45
42
46
class WriterAsyncIO :
43
47
_loop : asyncio .AbstractEventLoop
@@ -154,7 +158,6 @@ class WriterAsyncIOReconnector:
154
158
_credentials : Union [ydb .credentials .Credentials , None ]
155
159
_driver : ydb .aio .Driver
156
160
_init_message : StreamWriteMessage .InitRequest
157
- _init_info : asyncio .Future
158
161
_stream_connected : asyncio .Event
159
162
_settings : WriterSettings
160
163
_codec : PublicCodec
@@ -164,25 +167,30 @@ class WriterAsyncIOReconnector:
164
167
_codec_selector_last_codec : Optional [PublicCodec ]
165
168
_codec_selector_check_batches_interval : int
166
169
167
- _last_known_seq_no : int
168
170
if typing .TYPE_CHECKING :
169
171
_messages_for_encode : asyncio .Queue [List [InternalMessage ]]
170
172
else :
171
173
_messages_for_encode : asyncio .Queue
172
174
_messages : Deque [InternalMessage ]
173
175
_messages_future : Deque [asyncio .Future ]
174
176
_new_messages : asyncio .Queue
175
- _stop_reason : asyncio .Future
176
177
_background_tasks : List [asyncio .Task ]
177
178
179
+ _state_changed : asyncio .Event
180
+ if typing .TYPE_CHECKING :
181
+ _stop_reason : asyncio .Future [BaseException ]
182
+ else :
183
+ _stop_reason : asyncio .Future
184
+ _init_info : Optional [PublicWriterInitInfo ]
185
+
178
186
def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
179
187
self ._closed = False
180
188
self ._loop = asyncio .get_running_loop ()
181
189
self ._driver = driver
182
190
self ._credentials = driver ._credentials
183
191
self ._init_message = settings .create_init_request ()
184
192
self ._new_messages = asyncio .Queue ()
185
- self ._init_info = self . _loop . create_future ()
193
+ self ._init_info = None
186
194
self ._stream_connected = asyncio .Event ()
187
195
self ._settings = settings
188
196
@@ -219,14 +227,17 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
219
227
asyncio .create_task (self ._encode_loop (), name = "encode_loop" ),
220
228
]
221
229
230
+ self ._state_changed = asyncio .Event ()
231
+
222
232
async def close (self , flush : bool ):
223
233
if self ._closed :
224
234
return
235
+ self ._closed = True
236
+ logger .debug ("Close writer reconnector" )
225
237
226
238
if flush :
227
239
await self .flush ()
228
240
229
- self ._closed = True
230
241
self ._stop (TopicWriterStopped ())
231
242
232
243
for task in self ._background_tasks :
@@ -240,19 +251,20 @@ async def close(self, flush: bool):
240
251
pass
241
252
242
253
async def wait_init (self ) -> PublicWriterInitInfo :
243
- done , _ = await asyncio .wait (
244
- [self ._init_info , self ._stop_reason ], return_when = asyncio .FIRST_COMPLETED
245
- )
246
- res = done .pop () # type: asyncio.Future
247
- res_val = res .result ()
254
+ while True :
255
+ if self ._stop_reason .done ():
256
+ raise self ._stop_reason .exception ()
248
257
249
- if isinstance ( res_val , BaseException ) :
250
- raise res_val
258
+ if self . _init_info :
259
+ return self . _init_info
251
260
252
- return res_val
261
+ await self . _state_changed . wait ()
253
262
254
- async def wait_stop (self ) -> Exception :
255
- return await self ._stop_reason
263
+ async def wait_stop (self ) -> BaseException :
264
+ try :
265
+ await self ._stop_reason
266
+ except BaseException as stop_reason :
267
+ return stop_reason
256
268
257
269
async def write_with_ack_future (
258
270
self , messages : List [PublicMessage ]
@@ -343,13 +355,14 @@ async def _connection_loop(self):
343
355
self ._settings .update_token_interval ,
344
356
)
345
357
try :
346
- self ._last_known_seq_no = stream_writer . last_seqno
347
- self ._init_info . set_result (
348
- PublicWriterInitInfo (
358
+ if self ._init_info is None :
359
+ self ._last_known_seq_no = stream_writer . last_seqno
360
+ self . _init_info = PublicWriterInitInfo (
349
361
last_seqno = stream_writer .last_seqno ,
350
362
supported_codecs = stream_writer .supported_codecs ,
351
363
)
352
- )
364
+ self ._state_changed .set ()
365
+
353
366
except asyncio .InvalidStateError :
354
367
pass
355
368
@@ -369,9 +382,6 @@ async def _connection_loop(self):
369
382
await stream_writer .close ()
370
383
done .pop ().result ()
371
384
except issues .Error as err :
372
- # todo log error
373
- print (err )
374
-
375
385
err_info = check_retriable_error (err , retry_settings , attempt )
376
386
if not err_info .is_retriable :
377
387
self ._stop (err )
@@ -550,8 +560,13 @@ def _stop(self, reason: Exception):
550
560
551
561
self ._stop_reason .set_result (reason )
552
562
563
+ for f in self ._messages_future :
564
+ f .set_exception (reason )
565
+
566
+ self ._state_changed .set ()
567
+ logger .info ("Stop topic writer: %s" % reason )
568
+
553
569
async def flush (self ):
554
- self ._check_stop ()
555
570
if not self ._messages_future :
556
571
return
557
572
0 commit comments