Skip to content

Commit f7aceb1

Browse files
authored
Merge pull request #208 from ydb-platform/topic-reader-refresh-token
Topic reader refresh token
2 parents 584e393 + 618f25c commit f7aceb1

File tree

5 files changed

+153
-67
lines changed

5 files changed

+153
-67
lines changed

ydb/_grpc/grpcwrapper/ydb_topic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,8 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient:
686686
res.commit_offset_request.CopyFrom(self.client_message.to_proto())
687687
elif isinstance(self.client_message, StreamReadMessage.InitRequest):
688688
res.init_request.CopyFrom(self.client_message.to_proto())
689+
elif isinstance(self.client_message, UpdateTokenRequest):
690+
res.update_token_request.CopyFrom(self.client_message.to_proto())
689691
elif isinstance(
690692
self.client_message, StreamReadMessage.StartPartitionSessionResponse
691693
):
@@ -737,6 +739,13 @@ def from_proto(
737739
msg.start_partition_session_request
738740
),
739741
)
742+
elif mess_type == "update_token_response":
743+
return StreamReadMessage.FromServer(
744+
server_status=server_status,
745+
server_message=UpdateTokenResponse.from_proto(
746+
msg.update_token_response
747+
),
748+
)
740749

741750
# todo replace exception to log
742751
raise NotImplementedError()

ydb/_topic_reader/topic_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class PublicReaderSettings:
4949
# one_attempt_connection_timeout: Union[float, None] = 1
5050
# connection_timeout: Union[float, None] = None
5151
# retry_policy: Union["RetryPolicy", None] = None
52+
update_token_interval: Union[int, float] = 3600
5253

5354
def _init_message(self) -> StreamReadMessage.InitRequest:
5455
return StreamReadMessage.InitRequest(

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import typing
77
from asyncio import Task
88
from collections import deque
9-
from typing import Optional, Set, Dict
9+
from typing import Optional, Set, Dict, Union, Callable
1010

11-
from .. import _apis, issues, RetrySettings
11+
from .. import _apis, issues
1212
from .._utilities import AtomicCounter
1313
from ..aio import Driver
1414
from ..issues import Error as YdbError, _process_response
@@ -19,7 +19,12 @@
1919
SupportedDriverType,
2020
GrpcWrapperAsyncIO,
2121
)
22-
from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec
22+
from .._grpc.grpcwrapper.ydb_topic import (
23+
StreamReadMessage,
24+
UpdateTokenRequest,
25+
UpdateTokenResponse,
26+
Codec,
27+
)
2328
from .._errors import check_retriable_error
2429

2530

@@ -194,7 +199,6 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
194199
self._settings = settings
195200
self._driver = driver
196201
self._background_tasks = set()
197-
self._retry_settins = RetrySettings(idempotent=True) # get from settings
198202

199203
self._state_changed = asyncio.Event()
200204
self._stream_reader = None
@@ -227,7 +231,7 @@ async def wait_message(self):
227231
if self._first_error.done():
228232
raise self._first_error.result()
229233

230-
if self._stream_reader is not None:
234+
if self._stream_reader:
231235
try:
232236
await self._stream_reader.wait_messages()
233237
return
@@ -289,8 +293,15 @@ class ReaderStream:
289293
_message_batches: typing.Deque[datatypes.PublicBatch]
290294
_first_error: asyncio.Future[YdbError]
291295

