Skip to content

Commit 438a9b6

Browse files
committed
please help
1 parent cb99a63 commit 438a9b6

File tree

7 files changed

+122
-56
lines changed

7 files changed

+122
-56
lines changed

ydb/_topic_reader/datatypes.py

+1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class PublicBatch(ICommittable, ISessionAlive):
163163
_partition_session: PartitionSession
164164
_bytes_size: int
165165
_codec: Codec
166+
_commited_with_tx: bool = False
166167

167168
def _commit_get_partition_session(self) -> PartitionSession:
168169
return self.messages[0]._commit_get_partition_session()

ydb/_topic_reader/topic_reader_asyncio.py

+64-18
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import gzip
66
import typing
77
from asyncio import Task
8-
from collections import OrderedDict
8+
from collections import OrderedDict, defaultdict
99
from typing import Optional, Set, Dict, Union, Callable
1010

1111
import ydb
@@ -140,7 +140,8 @@ async def receive_batch_with_tx(
140140
use asyncio.wait_for for wait with timeout.
141141
"""
142142
await self._reconnector.wait_message()
143-
return await self._reconnector.receive_batch_with_tx_nowait(
143+
tx._add_listener(self)
144+
return self._reconnector.receive_batch_with_tx_nowait(
144145
tx,
145146
max_messages=max_messages,
146147
)
@@ -177,11 +178,14 @@ async def close(self, flush: bool = True):
177178
self._closed = True
178179
await self._reconnector.close(flush)
179180

180-
def _on_after_commit(self, exc):
181-
return super()._on_after_commit(exc)
181+
async def _on_before_commit(self, tx):
182+
await self._reconnector._on_before_commit(tx)
182183

183-
def _on_after_rollback(self, exc):
184-
return super()._on_after_rollback(exc)
184+
async def _on_after_commit(self, tx, exc):
185+
await self._reconnector._on_after_commit(tx, exc)
186+
187+
async def _on_after_rollback(self, tx, exc):
188+
await self._reconnector._on_after_rollback(tx, exc)
185189

186190

187191
class ReaderReconnector:
@@ -195,8 +199,10 @@ class ReaderReconnector:
195199
_state_changed: asyncio.Event
196200
_stream_reader: Optional["ReaderStream"]
197201
_first_error: asyncio.Future[YdbError]
198-
_batches_to_commit: asyncio.Queue
199-
_wait_executor: Optional[concurrent.futures.ThreadPoolExecutor]
202+
203+
_batches_to_commit_with_tx: asyncio.Queue
204+
_tx_to_batches: typing.Dict[str, typing.List[datatypes.PublicBatch]]
205+
_wait_executor: concurrent.futures.ThreadPoolExecutor
200206

201207
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
202208
self._id = self._static_reader_reconnector_counter.inc_and_get()
@@ -210,7 +216,9 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
210216
self._first_error = asyncio.get_running_loop().create_future()
211217
self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
212218

213-
self._batches_to_commit = asyncio.Queue()
219+
self._batches_to_commit_with_tx = asyncio.Queue()
220+
self._tx_to_batches = defaultdict(list)
221+
self._background_tasks.add(asyncio.create_task(self._update_offsets_in_tx_loop()))
214222

215223
async def _connection_loop(self):
216224
attempt = 0
@@ -263,19 +271,21 @@ def receive_message_nowait(self):
263271
def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter:
264272
return self._stream_reader.commit(batch)
265273

266-
async def _commit_with_tx(self, tx: "BaseQueryTxContext", batch: datatypes.ICommittable) -> None:
267-
pass
268-
269-
async def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None):
274+
def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None):
270275
batch = self.receive_batch_nowait(max_messages=max_messages)
271-
tx._add_listener(batch)
272-
await self._update_offsets_in_tx_call(self._driver, tx, batch)
276+
self._tx_to_batches[tx.tx_id].append(batch)
277+
278+
self._batches_to_commit_with_tx.put_nowait((tx, batch))
279+
280+
print("batch recieved")
281+
273282
return batch
274-
# self._batches_to_commit.put_nowait((tx, batch))
275283

276284
async def _update_offsets_in_tx_loop(self):
277285
while True:
278-
tx, batch = self._batches_to_commit.get()
286+
print("_update_offsets_in_tx_loop")
287+
288+
tx, batch = await self._batches_to_commit_with_tx.get()
279289
await self._update_offsets_in_tx_call(self._driver, tx, batch)
280290

281291
async def _update_offsets_in_tx_call(self, driver: SupportedDriverType, tx: "BaseQueryTxContext", batch: datatypes.ICommittable) -> None:
@@ -309,9 +319,25 @@ async def _update_offsets_in_tx_call(self, driver: SupportedDriverType, tx: "Bas
309319
else:
310320
res = await to_thread(driver, *args, executor=self._wait_executor)
311321

322+
batch._commited_with_tx = True
323+
312324
return res
313325
except BaseException as e:
314-
self._set_first_error(e)
326+
self._stream_reader._set_first_error(e)
327+
328+
async def _ensure_all_batches_commited_with_tx(self, tx: "BaseQueryTxContext"):
329+
while True:
330+
print("_ensure_all_batches_commited_with_tx")
331+
if tx.tx_id not in self._tx_to_batches:
332+
# we should not be here
333+
return True
334+
batches = self._tx_to_batches.get(tx.tx_id)
335+
everything_commited = True
336+
for batch in batches:
337+
everything_commited = everything_commited and batch._commited_with_tx
338+
if everything_commited:
339+
return True
340+
await asyncio.sleep(0.001)
315341

316342
async def close(self, flush: bool):
317343
if self._stream_reader:
@@ -336,6 +362,26 @@ def _set_first_error(self, err: issues.Error):
336362
# skip if already has result
337363
pass
338364

365+
async def _on_before_commit(self, tx):
366+
print("on before commit")
367+
await asyncio.wait_for(self._ensure_all_batches_commited_with_tx(tx), 1)
368+
pass
369+
370+
async def _on_after_commit(self, tx, exc):
371+
print(f"on after commit, exc = {exc is not None}")
372+
373+
if exc:
374+
self._stream_reader._set_first_error(exc)
375+
for batch in self._tx_to_batches[tx.tx_id]:
376+
batch._partition_session.committed_offset = max(batch._partition_session.committed_offset, batch._commit_get_offsets_range().end)
377+
del self._tx_to_batches[tx.tx_id]
378+
379+
async def _on_after_rollback(self, tx, exc):
380+
print(f"on after rollback, exc = {exc is not None}")
381+
print(exc)
382+
exc = exc if exc is not None else issues.InternalError("tx rollback failed")
383+
self._stream_reader._set_first_error(exc)
384+
339385

340386
class ReaderStream:
341387
_static_id_counter = AtomicCounter()

ydb/_topic_reader/topic_reader_sync.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@
2020
TopicReaderClosedError,
2121
)
2222

23+
from ..query.base import TxListener
2324

24-
class TopicReaderSync:
25+
if typing.TYPE_CHECKING:
26+
from ..query.transaction import BaseQueryTxContext
27+
28+
29+
class TopicReaderSync(TxListener):
2530
_caller: CallFromSyncToAsync
2631
_async_reader: PublicAsyncIOReader
2732
_closed: bool
@@ -155,3 +160,23 @@ def close(self, *, flush: bool = True, timeout: TimeoutType = None):
155160
def _check_closed(self):
156161
if self._closed:
157162
raise TopicReaderClosedError()
163+
164+
def _on_before_commit(self, tx: "BaseQueryTxContext"):
165+
self._check_closed()
166+
167+
return self._caller.unsafe_call_with_result(self._async_reader._on_before_commit(tx), 5)
168+
169+
def _on_after_commit(self, tx: "BaseQueryTxContext", exc):
170+
self._check_closed()
171+
172+
return self._caller.unsafe_call_with_result(self._async_reader._on_after_commit(tx, exc), 5)
173+
174+
def _on_before_rollback(self, tx: "BaseQueryTxContext"):
175+
self._check_closed()
176+
177+
return self._caller.unsafe_call_with_result(self._async_reader._on_before_rollback(tx), 5)
178+
179+
def _on_after_rollback(self, tx: "BaseQueryTxContext", exc):
180+
self._check_closed()
181+
182+
return self._caller.unsafe_call_with_result(self._async_reader._on_after_rollback(tx, exc), 5)

ydb/_topic_writer/topic_writer_asyncio.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ def __init__(
187187
self._parent = _client
188188
self._tx._add_listener(self)
189189

190-
async def _on_before_commit(self):
190+
async def _on_before_commit(self, **kwargs):
191191
await self.close()
192192

193-
async def _on_before_rollback(self):
193+
async def _on_before_rollback(self, **kwargs):
194194
await self.close(flush=False)
195195

196196

ydb/_topic_writer/topic_writer_sync.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ async def create_async_writer():
160160

161161
tx._add_listener(self)
162162

163-
def _on_before_commit(self):
163+
def _on_before_commit(self, **kwargs):
164164
self.close()
165165

166-
def _on_before_rollback(self):
166+
def _on_before_rollback(self, **kwargs):
167167
self.close(flush=False)

ydb/aio/query/transaction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryT
5757
await self._begin_call(settings)
5858
return self
5959

60-
@base.with_async_transaction_events
60+
@base.with_transaction_events
6161
async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None:
6262
"""Calls commit on a transaction if it is open otherwise is no-op. If transaction execution
6363
failed then this method raises PreconditionFailed.
@@ -77,7 +77,7 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None:
7777

7878
await self._commit_call(settings)
7979

80-
@base.with_async_transaction_events
80+
@base.with_transaction_events
8181
async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None:
8282
"""Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution
8383
failed then this method raises PreconditionFailed.

ydb/query/base.py

+25-31
Original file line numberDiff line numberDiff line change
@@ -200,74 +200,76 @@ def wrap_execute_query_response(
200200

201201

202202
class TxListener:
203-
def _on_before_commit(self):
203+
def _on_before_commit(self, tx: "BaseQueryTxContext"):
204204
pass
205205

206-
def _on_after_commit(self, exc: typing.Optional[BaseException]):
206+
def _on_after_commit(self, tx: "BaseQueryTxContext", exc: typing.Optional[BaseException]):
207207
pass
208208

209-
def _on_before_rollback(self):
209+
def _on_before_rollback(self, tx: "BaseQueryTxContext"):
210210
pass
211211

212-
def _on_after_rollback(self, exc: typing.Optional[BaseException]):
212+
def _on_after_rollback(self, tx: "BaseQueryTxContext", exc: typing.Optional[BaseException]):
213213
pass
214214

215215

216216
class TxListenerAsyncIO:
217-
async def _on_before_commit(self):
217+
async def _on_before_commit(self, tx: "BaseQueryTxContext"):
218218
pass
219219

220-
async def _on_after_commit(self, exc: typing.Optional[BaseException]):
220+
async def _on_after_commit(self, tx: "BaseQueryTxContext", exc: typing.Optional[BaseException]):
221221
pass
222222

223-
async def _on_before_rollback(self):
223+
async def _on_before_rollback(self, tx: "BaseQueryTxContext"):
224224
pass
225225

226-
async def _on_after_rollback(self, exc: typing.Optional[BaseException]):
226+
async def _on_after_rollback(self, tx: "BaseQueryTxContext", exc: typing.Optional[BaseException]):
227227
pass
228228

229229

230230
def with_transaction_events(method):
231231
@functools.wraps(method)
232-
def wrapper(self, *args, **kwargs):
232+
def wrapper(self: ListenerHandlerMixin, *args, **kwargs):
233233
method_name = method.__name__
234234
before_event = f"_on_before_{method_name}"
235235
after_event = f"_on_after_{method_name}"
236236

237-
self._notify_listeners_sync(before_event)
237+
self._notify_listeners_sync(before_event, tx=self)
238238

239239
try:
240240
result = method(self, *args, **kwargs)
241241

242-
self._notify_listeners_sync(after_event, exc=None)
242+
self._notify_listeners_sync(after_event, tx=self, exc=None)
243243

244244
return result
245245
except BaseException as e:
246-
self._notify_listeners_sync(after_event, exc=e)
246+
self._notify_listeners_sync(after_event, tx=self, exc=e)
247247
raise
248+
finally:
249+
self._clear_listeners()
248250

249-
return wrapper
250-
251-
252-
def with_async_transaction_events(method):
253251
@functools.wraps(method)
254-
async def wrapper(self, *args, **kwargs):
252+
async def async_wrapper(self: ListenerHandlerMixin, *args, **kwargs):
255253
method_name = method.__name__
256254
before_event = f"_on_before_{method_name}"
257255
after_event = f"_on_after_{method_name}"
258256

259-
await self._notify_listeners_async(before_event)
257+
await self._notify_listeners_async(before_event, tx=self)
260258

261259
try:
262260
result = await method(self, *args, **kwargs)
263261

264-
await self._notify_listeners_async(after_event, exc=None)
262+
await self._notify_listeners_async(after_event, tx=self, exc=None)
265263

266264
return result
267265
except BaseException as e:
268-
await self._notify_listeners_async(after_event, exc=e)
266+
await self._notify_listeners_async(after_event, tx=self, exc=e)
269267
raise
268+
finally:
269+
self._clear_listeners()
270270

271+
if asyncio.iscoroutinefunction(method):
272+
return async_wrapper
271273
return wrapper
272274

273275

@@ -289,24 +291,16 @@ def _clear_listeners(self):
289291
self.listeners.clear()
290292
return self
291293

292-
def _notify_sync_listeners(self, event_name: str, **kwargs) -> None:
294+
def _notify_listeners_sync(self, event_name: str, **kwargs) -> None:
293295
for listener in self.listeners:
294296
if isinstance(listener, TxListener) and hasattr(listener, event_name):
295297
getattr(listener, event_name)(**kwargs)
296298

297-
async def _notify_async_listeners(self, event_name: str, **kwargs) -> None:
299+
async def _notify_listeners_async(self, event_name: str, **kwargs) -> None:
298300
coros = []
299301
for listener in self.listeners:
300302
if isinstance(listener, TxListenerAsyncIO) and hasattr(listener, event_name):
301303
coros.append(getattr(listener, event_name)(**kwargs))
302304

303305
if coros:
304-
await asyncio.gather(*coros)
305-
306-
def _notify_listeners_sync(self, event_name: str, **kwargs) -> None:
307-
self._notify_sync_listeners(event_name, **kwargs)
308-
309-
async def _notify_listeners_async(self, event_name: str, **kwargs) -> None:
310-
# self._notify_sync_listeners(event_name, **kwargs)
311-
312-
await self._notify_async_listeners(event_name, **kwargs)
306+
await asyncio.gather(*coros)

0 commit comments

Comments
 (0)