Skip to content

Commit 2ea6b59

Browse files
authored
Merge pull request #243 stop on errors
2 parents 12a4b71 + bb9562a commit 2ea6b59

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from collections import deque
77
from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable
88

9+
import logging
10+
911
import ydb
1012
from .topic_writer import (
1113
PublicWriterSettings,
@@ -38,6 +40,8 @@
3840
GrpcWrapperAsyncIO,
3941
)
4042

43+
logger = logging.getLogger(__name__)
44+
4145

4246
class WriterAsyncIO:
4347
_loop: asyncio.AbstractEventLoop
@@ -154,7 +158,6 @@ class WriterAsyncIOReconnector:
154158
_credentials: Union[ydb.credentials.Credentials, None]
155159
_driver: ydb.aio.Driver
156160
_init_message: StreamWriteMessage.InitRequest
157-
_init_info: asyncio.Future
158161
_stream_connected: asyncio.Event
159162
_settings: WriterSettings
160163
_codec: PublicCodec
@@ -164,25 +167,30 @@ class WriterAsyncIOReconnector:
164167
_codec_selector_last_codec: Optional[PublicCodec]
165168
_codec_selector_check_batches_interval: int
166169

167-
_last_known_seq_no: int
168170
if typing.TYPE_CHECKING:
169171
_messages_for_encode: asyncio.Queue[List[InternalMessage]]
170172
else:
171173
_messages_for_encode: asyncio.Queue
172174
_messages: Deque[InternalMessage]
173175
_messages_future: Deque[asyncio.Future]
174176
_new_messages: asyncio.Queue
175-
_stop_reason: asyncio.Future
176177
_background_tasks: List[asyncio.Task]
177178

179+
_state_changed: asyncio.Event
180+
if typing.TYPE_CHECKING:
181+
_stop_reason: asyncio.Future[BaseException]
182+
else:
183+
_stop_reason: asyncio.Future
184+
_init_info: Optional[PublicWriterInitInfo]
185+
178186
def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
179187
self._closed = False
180188
self._loop = asyncio.get_running_loop()
181189
self._driver = driver
182190
self._credentials = driver._credentials
183191
self._init_message = settings.create_init_request()
184192
self._new_messages = asyncio.Queue()
185-
self._init_info = self._loop.create_future()
193+
self._init_info = None
186194
self._stream_connected = asyncio.Event()
187195
self._settings = settings
188196

@@ -219,14 +227,17 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
219227
asyncio.create_task(self._encode_loop(), name="encode_loop"),
220228
]
221229

230+
self._state_changed = asyncio.Event()
231+
222232
async def close(self, flush: bool):
223233
if self._closed:
224234
return
235+
self._closed = True
236+
logger.debug("Close writer reconnector")
225237

226238
if flush:
227239
await self.flush()
228240

229-
self._closed = True
230241
self._stop(TopicWriterStopped())
231242

232243
for task in self._background_tasks:
@@ -240,19 +251,20 @@ async def close(self, flush: bool):
240251
pass
241252

242253
async def wait_init(self) -> PublicWriterInitInfo:
243-
done, _ = await asyncio.wait(
244-
[self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED
245-
)
246-
res = done.pop() # type: asyncio.Future
247-
res_val = res.result()
254+
while True:
255+
if self._stop_reason.done():
256+
raise self._stop_reason.exception()
248257

249-
if isinstance(res_val, BaseException):
250-
raise res_val
258+
if self._init_info:
259+
return self._init_info
251260

252-
return res_val
261+
await self._state_changed.wait()
253262

254-
async def wait_stop(self) -> Exception:
255-
return await self._stop_reason
263+
async def wait_stop(self) -> BaseException:
264+
try:
265+
await self._stop_reason
266+
except BaseException as stop_reason:
267+
return stop_reason
256268

257269
async def write_with_ack_future(
258270
self, messages: List[PublicMessage]
@@ -343,13 +355,14 @@ async def _connection_loop(self):
343355
self._settings.update_token_interval,
344356
)
345357
try:
346-
self._last_known_seq_no = stream_writer.last_seqno
347-
self._init_info.set_result(
348-
PublicWriterInitInfo(
358+
if self._init_info is None:
359+
self._last_known_seq_no = stream_writer.last_seqno
360+
self._init_info = PublicWriterInitInfo(
349361
last_seqno=stream_writer.last_seqno,
350362
supported_codecs=stream_writer.supported_codecs,
351363
)
352-
)
364+
self._state_changed.set()
365+
353366
except asyncio.InvalidStateError:
354367
pass
355368

@@ -369,9 +382,6 @@ async def _connection_loop(self):
369382
await stream_writer.close()
370383
done.pop().result()
371384
except issues.Error as err:
372-
# todo log error
373-
print(err)
374-
375385
err_info = check_retriable_error(err, retry_settings, attempt)
376386
if not err_info.is_retriable:
377387
self._stop(err)
@@ -550,8 +560,13 @@ def _stop(self, reason: Exception):
550560

551561
self._stop_reason.set_result(reason)
552562

563+
for f in self._messages_future:
564+
f.set_exception(reason)
565+
566+
self._state_changed.set()
567+
logger.info("Stop topic writer: %s" % reason)
568+
553569
async def flush(self):
554-
self._check_stop()
555570
if not self._messages_future:
556571
return
557572

0 commit comments

Comments
 (0)