Skip to content

Commit e16e380

Browse files
authored
Merge pull request #540 from ydb-platform/fix_aio_creds
Fix auth credentials
2 parents 0dfe8e7 + 4ea96ce commit e16e380

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

Diff for: tests/aio/test_credentials.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ async def test_yandex_service_account_credentials():
3636
tests.auth.test_credentials.PRIVATE_KEY,
3737
server.get_endpoint(),
3838
)
39-
t = (await credentials.auth_metadata())[0][1]
39+
t = await credentials.get_auth_token()
4040
assert t == "test_token"
4141
assert credentials.get_expire_time() <= 42
42+
4243
server.stop()
4344

4445

Diff for: ydb/_topic_reader/topic_reader_asyncio.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,10 @@ async def _read_messages_loop(self):
516516
async def _update_token_loop(self):
517517
while True:
518518
await asyncio.sleep(self._update_token_interval)
519-
await self._update_token(token=self._get_token_function())
519+
token = self._get_token_function()
520+
if asyncio.iscoroutine(token):
521+
token = await token
522+
await self._update_token(token=token)
520523

521524
async def _update_token(self, token: str):
522525
await self._update_token_event.wait()

Diff for: ydb/_topic_writer/topic_writer_asyncio.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,10 @@ def write(self, messages: List[InternalMessage]):
686686
async def _update_token_loop(self):
687687
while True:
688688
await asyncio.sleep(self._update_token_interval)
689-
await self._update_token(token=self._get_token_function())
689+
token = self._get_token_function()
690+
if asyncio.iscoroutine(token):
691+
token = await token
692+
await self._update_token(token=token)
690693

691694
async def _update_token(self, token: str):
692695
await self._update_token_event.wait()

Diff for: ydb/aio/credentials.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import time
2-
31
import abc
42
import asyncio
53
import logging
6-
from ydb import issues, credentials
4+
import time
5+
6+
from ydb import credentials
7+
from ydb import issues
78

89
logger = logging.getLogger(__name__)
10+
YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket"
911

1012

1113
class _OneToManyValue(object):
@@ -64,6 +66,12 @@ def __init__(self):
6466
async def _make_token_request(self):
6567
pass
6668

69+
async def get_auth_token(self) -> str:
70+
for header, token in await self.auth_metadata():
71+
if header == YDB_AUTH_TICKET_HEADER:
72+
return token
73+
return ""
74+
6775
async def _refresh(self):
6876
current_time = time.time()
6977
self._log_refresh_start(current_time)

0 commit comments

Comments
 (0)