6
6
import typing
7
7
from asyncio import Task
8
8
from collections import deque
9
- from typing import Optional , Set , Dict
9
+ from typing import Optional , Set , Dict , Union , Callable
10
10
11
- from .. import _apis , issues , RetrySettings
11
+ from .. import _apis , issues
12
12
from .._utilities import AtomicCounter
13
13
from ..aio import Driver
14
14
from ..issues import Error as YdbError , _process_response
19
19
SupportedDriverType ,
20
20
GrpcWrapperAsyncIO ,
21
21
)
22
- from .._grpc .grpcwrapper .ydb_topic import StreamReadMessage , Codec
22
+ from .._grpc .grpcwrapper .ydb_topic import (
23
+ StreamReadMessage ,
24
+ UpdateTokenRequest ,
25
+ UpdateTokenResponse ,
26
+ Codec ,
27
+ )
23
28
from .._errors import check_retriable_error
24
29
25
30
@@ -194,7 +199,6 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
194
199
self ._settings = settings
195
200
self ._driver = driver
196
201
self ._background_tasks = set ()
197
- self ._retry_settins = RetrySettings (idempotent = True ) # get from settings
198
202
199
203
self ._state_changed = asyncio .Event ()
200
204
self ._stream_reader = None
@@ -227,7 +231,7 @@ async def wait_message(self):
227
231
if self ._first_error .done ():
228
232
raise self ._first_error .result ()
229
233
230
- if self ._stream_reader is not None :
234
+ if self ._stream_reader :
231
235
try :
232
236
await self ._stream_reader .wait_messages ()
233
237
return
@@ -289,8 +293,15 @@ class ReaderStream:
289
293
_message_batches : typing .Deque [datatypes .PublicBatch ]
290
294
_first_error : asyncio .Future [YdbError ]
291
295
296
+ _update_token_interval : Union [int , float ]
297
+ _update_token_event : asyncio .Event
298
+ _get_token_function : Callable [[], str ]
299
+
292
300
def __init__ (
293
- self , reader_reconnector_id : int , settings : topic_reader .PublicReaderSettings
301
+ self ,
302
+ reader_reconnector_id : int ,
303
+ settings : topic_reader .PublicReaderSettings ,
304
+ get_token_function : Optional [Callable [[], str ]] = None ,
294
305
):
295
306
self ._loop = asyncio .get_running_loop ()
296
307
self ._id = ReaderStream ._static_id_counter .inc_and_get ()
@@ -313,6 +324,10 @@ def __init__(
313
324
self ._batches_to_decode = asyncio .Queue ()
314
325
self ._message_batches = deque ()
315
326
327
+ self ._update_token_interval = settings .update_token_interval
328
+ self ._get_token_function = get_token_function
329
+ self ._update_token_event = asyncio .Event ()
330
+
316
331
@staticmethod
317
332
async def create (
318
333
reader_reconnector_id : int ,
@@ -325,7 +340,12 @@ async def create(
325
340
driver , _apis .TopicService .Stub , _apis .TopicService .StreamRead
326
341
)
327
342
328
- reader = ReaderStream (reader_reconnector_id , settings )
343
+ creds = driver ._credentials
344
+ reader = ReaderStream (
345
+ reader_reconnector_id ,
346
+ settings ,
347
+ get_token_function = creds .get_auth_token if creds else None ,
348
+ )
329
349
await reader ._start (stream , settings ._init_message ())
330
350
return reader
331
351
@@ -347,35 +367,41 @@ async def _start(
347
367
"Unexpected message after InitRequest: %s" , init_response
348
368
)
349
369
370
+ self ._update_token_event .set ()
371
+
350
372
self ._background_tasks .add (
351
- asyncio .create_task (self ._read_messages_loop (stream ) )
373
+ asyncio .create_task (self ._read_messages_loop (), name = "read_messages_loop" )
352
374
)
353
375
self ._background_tasks .add (asyncio .create_task (self ._decode_batches_loop ()))
376
+ if self ._get_token_function :
377
+ self ._background_tasks .add (
378
+ asyncio .create_task (self ._update_token_loop (), name = "update_token_loop" )
379
+ )
354
380
355
381
async def wait_error (self ):
356
382
raise await self ._first_error
357
383
358
384
async def wait_messages (self ):
359
385
while True :
360
- if self ._get_first_error () is not None :
386
+ if self ._get_first_error ():
361
387
raise self ._get_first_error ()
362
388
363
- if len ( self ._message_batches ) > 0 :
389
+ if self ._message_batches :
364
390
return
365
391
366
392
await self ._state_changed .wait ()
367
393
self ._state_changed .clear ()
368
394
369
395
def receive_batch_nowait (self ):
370
- if self ._get_first_error () is not None :
396
+ if self ._get_first_error ():
371
397
raise self ._get_first_error ()
372
398
373
- try :
374
- batch = self . _message_batches . popleft ()
375
- self . _buffer_release_bytes ( batch . _bytes_size )
376
- return batch
377
- except IndexError :
378
- return None
399
+ if not self . _message_batches :
400
+ return
401
+
402
+ batch = self . _message_batches . popleft ()
403
+ self . _buffer_release_bytes ( batch . _bytes_size )
404
+ return batch
379
405
380
406
def commit (
381
407
self , batch : datatypes .ICommittable
@@ -413,7 +439,7 @@ def commit(
413
439
414
440
return waiter
415
441
416
- async def _read_messages_loop (self , stream : IGrpcWrapperAsyncIO ):
442
+ async def _read_messages_loop (self ):
417
443
try :
418
444
self ._stream .write (
419
445
StreamReadMessage .FromClient (
@@ -423,24 +449,34 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
423
449
)
424
450
)
425
451
while True :
426
- message = await stream .receive () # type: StreamReadMessage.FromServer
452
+ message = (
453
+ await self ._stream .receive ()
454
+ ) # type: StreamReadMessage.FromServer
427
455
_process_response (message .server_status )
456
+
428
457
if isinstance (message .server_message , StreamReadMessage .ReadResponse ):
429
458
self ._on_read_response (message .server_message )
459
+
430
460
elif isinstance (
431
461
message .server_message , StreamReadMessage .CommitOffsetResponse
432
462
):
433
463
self ._on_commit_response (message .server_message )
464
+
434
465
elif isinstance (
435
466
message .server_message ,
436
467
StreamReadMessage .StartPartitionSessionRequest ,
437
468
):
438
469
self ._on_start_partition_session (message .server_message )
470
+
439
471
elif isinstance (
440
472
message .server_message ,
441
473
StreamReadMessage .StopPartitionSessionRequest ,
442
474
):
443
475
self ._on_partition_session_stop (message .server_message )
476
+
477
+ elif isinstance (message .server_message , UpdateTokenResponse ):
478
+ self ._update_token_event .set ()
479
+
444
480
else :
445
481
raise NotImplementedError (
446
482
"Unexpected type of StreamReadMessage.FromServer message: %s"
@@ -450,7 +486,20 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
450
486
self ._state_changed .set ()
451
487
except Exception as e :
452
488
self ._set_first_error (e )
453
- raise e
489
+ raise
490
+
491
+ async def _update_token_loop (self ):
492
+ while True :
493
+ await asyncio .sleep (self ._update_token_interval )
494
+ await self ._update_token (token = self ._get_token_function ())
495
+
496
+ async def _update_token (self , token : str ):
497
+ await self ._update_token_event .wait ()
498
+ try :
499
+ msg = StreamReadMessage .FromClient (UpdateTokenRequest (token ))
500
+ self ._stream .write (msg )
501
+ finally :
502
+ self ._update_token_event .clear ()
454
503
455
504
def _on_start_partition_session (
456
505
self , message : StreamReadMessage .StartPartitionSessionRequest
@@ -491,14 +540,12 @@ def _on_start_partition_session(
491
540
def _on_partition_session_stop (
492
541
self , message : StreamReadMessage .StopPartitionSessionRequest
493
542
):
494
- try :
495
- partition = self ._partition_sessions [message .partition_session_id ]
496
- except KeyError :
543
+ if message .partition_session_id not in self ._partition_sessions :
497
544
# may if receive stop partition with graceful=false after response on stop partition
498
545
# with graceful=true and remove partition from internal dictionary
499
546
return
500
547
501
- del self ._partition_sessions [ message .partition_session_id ]
548
+ partition = self ._partition_sessions . pop ( message .partition_session_id )
502
549
partition .close ()
503
550
504
551
if message .graceful :
@@ -519,11 +566,10 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse):
519
566
520
567
def _on_commit_response (self , message : StreamReadMessage .CommitOffsetResponse ):
521
568
for partition_offset in message .partitions_committed_offsets :
522
- session = self ._partition_sessions .get (
523
- partition_offset .partition_session_id
524
- )
525
- if session is None :
569
+ if partition_offset .partition_session_id not in self ._partition_sessions :
526
570
continue
571
+
572
+ session = self ._partition_sessions [partition_offset .partition_session_id ]
527
573
session .ack_notify (partition_offset .committed_offset )
528
574
529
575
def _buffer_consume_bytes (self , bytes_size ):
@@ -544,12 +590,9 @@ def _read_response_to_batches(
544
590
) -> typing .List [datatypes .PublicBatch ]:
545
591
batches = []
546
592
547
- batch_count = 0
548
- for partition_data in message .partition_data :
549
- batch_count += len (partition_data .batches )
550
-
593
+ batch_count = sum (len (p .batches ) for p in message .partition_data )
551
594
if batch_count == 0 :
552
- return []
595
+ return batches
553
596
554
597
bytes_per_batch = message .bytes_size // batch_count
555
598
additional_bytes_to_last_batch = (
@@ -577,12 +620,11 @@ def _read_response_to_batches(
577
620
_commit_end_offset = message_data .offset + 1 ,
578
621
)
579
622
messages .append (mess )
580
-
581
623
partition_session ._next_message_start_commit_offset = (
582
624
mess ._commit_end_offset
583
625
)
584
626
585
- if len ( messages ) > 0 :
627
+ if messages :
586
628
batch = datatypes .PublicBatch (
587
629
session_metadata = server_batch .write_session_meta ,
588
630
messages = messages ,
@@ -637,14 +679,12 @@ def _set_first_error(self, err: YdbError):
637
679
def _get_first_error (self ) -> Optional [YdbError ]:
638
680
if self ._first_error .done ():
639
681
return self ._first_error .result ()
640
- else :
641
- return None
642
682
643
683
async def close (self ):
644
684
if self ._closed :
645
- raise TopicReaderError (message = "Double closed ReaderStream" )
646
-
685
+ return
647
686
self ._closed = True
687
+
648
688
self ._set_first_error (TopicReaderStreamClosedError ())
649
689
self ._state_changed .set ()
650
690
self ._stream .close ()
@@ -654,5 +694,4 @@ async def close(self):
654
694
655
695
for task in self ._background_tasks :
656
696
task .cancel ()
657
-
658
697
await asyncio .wait (self ._background_tasks )
0 commit comments