5
5
import gzip
6
6
import typing
7
7
from asyncio import Task
8
- from collections import OrderedDict
8
+ from collections import OrderedDict , defaultdict
9
9
from typing import Optional , Set , Dict , Union , Callable
10
10
11
11
import ydb
@@ -140,7 +140,8 @@ async def receive_batch_with_tx(
140
140
use asyncio.wait_for for wait with timeout.
141
141
"""
142
142
await self ._reconnector .wait_message ()
143
- return await self ._reconnector .receive_batch_with_tx_nowait (
143
+ tx ._add_listener (self )
144
+ return self ._reconnector .receive_batch_with_tx_nowait (
144
145
tx ,
145
146
max_messages = max_messages ,
146
147
)
@@ -177,11 +178,14 @@ async def close(self, flush: bool = True):
177
178
self ._closed = True
178
179
await self ._reconnector .close (flush )
179
180
180
- def _on_after_commit (self , exc ):
181
- return super (). _on_after_commit ( exc )
181
+ async def _on_before_commit (self , tx ):
182
+ await self . _reconnector . _on_before_commit ( tx )
182
183
183
- def _on_after_rollback (self , exc ):
184
- return super ()._on_after_rollback (exc )
184
+ async def _on_after_commit (self , tx , exc ):
185
+ await self ._reconnector ._on_after_commit (tx , exc )
186
+
187
+ async def _on_after_rollback (self , tx , exc ):
188
+ await self ._reconnector ._on_after_rollback (tx , exc )
185
189
186
190
187
191
class ReaderReconnector :
@@ -195,8 +199,10 @@ class ReaderReconnector:
195
199
_state_changed : asyncio .Event
196
200
_stream_reader : Optional ["ReaderStream" ]
197
201
_first_error : asyncio .Future [YdbError ]
198
- _batches_to_commit : asyncio .Queue
199
- _wait_executor : Optional [concurrent .futures .ThreadPoolExecutor ]
202
+
203
+ _batches_to_commit_with_tx : asyncio .Queue
204
+ _tx_to_batches : typing .Dict [str , typing .List [datatypes .PublicBatch ]]
205
+ _wait_executor : concurrent .futures .ThreadPoolExecutor
200
206
201
207
def __init__ (self , driver : Driver , settings : topic_reader .PublicReaderSettings ):
202
208
self ._id = self ._static_reader_reconnector_counter .inc_and_get ()
@@ -210,7 +216,9 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
210
216
self ._first_error = asyncio .get_running_loop ().create_future ()
211
217
self ._wait_executor = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
212
218
213
- self ._batches_to_commit = asyncio .Queue ()
219
+ self ._batches_to_commit_with_tx = asyncio .Queue ()
220
+ self ._tx_to_batches = defaultdict (list )
221
+ self ._background_tasks .add (asyncio .create_task (self ._update_offsets_in_tx_loop ()))
214
222
215
223
async def _connection_loop (self ):
216
224
attempt = 0
@@ -263,19 +271,21 @@ def receive_message_nowait(self):
263
271
def commit (self , batch : datatypes .ICommittable ) -> datatypes .PartitionSession .CommitAckWaiter :
264
272
return self ._stream_reader .commit (batch )
265
273
266
- async def _commit_with_tx (self , tx : "BaseQueryTxContext" , batch : datatypes .ICommittable ) -> None :
267
- pass
268
-
269
- async def receive_batch_with_tx_nowait (self , tx : "BaseQueryTxContext" , max_messages : Optional [int ] = None ):
274
+ def receive_batch_with_tx_nowait (self , tx : "BaseQueryTxContext" , max_messages : Optional [int ] = None ):
270
275
batch = self .receive_batch_nowait (max_messages = max_messages )
271
- tx ._add_listener (batch )
272
- await self ._update_offsets_in_tx_call (self ._driver , tx , batch )
276
+ self ._tx_to_batches [tx .tx_id ].append (batch )
277
+
278
+ self ._batches_to_commit_with_tx .put_nowait ((tx , batch ))
279
+
280
+ print ("batch recieved" )
281
+
273
282
return batch
274
- # self._batches_to_commit.put_nowait((tx, batch))
275
283
276
284
async def _update_offsets_in_tx_loop (self ):
277
285
while True :
278
- tx , batch = self ._batches_to_commit .get ()
286
+ print ("_update_offsets_in_tx_loop" )
287
+
288
+ tx , batch = await self ._batches_to_commit_with_tx .get ()
279
289
await self ._update_offsets_in_tx_call (self ._driver , tx , batch )
280
290
281
291
async def _update_offsets_in_tx_call (self , driver : SupportedDriverType , tx : "BaseQueryTxContext" , batch : datatypes .ICommittable ) -> None :
@@ -309,9 +319,25 @@ async def _update_offsets_in_tx_call(self, driver: SupportedDriverType, tx: "Bas
309
319
else :
310
320
res = await to_thread (driver , * args , executor = self ._wait_executor )
311
321
322
+ batch ._commited_with_tx = True
323
+
312
324
return res
313
325
except BaseException as e :
314
- self ._set_first_error (e )
326
+ self ._stream_reader ._set_first_error (e )
327
+
328
+ async def _ensure_all_batches_commited_with_tx (self , tx : "BaseQueryTxContext" ):
329
+ while True :
330
+ print ("_ensure_all_batches_commited_with_tx" )
331
+ if tx .tx_id not in self ._tx_to_batches :
332
+ # we should not be here
333
+ return True
334
+ batches = self ._tx_to_batches .get (tx .tx_id )
335
+ everything_commited = True
336
+ for batch in batches :
337
+ everything_commited = everything_commited and batch ._commited_with_tx
338
+ if everything_commited :
339
+ return True
340
+ await asyncio .sleep (0.001 )
315
341
316
342
async def close (self , flush : bool ):
317
343
if self ._stream_reader :
@@ -336,6 +362,26 @@ def _set_first_error(self, err: issues.Error):
336
362
# skip if already has result
337
363
pass
338
364
365
+ async def _on_before_commit (self , tx ):
366
+ print ("on before commit" )
367
+ await asyncio .wait_for (self ._ensure_all_batches_commited_with_tx (tx ), 1 )
368
+ pass
369
+
370
+ async def _on_after_commit (self , tx , exc ):
371
+ print (f"on after commit, exc = { exc is not None } " )
372
+
373
+ if exc :
374
+ self ._stream_reader ._set_first_error (exc )
375
+ for batch in self ._tx_to_batches [tx .tx_id ]:
376
+ batch ._partition_session .committed_offset = max (batch ._partition_session .committed_offset , batch ._commit_get_offsets_range ().end )
377
+ del self ._tx_to_batches [tx .tx_id ]
378
+
379
+ async def _on_after_rollback (self , tx , exc ):
380
+ print (f"on after rollback, exc = { exc is not None } " )
381
+ print (exc )
382
+ exc = exc if exc is not None else issues .InternalError ("tx rollback failed" )
383
+ self ._stream_reader ._set_first_error (exc )
384
+
339
385
340
386
class ReaderStream :
341
387
_static_id_counter = AtomicCounter ()
0 commit comments