35
35
UpdateTokenRequest ,
36
36
UpdateTokenResponse ,
37
37
StreamWriteMessage ,
38
+ TransactionIdentity ,
38
39
WriterMessagesFromServerToClient ,
39
40
)
40
41
from .._grpc .grpcwrapper .common_utils import (
43
44
GrpcWrapperAsyncIO ,
44
45
)
45
46
47
+ if typing .TYPE_CHECKING :
48
+ from ..query .transaction import BaseQueryTxContext
49
+
46
50
logger = logging .getLogger (__name__ )
47
51
48
52
@@ -165,7 +169,20 @@ async def wait_init(self) -> PublicWriterInitInfo:
165
169
166
170
167
171
class TxWriterAsyncIO (WriterAsyncIO ):
168
- ...
172
+ _tx : object
173
+
174
+ def __init__ (
175
+ self ,
176
+ tx ,
177
+ driver : SupportedDriverType ,
178
+ settings : PublicWriterSettings ,
179
+ _client = None ,
180
+ ):
181
+ self ._tx = tx
182
+ self ._loop = asyncio .get_running_loop ()
183
+ self ._closed = False
184
+ self ._reconnector = WriterAsyncIOReconnector (driver = driver , settings = WriterSettings (settings ), tx = self ._tx )
185
+ self ._parent = _client
169
186
170
187
171
188
class WriterAsyncIOReconnector :
@@ -182,6 +199,7 @@ class WriterAsyncIOReconnector:
182
199
_codec_selector_batch_num : int
183
200
_codec_selector_last_codec : Optional [PublicCodec ]
184
201
_codec_selector_check_batches_interval : int
202
+ _tx : Optional ["BaseQueryTxContext" ]
185
203
186
204
if typing .TYPE_CHECKING :
187
205
_messages_for_encode : asyncio .Queue [List [InternalMessage ]]
@@ -199,7 +217,9 @@ class WriterAsyncIOReconnector:
199
217
_stop_reason : asyncio .Future
200
218
_init_info : Optional [PublicWriterInitInfo ]
201
219
202
- def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
220
+ def __init__ (
221
+ self , driver : SupportedDriverType , settings : WriterSettings , tx : Optional ["BaseQueryTxContext" ] = None
222
+ ):
203
223
self ._closed = False
204
224
self ._loop = asyncio .get_running_loop ()
205
225
self ._driver = driver
@@ -209,6 +229,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
209
229
self ._init_info = None
210
230
self ._stream_connected = asyncio .Event ()
211
231
self ._settings = settings
232
+ self ._tx = tx
212
233
213
234
self ._codec_functions = {
214
235
PublicCodec .RAW : lambda data : data ,
@@ -358,10 +379,12 @@ async def _connection_loop(self):
358
379
# noinspection PyBroadException
359
380
stream_writer = None
360
381
try :
382
+ tx_identity = None if self ._tx is None else self ._tx ._tx_identity ()
361
383
stream_writer = await WriterAsyncIOStream .create (
362
384
self ._driver ,
363
385
self ._init_message ,
364
386
self ._settings .update_token_interval ,
387
+ tx_identity = tx_identity ,
365
388
)
366
389
try :
367
390
if self ._init_info is None :
@@ -601,10 +624,13 @@ class WriterAsyncIOStream:
601
624
_update_token_event : asyncio .Event
602
625
_get_token_function : Optional [Callable [[], str ]]
603
626
627
+ _tx_identity : Optional [TransactionIdentity ]
628
+
604
629
def __init__ (
605
630
self ,
606
631
update_token_interval : Optional [Union [int , float ]] = None ,
607
632
get_token_function : Optional [Callable [[], str ]] = None ,
633
+ tx_identity : Optional [TransactionIdentity ] = None ,
608
634
):
609
635
self ._closed = False
610
636
@@ -613,6 +639,8 @@ def __init__(
613
639
self ._update_token_event = asyncio .Event ()
614
640
self ._update_token_task = None
615
641
642
+ self ._tx_identity = tx_identity
643
+
616
644
async def close (self ):
617
645
if self ._closed :
618
646
return
@@ -629,6 +657,7 @@ async def create(
629
657
driver : SupportedDriverType ,
630
658
init_request : StreamWriteMessage .InitRequest ,
631
659
update_token_interval : Optional [Union [int , float ]] = None ,
660
+ tx_identity : Optional [TransactionIdentity ] = None ,
632
661
) -> "WriterAsyncIOStream" :
633
662
stream = GrpcWrapperAsyncIO (StreamWriteMessage .FromServer .from_proto )
634
663
@@ -638,6 +667,7 @@ async def create(
638
667
writer = WriterAsyncIOStream (
639
668
update_token_interval = update_token_interval ,
640
669
get_token_function = creds .get_auth_token if creds else lambda : "" ,
670
+ tx_identity = tx_identity ,
641
671
)
642
672
await writer ._start (stream , init_request )
643
673
return writer
@@ -684,7 +714,7 @@ def write(self, messages: List[InternalMessage]):
684
714
if self ._closed :
685
715
raise RuntimeError ("Can not write on closed stream." )
686
716
687
- for request in messages_to_proto_requests (messages ):
717
+ for request in messages_to_proto_requests (messages , self . _tx_identity ):
688
718
self ._stream .write (request )
689
719
690
720
async def _update_token_loop (self ):
0 commit comments