Skip to content

Commit 53d6868

Browse files
authored
Merge pull request #40 from ydb-platform/fix-issues-in-aio-library
fix importing aio credentials
2 parents 008d331 + da3b7e8 commit 53d6868

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

tests/aio/test_tx.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import ydb.aio.iam
23

34

45
@pytest.mark.asyncio
@@ -36,3 +37,15 @@ async def test_tx_begin(driver, database):
3637
await tx.begin()
3738
await tx.begin()
3839
await tx.rollback()
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_credentials():
44+
credentials = ydb.aio.iam.MetadataUrlCredentials()
45+
raised = False
46+
try:
47+
await credentials.auth_metadata()
48+
except Exception:
49+
raised = True
50+
51+
assert raised

ydb/aio/credentials.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import time
2+
3+
import abc
4+
import asyncio
5+
import logging
6+
import six
7+
from ydb import issues, credentials
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class _OneToManyValue(object):
13+
def __init__(self):
14+
self._value = None
15+
self._condition = asyncio.Condition()
16+
17+
async def consume(self, timeout=3):
18+
async with self._condition:
19+
if self._value is None:
20+
try:
21+
await asyncio.wait_for(self._condition.wait(), timeout=timeout)
22+
except Exception:
23+
return self._value
24+
return self._value
25+
26+
async def update(self, n_value):
27+
async with self._condition:
28+
prev_value = self._value
29+
self._value = n_value
30+
if prev_value is None:
31+
self._condition.notify_all()
32+
33+
34+
class _AtMostOneExecution(object):
35+
def __init__(self):
36+
self._can_schedule = True
37+
self._lock = asyncio.Lock() # Lock to guarantee only one execution
38+
39+
async def _wrapped_execution(self, callback):
40+
await self._lock.acquire()
41+
try:
42+
res = callback()
43+
if asyncio.iscoroutine(res):
44+
await res
45+
except Exception:
46+
pass
47+
48+
finally:
49+
self._lock.release()
50+
self._can_schedule = True
51+
52+
def submit(self, callback):
53+
if self._can_schedule:
54+
self._can_schedule = False
55+
asyncio.ensure_future(self._wrapped_execution(callback))
56+
57+
58+
@six.add_metaclass(abc.ABCMeta)
59+
class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials):
60+
def __init__(self):
61+
super(AbstractExpiringTokenCredentials, self).__init__()
62+
self._tp = _AtMostOneExecution()
63+
self._cached_token = _OneToManyValue()
64+
65+
@abc.abstractmethod
66+
async def _make_token_request(self):
67+
pass
68+
69+
async def _refresh(self):
70+
current_time = time.time()
71+
self._log_refresh_start(current_time)
72+
73+
try:
74+
auth_metadata = await self._make_token_request()
75+
await self._cached_token.update(auth_metadata["access_token"])
76+
self.update_expiration_info(auth_metadata)
77+
self.logger.info(
78+
"Token refresh successful. current_time %s, refresh_in %s",
79+
current_time,
80+
self._refresh_in,
81+
)
82+
83+
except (KeyboardInterrupt, SystemExit):
84+
return
85+
86+
except Exception as e:
87+
self.last_error = str(e)
88+
await asyncio.sleep(1)
89+
self._tp.submit(self._refresh)
90+
91+
async def token(self):
92+
current_time = time.time()
93+
if current_time > self._refresh_in:
94+
self._tp.submit(self._refresh)
95+
96+
cached_token = await self._cached_token.consume(timeout=3)
97+
if cached_token is None:
98+
if self.last_error is None:
99+
raise issues.ConnectionError(
100+
"%s: timeout occurred while waiting for token.\n%s"
101+
% (
102+
self.__class__.__name__,
103+
self.extra_error_message,
104+
)
105+
)
106+
raise issues.ConnectionError(
107+
"%s: %s.\n%s"
108+
% (self.__class__.__name__, self.last_error, self.extra_error_message)
109+
)
110+
return cached_token
111+
112+
async def auth_metadata(self):
113+
return [(credentials.YDB_AUTH_TICKET_HEADER, await self.token())]

0 commit comments

Comments
 (0)