296+
_update_token_interval: Union[int, float]
297+
_update_token_event: asyncio.Event
298+
_get_token_function: Callable[[], str]
299+
292300
def __init__(
293-
self, reader_reconnector_id: int, settings: topic_reader.PublicReaderSettings
301+
self,
302+
reader_reconnector_id: int,
303+
settings: topic_reader.PublicReaderSettings,
304+
get_token_function: Optional[Callable[[], str]] = None,
294305
):
295306
self._loop = asyncio.get_running_loop()
296307
self._id = ReaderStream._static_id_counter.inc_and_get()
@@ -313,6 +324,10 @@ def __init__(
313324
self._batches_to_decode = asyncio.Queue()
314325
self._message_batches = deque()
315326

327+
self._update_token_interval = settings.update_token_interval
328+
self._get_token_function = get_token_function
329+
self._update_token_event = asyncio.Event()
330+
316331
@staticmethod
317332
async def create(
318333
reader_reconnector_id: int,
@@ -325,7 +340,12 @@ async def create(
325340
driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead
326341
)
327342

328-
reader = ReaderStream(reader_reconnector_id, settings)
343+
creds = driver._credentials
344+
reader = ReaderStream(
345+
reader_reconnector_id,
346+
settings,
347+
get_token_function=creds.get_auth_token if creds else None,
348+
)
329349
await reader._start(stream, settings._init_message())
330350
return reader
331351

@@ -347,35 +367,41 @@ async def _start(
347367
"Unexpected message after InitRequest: %s", init_response
348368
)
349369

370+
self._update_token_event.set()
371+
350372
self._background_tasks.add(
351-
asyncio.create_task(self._read_messages_loop(stream))
373+
asyncio.create_task(self._read_messages_loop(), name="read_messages_loop")
352374
)
353375
self._background_tasks.add(asyncio.create_task(self._decode_batches_loop()))
376+
if self._get_token_function:
377+
self._background_tasks.add(
378+
asyncio.create_task(self._update_token_loop(), name="update_token_loop")
379+
)
354380

355381
async def wait_error(self):
356382
raise await self._first_error
357383

358384
async def wait_messages(self):
359385
while True:
360-
if self._get_first_error() is not None:
386+
if self._get_first_error():
361387
raise self._get_first_error()
362388

363-
if len(self._message_batches) > 0:
389+
if self._message_batches:
364390
return
365391

366392
await self._state_changed.wait()
367393
self._state_changed.clear()
368394

369395
def receive_batch_nowait(self):
370-
if self._get_first_error() is not None:
396+
if self._get_first_error():
371397
raise self._get_first_error()
372398

373-
try:
374-
batch = self._message_batches.popleft()
375-
self._buffer_release_bytes(batch._bytes_size)
376-
return batch
377-
except IndexError:
378-
return None
399+
if not self._message_batches:
400+
return
401+
402+
batch = self._message_batches.popleft()
403+
self._buffer_release_bytes(batch._bytes_size)
404+
return batch
379405

380406
def commit(
381407
self, batch: datatypes.ICommittable
@@ -413,7 +439,7 @@ def commit(
413439

414440
return waiter
415441

416-
async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
442+
async def _read_messages_loop(self):
417443
try:
418444
self._stream.write(
419445
StreamReadMessage.FromClient(
@@ -423,24 +449,34 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
423449
)
424450
)
425451
while True:
426-
message = await stream.receive() # type: StreamReadMessage.FromServer
452+
message = (
453+
await self._stream.receive()
454+
) # type: StreamReadMessage.FromServer
427455
_process_response(message.server_status)
456+
428457
if isinstance(message.server_message, StreamReadMessage.ReadResponse):
429458
self._on_read_response(message.server_message)
459+
430460
elif isinstance(
431461
message.server_message, StreamReadMessage.CommitOffsetResponse
432462
):
433463
self._on_commit_response(message.server_message)
464+
434465
elif isinstance(
435466
message.server_message,
436467
StreamReadMessage.StartPartitionSessionRequest,
437468
):
438469
self._on_start_partition_session(message.server_message)
470+
439471
elif isinstance(
440472
message.server_message,
441473
StreamReadMessage.StopPartitionSessionRequest,
442474
):
443475
self._on_partition_session_stop(message.server_message)
476+
477+
elif isinstance(message.server_message, UpdateTokenResponse):
478+
self._update_token_event.set()
479+
444480
else:
445481
raise NotImplementedError(
446482
"Unexpected type of StreamReadMessage.FromServer message: %s"
@@ -450,7 +486,20 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
450486
self._state_changed.set()
451487
except Exception as e:
452488
self._set_first_error(e)
453-
raise e
489+
raise
490+
491+
async def _update_token_loop(self):
492+
while True:
493+
await asyncio.sleep(self._update_token_interval)
494+
await self._update_token(token=self._get_token_function())
495+
496+
async def _update_token(self, token: str):
497+
await self._update_token_event.wait()
498+
try:
499+
msg = StreamReadMessage.FromClient(UpdateTokenRequest(token))
500+
self._stream.write(msg)
501+
finally:
502+
self._update_token_event.clear()
454503

455504
def _on_start_partition_session(
456505
self, message: StreamReadMessage.StartPartitionSessionRequest
@@ -491,14 +540,12 @@ def _on_start_partition_session(
491540
def _on_partition_session_stop(
492541
self, message: StreamReadMessage.StopPartitionSessionRequest
493542
):
494-
try:
495-
partition = self._partition_sessions[message.partition_session_id]
496-
except KeyError:
543+
if message.partition_session_id not in self._partition_sessions:
497544
# may if receive stop partition with graceful=false after response on stop partition
498545
# with graceful=true and remove partition from internal dictionary
499546
return
500547

501-
del self._partition_sessions[message.partition_session_id]
548+
partition = self._partition_sessions.pop(message.partition_session_id)
502549
partition.close()
503550

504551
if message.graceful:
@@ -519,11 +566,10 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse):
519566

520567
def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse):
521568
for partition_offset in message.partitions_committed_offsets:
522-
session = self._partition_sessions.get(
523-
partition_offset.partition_session_id
524-
)
525-
if session is None:
569+
if partition_offset.partition_session_id not in self._partition_sessions:
526570
continue
571+
572+
session = self._partition_sessions[partition_offset.partition_session_id]
527573
session.ack_notify(partition_offset.committed_offset)
528574

529575
def _buffer_consume_bytes(self, bytes_size):
@@ -544,12 +590,9 @@ def _read_response_to_batches(
544590
) -> typing.List[datatypes.PublicBatch]:
545591
batches = []
546592

547-
batch_count = 0
548-
for partition_data in message.partition_data:
549-
batch_count += len(partition_data.batches)
550-
593+
batch_count = sum(len(p.batches) for p in message.partition_data)
551594
if batch_count == 0:
552-
return []
595+
return batches
553596

554597
bytes_per_batch = message.bytes_size // batch_count
555598
additional_bytes_to_last_batch = (
@@ -577,12 +620,11 @@ def _read_response_to_batches(
577620
_commit_end_offset=message_data.offset + 1,
578621
)
579622
messages.append(mess)
580-
581623
partition_session._next_message_start_commit_offset = (
582624
mess._commit_end_offset
583625
)
584626

585-
if len(messages) > 0:
627+
if messages:
586628
batch = datatypes.PublicBatch(
587629
session_metadata=server_batch.write_session_meta,
588630
messages=messages,
@@ -637,14 +679,12 @@ def _set_first_error(self, err: YdbError):
637679
def _get_first_error(self) -> Optional[YdbError]:
638680
if self._first_error.done():
639681
return self._first_error.result()
640-
else:
641-
return None
642682

643683
async def close(self):
644684
if self._closed:
645-
raise TopicReaderError(message="Double closed ReaderStream")
646-
685+
return
647686
self._closed = True
687+
648688
self._set_first_error(TopicReaderStreamClosedError())
649689
self._state_changed.set()
650690
self._stream.close()
@@ -654,5 +694,4 @@ async def close(self):
654694

655695
for task in self._background_tasks:
656696
task.cancel()
657-
658697
await asyncio.wait(self._background_tasks)

0 commit comments

Comments
 (0)