Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 56 additions & 42 deletions harness/determined/cli/token.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import argparse
import json
from typing import Any, List, Sequence
from typing import Any, List

from determined import cli
from determined.cli import errors, render
from determined.common import api, util
from determined.common.api import authentication, bindings
from determined.common.api import authentication
from determined.common.experimental import token
from determined.experimental import client

TOKEN_HEADERS = [
"ID",
Expand All @@ -18,9 +20,9 @@
]


def render_token_info(token_info: Sequence[bindings.v1TokenInfo]) -> None:
def render_token_info(token_info: List[token.AccessToken]) -> None:
values = [
[t.id, t.userId, t.description, t.createdAt, t.expiry, t.revoked, t.tokenType]
[t.id, t.user_id, t.description, t.created_at, t.expiry, t.revoked, t.token_type]
for t in token_info
]
render.tabulate_or_csv(TOKEN_HEADERS, values, False)
Expand All @@ -29,15 +31,18 @@ def render_token_info(token_info: Sequence[bindings.v1TokenInfo]) -> None:
def describe_token(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
try:
resp = bindings.get_GetAccessTokens(session=sess, tokenIds=args.token_id)
d = client.Determined._from_session(sess)
token_info = d.describe_tokens(args.token_id)

if args.json or args.yaml:
json_data = [t.to_json() for t in resp.tokenInfo]
json_data = [t.to_json() for t in token_info]
print(json_data)
if args.json:
render.print_json(json_data)
else:
print(util.yaml_safe_dump(json_data, default_flow_style=False))
else:
render_token_info(resp.tokenInfo)
render_token_info(token_info)
except api.errors.APIException as e:
raise errors.CliError(f"Caught APIException: {str(e)}")
except Exception as e:
Expand All @@ -49,27 +54,31 @@ def list_tokens(args: argparse.Namespace) -> None:
try:
username = args.username if args.username else None
show_inactive = True if args.show_inactive else False
resp = bindings.get_GetAccessTokens(sess, username=username, showInactive=show_inactive)
d = client.Determined._from_session(sess)
token_info = d.list_tokens(username, show_inactive)

if args.json or args.yaml:
json_data = [t.to_json() for t in resp.tokenInfo]
json_data = [t.to_json() for t in token_info]
if args.json:
render.print_json(json_data)
else:
print(util.yaml_safe_dump(json_data, default_flow_style=False))
else:
render_token_info(resp.tokenInfo)
render_token_info(token_info)
except Exception as e:
raise errors.CliError(f"Error fetching tokens: {e}")


def revoke_token(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
try:
request = bindings.v1PatchAccessTokenRequest(
tokenId=args.token_id, description=None, setRevoked=True
)
resp = bindings.patch_PatchAccessToken(sess, body=request, tokenId=args.token_id)
print(json.dumps(resp.to_json(), indent=2))
d = client.Determined._from_session(sess)
print(args.token_id)
token_info_list = d.describe_token(args.token_id)
# Only one token will be returned, use the first one
token_info = token_info_list[0]
token_info.revoke_token()
render_token_info([token_info])
print(f"Successfully revoked token {args.token_id}.")
except api.errors.NotFoundException:
raise errors.CliError("Token not found")
Expand All @@ -79,27 +88,25 @@ def create_token(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
try:
username = args.username or sess.username
user = bindings.get_GetUserByUsername(session=sess, username=username).user

if user is None or user.id is None:
d = client.Determined._from_session(sess)
user_obj = d.get_user_by_name(username)
if user_obj is None or user_obj.user_id is None:
raise errors.CliError(f"User '{username}' not found or does not have an ID")

# convert days into hours Go duration format
expiration_in_hours = None
if args.expiration_days:
expiration_in_hours = str(24 * args.expiration_days) + "h"

request = bindings.v1PostAccessTokenRequest(
userId=user.id, lifespan=expiration_in_hours, description=args.description
)
resp = bindings.post_PostAccessToken(sess, body=request).to_json()
token_info = d.create_token(user_obj.user_id, expiration_in_hours, args.description)

output_string = None
if args.yaml:
output_string = util.yaml_safe_dump(resp, default_flow_style=False)
output_string = util.yaml_safe_dump(token_info.to_json(), default_flow_style=False)
elif args.json:
output_string = json.dumps(resp, indent=2)
output_string = json.dumps(token_info.to_json(), indent=2)
else:
output_string = f'TokenID: {resp["tokenId"]}\nAccess-Token: {resp["token"]}'
output_string = f"TokenID: {token_info.tokenId}\nAccess-Token: {token_info.token}"

print(output_string)
except api.errors.APIException as e:
Expand All @@ -112,14 +119,26 @@ def edit_token(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
try:
if args.token_id:
request = bindings.v1PatchAccessTokenRequest(
tokenId=args.token_id,
description=args.description if args.description else None,
setRevoked=False,
)
resp = bindings.patch_PatchAccessToken(sess, body=request, tokenId=args.token_id)
print(json.dumps(resp.to_json(), indent=2))
print(f"Successfully updated token with ID: {args.token_id}.")
d = client.Determined._from_session(sess)
token_info_list = d.describe_token(args.token_id)
# Only one token will be returned, use the first one
token_info = token_info_list[0]
if args.description:
token_info.edit_token(args.description)
if args.json or args.yaml:
json_data = token_info.to_json()
print(json_data)
if args.json:
render.print_json(json_data)
else:
print(util.yaml_safe_dump(json_data, default_flow_style=False))
else:
render_token_info([token_info])
print(f"Successfully updated token with ID: {args.token_id}.")
else:
raise errors.CliError(
f"Please provide a description for token ID '{args.token_id}'."
)
except api.errors.APIException as e:
raise errors.CliError(f"Caught APIException: {str(e)}")
except api.errors.NotFoundException:
Expand All @@ -128,15 +147,10 @@ def edit_token(args: argparse.Namespace) -> None:

def login_with_token(args: argparse.Namespace) -> None:
try:
unauth_session = api.UnauthSession(master=args.master, cert=cli.cert)
auth_headers = {"Authorization": f"Bearer {args.token}"}
user_data = unauth_session.get("/api/v1/me", headers=auth_headers).json()
username = user_data.get("user").get("username")

token_store = authentication.TokenStore(args.master)
token_store.set_token(username, args.token)
token_store.set_active(username)
print(f"Authenticated as {username}.")
sess = authentication.login_with_token(
master_address=args.master, token=args.token, cert=cli.cert
)
print(f"Authenticated as {sess.username}.")
except api.errors.APIException as e:
raise errors.CliError(f"Caught APIException: {str(e)}")
except api.errors.UnauthenticatedException as e:
Expand Down
32 changes: 32 additions & 0 deletions harness/determined/common/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,35 @@ def validate_token_store_v2(store: Any) -> None:
if not all(api.canonicalize_master_url(key) == key for key in masters):
# A non-canonical master url is present.
raise api.errors.CorruptTokenCacheException()


def login_with_token(
master_address: str,
token: str,
cert: Optional[certs.Cert] = None,
) -> "api.Session":
"""
Log in using a provided token, without interacting with the TokenStore on the file system.

This function sends a login request to the master to authenticate the token and retrieve user
information. If successful, it stores the token in the TokenStore and returns a new api.Session
object, which can be used for future authenticated requests.

Returns:
api.Session: A new session object with the authenticated token.
"""
unauth_session = api.UnauthSession(master=master_address, cert=cert)
headers = {"Authorization": f"Bearer {token}"}
try:
r = unauth_session.get("api/v1/me", headers=headers)
if r.status_code != 200:
raise api.errors.APIException(response=r)
except (api.errors.UnauthenticatedException, api.errors.APIException):
raise

username = r.json()["user"]["username"]

token_store = TokenStore(master_address)
token_store.set_token(username, token)
token_store.set_active(username)
return api.Session(master=master_address, username=username, token=token, cert=cert)
46 changes: 46 additions & 0 deletions harness/determined/common/experimental/determined.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it'd be great if the Determined SDK client class accepted a token: Optional[str] parameter as well to be able to use the SDK with these access tokens.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
metrics,
model,
oauth2_scim_client,
token,
trial,
user,
workspace,
Expand Down Expand Up @@ -582,3 +583,48 @@ def stream_trials_validation_metrics(
stacklevel=2,
)
return trial._stream_validation_metrics(self._session, trial_ids)

def describe_token(self, token_id: int) -> token.AccessToken:
"""
Get the :class:`~determined.experimental.Token` representing the
token info with the provided token ID.
"""
resp = bindings.get_GetAccessTokens(
session=self._session, tokenIds=[token_id], showInactive=True
)
return token.AccessToken._from_bindings(resp.tokenInfo, self._session)

def describe_tokens(self, token_ids: List[int]) -> token.AccessToken:
"""
Get the :class:`~determined.experimental.Token` representing list of
token info with the provided token IDs.
"""
resp = bindings.get_GetAccessTokens(
session=self._session, tokenIds=token_ids, showInactive=False
)
return token.AccessToken._from_bindings(resp.tokenInfo, self._session)

def list_tokens(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think all 3 of these new methods: describe_token, describe_tokens, and list_tokens should be combined into a singular list_tokens that takes in all the parameters.

the CLI and SDK roughly overlap in functionality, but the methods shouldn't be an exact replica. in the CLI, we have to be careful about which methods to expose and how because the UI is limited to the command line, so we generally value convenience and modularity. but the SDK has a bit of a different philosophy. it's in code and made for developers, so the methods we expose can be more powerful and more robust. (see list_experiments here as an example)

the functionality in describe_token, describe_tokens, and list_tokens can be captured with a single list_tokens method, and the user can pass in exactly what they want, just like in code. i'm thinking list_tokens would have a signature like:

def list_tokens(
        self,
        username: Optional[str],
        token_ids: Optional[List[int]],
        include_inactive: bool,
        sort_by: token.TokenSortBy = token.TokenSortBy.NAME,
        order_by: OrderBy = OrderBy.ASCENDING,
    ) -> List[token.Token]:

this is nice because it basically mirrors the actual bindings API call, the user isn't limited by separate methods, and the method is representative of the actual query we make in the system. we should also include sort/order as accepted parameters.

self, username: Optional[str] = None, show_inactive: Optional[bool] = None
) -> token.AccessToken:
"""
Get the :class:`~determined.experimental.Token` representing list of
token info with the provided username.
"""
resp = bindings.get_GetAccessTokens(
session=self._session, username=username, showInactive=show_inactive
)
return token.AccessToken._from_bindings(resp.tokenInfo, self._session)

def create_token(
self, user_id: int, lifespan: Optional[str] = None, description: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make user_id username here, since the other methods in the SDK use that as the standard identifier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, we've decided to accept days, so let's just make it expiration_days here.

) -> bindings.v1PostAccessTokenResponse:
"""
Get the :`bindings.v1PostAccessTokenResponse` representing the
token and token ID with the provided user ID.
"""
post_create_token = bindings.v1PostAccessTokenRequest(
userId=user_id, description=description, lifespan=lifespan
)
resp = bindings.post_PostAccessToken(session=self._session, body=post_create_token)
return resp
109 changes: 109 additions & 0 deletions harness/determined/common/experimental/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import enum
from typing import List, Optional

from determined.common import api
from determined.common.api import bindings


class TokenType(enum.Enum):
# UNSPECIFIED is internal to the bound API and is not be exposed to the front end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment seems to be leftover from copy/paste?

USER_SESSION = bindings.v1TokenType.USER_SESSION.name
ACCESS_TOKEN = bindings.v1TokenType.ACCESS_TOKEN.name


class AccessToken:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just call this Token, that's what it is in the CLI after all

"""
A class representing a AccessToken object that contains user session token info and
access token info.
It can be obtained from :func:`determined.experimental.client.list_access_tokens`
Attributes:
session: HTTP request session.
token_id: (int) The ID of the access token in user sessions table.
user_id: (int) Unique ID for the user.
expiry: (str) Timestamp expires at reported.
created_at: (str) Timestamp created at reported.
token_type: (TokenType) Token type of the token.
revoked: (Mutable, Optional[bool]) The datetime when the token was revoked.
Null if the token is still active.
description: (Mutable, Optional[str]) Human-friendly description of token.
Note:
Mutable properties may be changed by methods that update these values either automatically
(eg. `revoke_tokens`, `edit_tokens`) or explicitly with :meth:`reload()`.
"""

def __init__(self, token_id: int, session: api.Session):
self.token_id = token_id
self._session = session

self.user_id: Optional[int] = None
self.expiry: Optional[str] = None
self.created_at: Optional[str] = None
self.token_type: Optional[TokenType] = None
self.revoked: Optional[bool] = None
self.description: Optional[str] = None

def _hydrate(self, tokenInfo: bindings.v1TokenInfo) -> None:
self.user_id = tokenInfo.userId
self.expiry = tokenInfo.expiry
self.created_at = tokenInfo.createdAt
self.token_type = tokenInfo.tokenType
self.revoked = tokenInfo.revoked if tokenInfo.revoked is not None else False
self.description = tokenInfo.description if tokenInfo.description is not None else ""

def reload(self) -> None:
resp = bindings.get_GetAccessTokens(
session=self._session, tokenIds=[self.id], showInactive=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does self.id exist? either change this to self.token_id or add a property to this class:

@property
    def id(self) -> int:
        return self._id

).tokenInfo
self._hydrate(resp[0])

def edit_token(self, desc) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make this set_description to follow standard for other SDK methods.

patch_token_description = bindings.v1PatchAccessTokenRequest(
tokenId=self.token_id, description=desc
)
bindings.patch_PatchAccessToken(
self._session, body=patch_token_description, tokenId=self.token_id
)
self.reload()

def revoke_token(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's call this revoke, we're already on the token object, no need for the extra verbosity.

patch_revoke_token = bindings.v1PatchAccessTokenRequest(
tokenId=self.token_id, description=None, setRevoked=True
)
bindings.patch_PatchAccessToken(
self._session, body=patch_revoke_token, tokenId=self.token_id
)
self.reload()

def to_json(self):
return {
"token_id": self.token_id,
"user_id": self.user_id,
"description": self.description,
"created_at": self.created_at if self.created_at else None,
"expiry": self.expiry if self.expiry else None,
"revoked": self.revoked if self.revoked else None,
"token_type": self.token_type.name
if isinstance(self.token_type, enum.Enum)
else self.token_type,
}

@classmethod
def _from_bindings(
cls, AccessToken_bindings: List[bindings.v1TokenInfo], session: api.Session
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casing is weird with AccessToken_bindings. just call it token_bindings?

) -> "AccessToken | List[AccessToken]":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method should always just return a single Token object, because it lives on the Token class. it's up to the caller to create the list.

assert len(AccessToken_bindings) > 0

access_token_infos = []
for binding in AccessToken_bindings:
assert binding.token_id
AccessTokenInfo = cls(session=session, token_id=binding.token_id)
AccessTokenInfo._hydrate(binding)
access_token_infos.append(AccessTokenInfo)

# Return a single instance if only one tokenInfo is provided
if len(access_token_infos) == 1:
return access_token_infos

# Otherwise, return the list of AccessToken instances
return access_token_infos