diff --git a/ydb/oauth2_token_exchange/token_source.py b/ydb/oauth2_token_exchange/token_source.py index f33e329b..0ed2c06e 100644 --- a/ydb/oauth2_token_exchange/token_source.py +++ b/ydb/oauth2_token_exchange/token_source.py @@ -7,9 +7,15 @@ try: import jwt + import jwt.utils except ImportError: jwt = None +try: + from cryptography.hazmat.primitives.serialization import load_pem_private_key +except ImportError: + load_pem_private_key = None + class Token(abc.ABC): def __init__(self, token: str, token_type: str): @@ -48,6 +54,7 @@ def __init__( token_ttl_seconds: int = 3600, ): assert jwt is not None, "Install pyjwt library to use jwt tokens" + assert load_pem_private_key is not None, "Install cryptography library to use jwt tokens" self._signing_method = signing_method self._key_id = key_id if private_key and private_key_file: @@ -57,7 +64,7 @@ def __init__( self._private_key = private_key if private_key_file: private_key_file = os.path.expanduser(private_key_file) - with open(private_key_file, "r") as key_file: + with open(private_key_file, "rb") as key_file: self._private_key = key_file.read() self._issuer = issuer self._subject = subject @@ -70,6 +77,10 @@ def __init__( raise Exception("JWT: no private key specified") if self._token_ttl_seconds <= 0: raise Exception("JWT: invalid jwt token TTL") + if isinstance(self._private_key, str): + self._private_key = self._private_key.encode() + if isinstance(self._private_key, bytes) and jwt.utils.is_pem_format(self._private_key): + self._private_key = load_pem_private_key(self._private_key, password=None) def token(self) -> Token: now = time.time()