@@ -84,7 +84,7 @@ def __init__(
84
84
):
85
85
self ._loop = asyncio .get_running_loop ()
86
86
self ._closed = False
87
- self ._reconnector = ReaderReconnector (driver , settings )
87
+ self ._reconnector = ReaderReconnector (driver , settings , self . _loop )
88
88
self ._parent = _parent
89
89
90
90
async def __aenter__ (self ):
@@ -190,18 +190,24 @@ class ReaderReconnector:
190
190
_first_error : asyncio .Future [YdbError ]
191
191
_tx_to_batches_map : Dict [str , typing .List [datatypes .PublicBatch ]]
192
192
193
- def __init__ (self , driver : Driver , settings : topic_reader .PublicReaderSettings ):
193
+ def __init__ (
194
+ self ,
195
+ driver : Driver ,
196
+ settings : topic_reader .PublicReaderSettings ,
197
+ loop : Optional [asyncio .AbstractEventLoop ] = None ,
198
+ ):
194
199
self ._id = self ._static_reader_reconnector_counter .inc_and_get ()
195
200
self ._settings = settings
196
201
self ._driver = driver
202
+ self ._loop = loop if loop is not None else asyncio .get_running_loop ()
197
203
self ._background_tasks = set ()
198
204
199
205
self ._state_changed = asyncio .Event ()
200
206
self ._stream_reader = None
201
207
self ._background_tasks .add (asyncio .create_task (self ._connection_loop ()))
202
208
self ._first_error = asyncio .get_running_loop ().create_future ()
203
209
204
- self ._tx_to_batches_map = defaultdict ( list )
210
+ self ._tx_to_batches_map = dict ( )
205
211
206
212
async def _connection_loop (self ):
207
213
attempt = 0
@@ -254,22 +260,23 @@ def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: O
254
260
max_messages = max_messages ,
255
261
)
256
262
257
- self ._init_tx_if_needed (tx )
263
+ self ._init_tx (tx )
258
264
259
265
self ._tx_to_batches_map [tx .tx_id ].append (batch )
260
266
261
- tx ._add_callback (TxEvent .AFTER_COMMIT , batch ._update_partition_offsets , None ) # probably should be current loop
267
+ tx ._add_callback (TxEvent .AFTER_COMMIT , batch ._update_partition_offsets , self . _loop )
262
268
263
269
return batch
264
270
265
271
def receive_message_nowait (self ):
266
272
return self ._stream_reader .receive_message_nowait ()
267
273
268
- def _init_tx_if_needed (self , tx : "BaseQueryTxContext" ):
274
+ def _init_tx (self , tx : "BaseQueryTxContext" ):
269
275
if tx .tx_id not in self ._tx_to_batches_map : # Init tx callbacks
270
- tx ._add_callback (TxEvent .BEFORE_COMMIT , self ._commit_batches_with_tx , None )
271
- tx ._add_callback (TxEvent .AFTER_COMMIT , self ._handle_after_tx_commit , None )
272
- tx ._add_callback (TxEvent .AFTER_ROLLBACK , self ._handle_after_tx_rollback , None )
276
+ self ._tx_to_batches_map [tx .tx_id ] = []
277
+ tx ._add_callback (TxEvent .BEFORE_COMMIT , self ._commit_batches_with_tx , self ._loop )
278
+ tx ._add_callback (TxEvent .AFTER_COMMIT , self ._handle_after_tx_commit , self ._loop )
279
+ tx ._add_callback (TxEvent .AFTER_ROLLBACK , self ._handle_after_tx_rollback , self ._loop )
273
280
274
281
async def _commit_batches_with_tx (self , tx : "BaseQueryTxContext" ):
275
282
grouped_batches = defaultdict (lambda : defaultdict (list ))
0 commit comments