diff --git a/harness/determined/cli/token.py b/harness/determined/cli/token.py index 4471ec24ec3..a69db6b8763 100644 --- a/harness/determined/cli/token.py +++ b/harness/determined/cli/token.py @@ -1,11 +1,13 @@ import argparse import json -from typing import Any, List, Sequence +from typing import Any, List from determined import cli from determined.cli import errors, render from determined.common import api, util -from determined.common.api import authentication, bindings +from determined.common.api import authentication +from determined.common.experimental import token +from determined.experimental import client TOKEN_HEADERS = [ "ID", @@ -18,9 +20,9 @@ ] -def render_token_info(token_info: Sequence[bindings.v1TokenInfo]) -> None: +def render_token_info(token_info: List[token.AccessToken]) -> None: values = [ - [t.id, t.userId, t.description, t.createdAt, t.expiry, t.revoked, t.tokenType] + [t.id, t.user_id, t.description, t.created_at, t.expiry, t.revoked, t.token_type] for t in token_info ] render.tabulate_or_csv(TOKEN_HEADERS, values, False) @@ -29,15 +31,18 @@ def render_token_info(token_info: Sequence[bindings.v1TokenInfo]) -> None: def describe_token(args: argparse.Namespace) -> None: sess = cli.setup_session(args) try: - resp = bindings.get_GetAccessTokens(session=sess, tokenIds=args.token_id) + d = client.Determined._from_session(sess) + token_info = d.describe_tokens(args.token_id) + if args.json or args.yaml: - json_data = [t.to_json() for t in resp.tokenInfo] + json_data = [t.to_json() for t in token_info] + print(json_data) if args.json: render.print_json(json_data) else: print(util.yaml_safe_dump(json_data, default_flow_style=False)) else: - render_token_info(resp.tokenInfo) + render_token_info(token_info) except api.errors.APIException as e: raise errors.CliError(f"Caught APIException: {str(e)}") except Exception as e: @@ -49,15 +54,17 @@ def list_tokens(args: argparse.Namespace) -> None: try: username = args.username if args.username else None show_inactive = True if args.show_inactive else False - resp = bindings.get_GetAccessTokens(sess, username=username, showInactive=show_inactive) + d = client.Determined._from_session(sess) + token_info = d.list_tokens(username, show_inactive) + if args.json or args.yaml: - json_data = [t.to_json() for t in resp.tokenInfo] + json_data = [t.to_json() for t in token_info] if args.json: render.print_json(json_data) else: print(util.yaml_safe_dump(json_data, default_flow_style=False)) else: - render_token_info(resp.tokenInfo) + render_token_info(token_info) except Exception as e: raise errors.CliError(f"Error fetching tokens: {e}") @@ -65,11 +72,13 @@ def list_tokens(args: argparse.Namespace) -> None: def revoke_token(args: argparse.Namespace) -> None: sess = cli.setup_session(args) try: - request = bindings.v1PatchAccessTokenRequest( - tokenId=args.token_id, description=None, setRevoked=True - ) - resp = bindings.patch_PatchAccessToken(sess, body=request, tokenId=args.token_id) - print(json.dumps(resp.to_json(), indent=2)) + d = client.Determined._from_session(sess) + print(args.token_id) + token_info_list = d.describe_token(args.token_id) + # Only one token will be returned, use the first one + token_info = token_info_list[0] + token_info.revoke_token() + render_token_info([token_info]) print(f"Successfully revoked token {args.token_id}.") except api.errors.NotFoundException: raise errors.CliError("Token not found") @@ -79,27 +88,25 @@ def create_token(args: argparse.Namespace) -> None: sess = cli.setup_session(args) try: username = args.username or sess.username - user = bindings.get_GetUserByUsername(session=sess, username=username).user - - if user is None or user.id is None: + d = client.Determined._from_session(sess) + user_obj = d.get_user_by_name(username) + if user_obj is None or user_obj.user_id is None: raise errors.CliError(f"User '{username}' not found or does not have an ID") # convert days into hours Go duration format + expiration_in_hours = None if args.expiration_days: expiration_in_hours = str(24 * args.expiration_days) + "h" - request = bindings.v1PostAccessTokenRequest( - userId=user.id, lifespan=expiration_in_hours, description=args.description - ) - resp = bindings.post_PostAccessToken(sess, body=request).to_json() + token_info = d.create_token(user_obj.user_id, expiration_in_hours, args.description) output_string = None if args.yaml: - output_string = util.yaml_safe_dump(resp, default_flow_style=False) + output_string = util.yaml_safe_dump(token_info.to_json(), default_flow_style=False) elif args.json: - output_string = json.dumps(resp, indent=2) + output_string = json.dumps(token_info.to_json(), indent=2) else: - output_string = f'TokenID: {resp["tokenId"]}\nAccess-Token: {resp["token"]}' + output_string = f"TokenID: {token_info.tokenId}\nAccess-Token: {token_info.token}" print(output_string) except api.errors.APIException as e: @@ -112,14 +119,26 @@ def edit_token(args: argparse.Namespace) -> None: sess = cli.setup_session(args) try: if args.token_id: - request = bindings.v1PatchAccessTokenRequest( - tokenId=args.token_id, - description=args.description if args.description else None, - setRevoked=False, - ) - resp = bindings.patch_PatchAccessToken(sess, body=request, tokenId=args.token_id) - print(json.dumps(resp.to_json(), indent=2)) - print(f"Successfully updated token with ID: {args.token_id}.") + d = client.Determined._from_session(sess) + token_info_list = d.describe_token(args.token_id) + # Only one token will be returned, use the first one + token_info = token_info_list[0] + if args.description: + token_info.edit_token(args.description) + if args.json or args.yaml: + json_data = token_info.to_json() + print(json_data) + if args.json: + render.print_json(json_data) + else: + print(util.yaml_safe_dump(json_data, default_flow_style=False)) + else: + render_token_info([token_info]) + print(f"Successfully updated token with ID: {args.token_id}.") + else: + raise errors.CliError( + f"Please provide a description for token ID '{args.token_id}'." + ) except api.errors.APIException as e: raise errors.CliError(f"Caught APIException: {str(e)}") except api.errors.NotFoundException: @@ -128,15 +147,10 @@ def edit_token(args: argparse.Namespace) -> None: def login_with_token(args: argparse.Namespace) -> None: try: - unauth_session = api.UnauthSession(master=args.master, cert=cli.cert) - auth_headers = {"Authorization": f"Bearer {args.token}"} - user_data = unauth_session.get("/api/v1/me", headers=auth_headers).json() - username = user_data.get("user").get("username") - - token_store = authentication.TokenStore(args.master) - token_store.set_token(username, args.token) - token_store.set_active(username) - print(f"Authenticated as {username}.") + sess = authentication.login_with_token( + master_address=args.master, token=args.token, cert=cli.cert + ) + print(f"Authenticated as {sess.username}.") except api.errors.APIException as e: raise errors.CliError(f"Caught APIException: {str(e)}") except api.errors.UnauthenticatedException as e: diff --git a/harness/determined/common/api/authentication.py b/harness/determined/common/api/authentication.py index 50109125b80..60db2d9ef95 100644 --- a/harness/determined/common/api/authentication.py +++ b/harness/determined/common/api/authentication.py @@ -704,3 +704,35 @@ def validate_token_store_v2(store: Any) -> None: if not all(api.canonicalize_master_url(key) == key for key in masters): # A non-canonical master url is present. raise api.errors.CorruptTokenCacheException() + + +def login_with_token( + master_address: str, + token: str, + cert: Optional[certs.Cert] = None, +) -> "api.Session": + """ + Log in using a provided token, without interacting with the TokenStore on the file system. + + This function sends a login request to the master to authenticate the token and retrieve user + information. If successful, it stores the token in the TokenStore and returns a new api.Session + object, which can be used for future authenticated requests. + + Returns: + api.Session: A new session object with the authenticated token. + """ + unauth_session = api.UnauthSession(master=master_address, cert=cert) + headers = {"Authorization": f"Bearer {token}"} + try: + r = unauth_session.get("api/v1/me", headers=headers) + if r.status_code != 200: + raise api.errors.APIException(response=r) + except (api.errors.UnauthenticatedException, api.errors.APIException): + raise + + username = r.json()["user"]["username"] + + token_store = TokenStore(master_address) + token_store.set_token(username, token) + token_store.set_active(username) + return api.Session(master=master_address, username=username, token=token, cert=cert) diff --git a/harness/determined/common/experimental/determined.py b/harness/determined/common/experimental/determined.py index 6f8496ea123..cd74988c16d 100644 --- a/harness/determined/common/experimental/determined.py +++ b/harness/determined/common/experimental/determined.py @@ -13,6 +13,7 @@ metrics, model, oauth2_scim_client, + token, trial, user, workspace, @@ -582,3 +583,48 @@ def stream_trials_validation_metrics( stacklevel=2, ) return trial._stream_validation_metrics(self._session, trial_ids) + + def describe_token(self, token_id: int) -> token.AccessToken: + """ + Get the :class:`~determined.experimental.Token` representing the + token info with the provided token ID. + """ + resp = bindings.get_GetAccessTokens( + session=self._session, tokenIds=[token_id], showInactive=True + ) + return token.AccessToken._from_bindings(resp.tokenInfo, self._session) + + def describe_tokens(self, token_ids: List[int]) -> token.AccessToken: + """ + Get the :class:`~determined.experimental.Token` representing list of + token info with the provided token IDs. + """ + resp = bindings.get_GetAccessTokens( + session=self._session, tokenIds=token_ids, showInactive=False + ) + return token.AccessToken._from_bindings(resp.tokenInfo, self._session) + + def list_tokens( + self, username: Optional[str] = None, show_inactive: Optional[bool] = None + ) -> token.AccessToken: + """ + Get the :class:`~determined.experimental.Token` representing list of + token info with the provided username. + """ + resp = bindings.get_GetAccessTokens( + session=self._session, username=username, showInactive=show_inactive + ) + return token.AccessToken._from_bindings(resp.tokenInfo, self._session) + + def create_token( + self, user_id: int, lifespan: Optional[str] = None, description: Optional[str] = None + ) -> bindings.v1PostAccessTokenResponse: + """ + Get the :`bindings.v1PostAccessTokenResponse` representing the + token and token ID with the provided user ID. + """ + post_create_token = bindings.v1PostAccessTokenRequest( + userId=user_id, description=description, lifespan=lifespan + ) + resp = bindings.post_PostAccessToken(session=self._session, body=post_create_token) + return resp diff --git a/harness/determined/common/experimental/token.py b/harness/determined/common/experimental/token.py new file mode 100644 index 00000000000..8720c97ae8a --- /dev/null +++ b/harness/determined/common/experimental/token.py @@ -0,0 +1,109 @@ +import enum +from typing import List, Optional + +from determined.common import api +from determined.common.api import bindings + + +class TokenType(enum.Enum): + # UNSPECIFIED is internal to the bound API and is not be exposed to the front end + USER_SESSION = bindings.v1TokenType.USER_SESSION.name + ACCESS_TOKEN = bindings.v1TokenType.ACCESS_TOKEN.name + + +class AccessToken: + """ + A class representing a AccessToken object that contains user session token info and + access token info. + It can be obtained from :func:`determined.experimental.client.list_access_tokens` + Attributes: + session: HTTP request session. + token_id: (int) The ID of the access token in user sessions table. + user_id: (int) Unique ID for the user. + expiry: (str) Timestamp expires at reported. + created_at: (str) Timestamp created at reported. + token_type: (TokenType) Token type of the token. + revoked: (Mutable, Optional[bool]) The datetime when the token was revoked. + Null if the token is still active. + description: (Mutable, Optional[str]) Human-friendly description of token. + + Note: + Mutable properties may be changed by methods that update these values either automatically + (eg. `revoke_tokens`, `edit_tokens`) or explicitly with :meth:`reload()`. + """ + + def __init__(self, token_id: int, session: api.Session): + self.token_id = token_id + self._session = session + + self.user_id: Optional[int] = None + self.expiry: Optional[str] = None + self.created_at: Optional[str] = None + self.token_type: Optional[TokenType] = None + self.revoked: Optional[bool] = None + self.description: Optional[str] = None + + def _hydrate(self, tokenInfo: bindings.v1TokenInfo) -> None: + self.user_id = tokenInfo.userId + self.expiry = tokenInfo.expiry + self.created_at = tokenInfo.createdAt + self.token_type = tokenInfo.tokenType + self.revoked = tokenInfo.revoked if tokenInfo.revoked is not None else False + self.description = tokenInfo.description if tokenInfo.description is not None else "" + + def reload(self) -> None: + resp = bindings.get_GetAccessTokens( + session=self._session, tokenIds=[self.id], showInactive=True + ).tokenInfo + self._hydrate(resp[0]) + + def edit_token(self, desc) -> None: + patch_token_description = bindings.v1PatchAccessTokenRequest( + tokenId=self.token_id, description=desc + ) + bindings.patch_PatchAccessToken( + self._session, body=patch_token_description, tokenId=self.token_id + ) + self.reload() + + def revoke_token(self) -> None: + patch_revoke_token = bindings.v1PatchAccessTokenRequest( + tokenId=self.token_id, description=None, setRevoked=True + ) + bindings.patch_PatchAccessToken( + self._session, body=patch_revoke_token, tokenId=self.token_id + ) + self.reload() + + def to_json(self): + return { + "token_id": self.token_id, + "user_id": self.user_id, + "description": self.description, + "created_at": self.created_at if self.created_at else None, + "expiry": self.expiry if self.expiry else None, + "revoked": self.revoked if self.revoked else None, + "token_type": self.token_type.name + if isinstance(self.token_type, enum.Enum) + else self.token_type, + } + + @classmethod + def _from_bindings( + cls, AccessToken_bindings: List[bindings.v1TokenInfo], session: api.Session + ) -> "AccessToken | List[AccessToken]": + assert len(AccessToken_bindings) > 0 + + access_token_infos = [] + for binding in AccessToken_bindings: + assert binding.token_id + AccessTokenInfo = cls(session=session, token_id=binding.token_id) + AccessTokenInfo._hydrate(binding) + access_token_infos.append(AccessTokenInfo) + + # Return a single instance if only one tokenInfo is provided + if len(access_token_infos) == 1: + return access_token_infos + + # Otherwise, return the list of AccessToken instances + return access_token_infos