Skip to content

Commit 856ad4d

Browse files
feat: avoid starting login flow unless the tokens are expired (#304)
1 parent ab8e800 commit 856ad4d

7 files changed

Lines changed: 131 additions & 6 deletions

File tree

DEPENDENCIES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
| `pydantic-settings` | `>=2, <3.0` | Settings management using Pydantic | [https://pypi.org/project/pydantic-settings/](https://pypi.org/project/pydantic-settings/) |
1515
| `quantinuum-schemas` | `>=7.3, <8.0` | Shared data models for Quantinuum. | [https://github.com/CQCL/quantinuum-schemas](https://github.com/CQCL/quantinuum-schemas) |
1616
| `hugr` | `>=0.14.0, <1.0.0` | Quantinuum's common representation for quantum programs | [https://github.com/CQCL/hugr/tree/main/hugr-py](https://github.com/CQCL/hugr/tree/main/hugr-py) |
17+
| `pyjwt` | `>=2.10.1,<3.0.0` | JSON Web Token implementation in Python | [https://pypi.org/project/PyJWT/](https://pypi.org/project/PyJWT/) |

integration/test_auth_flows.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test basic functionality relating to the auth module."""
22

3+
from contextlib import redirect_stdout
34
from io import StringIO
45
from typing import Any
56

@@ -89,3 +90,24 @@ def test_domain_switch() -> None:
8990
assert original_domain in str(get_nexus_client().base_url)
9091

9192
qnx.users.get_self()
93+
94+
95+
def test_login_when_already_logged_in(monkeypatch: Any) -> None:
96+
"""Test that logging in when already logged in notifies the user appropriately."""
97+
username = CONFIG.qa_user_email
98+
pwd = CONFIG.qa_user_password
99+
100+
# Ensure logged out first
101+
qnx.logout()
102+
# First login
103+
monkeypatch.setattr("sys.stdin", StringIO(username + "\n"))
104+
monkeypatch.setattr("getpass.getpass", lambda prompt: pwd)
105+
qnx.login_with_credentials()
106+
107+
# Try to login again, should indicate already logged in
108+
# Capture output if function prints, or check for raised exception/message
109+
output = StringIO()
110+
with redirect_stdout(output):
111+
qnx.login_with_credentials()
112+
out_str = output.getvalue()
113+
assert "already logged in" in out_str.lower()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies = [
2323
"pydantic-settings >=2, <3.0",
2424
"quantinuum-schemas>=7.3, <8.0",
2525
"hugr >=0.14.0, <1.0.0",
26+
"pyjwt>=2.10.1,<3.0.0",
2627
]
2728

2829
[project.optional-dependencies]

qnexus/client/auth.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Client API for authentication in Nexus."""
22

3+
import datetime
34
import getpass
45
import time
6+
import warnings
57
import webbrowser
68
from http import HTTPStatus
79

810
import httpx
11+
import jwt
912
from colorama import Fore
1013
from pydantic import EmailStr
1114
from rich.console import Console
@@ -18,12 +21,52 @@
1821
_check_version_headers,
1922
get_nexus_client,
2023
)
21-
from qnexus.client.utils import consolidate_error, remove_token, write_token
24+
from qnexus.client.utils import consolidate_error, read_token, remove_token, write_token
2225
from qnexus.config import CONFIG
2326

2427
console = Console()
2528

2629

30+
def is_logged_in() -> bool:
31+
"""Check if the user is already logged in by verifying tokens and
32+
attempting a lightweight authenticated request."""
33+
34+
try:
35+
refresh_token = read_token("refresh_token")
36+
access_token = read_token("access_token")
37+
38+
# Check that tokens are present
39+
if not refresh_token or not access_token:
40+
return False
41+
except FileNotFoundError:
42+
return False
43+
# Check expiry of refresh token (assume JWT)
44+
try:
45+
payload = jwt.decode(refresh_token, options={"verify_signature": False})
46+
exp = payload.get("exp")
47+
if exp:
48+
expiry_dt = datetime.datetime.fromtimestamp(exp)
49+
now = datetime.datetime.now()
50+
hours_left = (expiry_dt - now).total_seconds() / 3600
51+
if hours_left < 24:
52+
msg = (
53+
f"Your refresh token expires in less than 24 hours (expires at {expiry_dt}). "
54+
"You will need to login again after this time or use force=True to refresh now."
55+
)
56+
warnings.warn(msg, category=UserWarning)
57+
except jwt.PyJWTError:
58+
pass
59+
# Try a lightweight authenticated request to check validity
60+
try:
61+
client = get_nexus_client()
62+
resp = client.get("/api/users/v1beta2/me")
63+
if resp.status_code == HTTPStatus.OK:
64+
return True
65+
except (httpx.HTTPError, qnx_exc.AuthenticationError):
66+
pass
67+
return False
68+
69+
2770
def _get_auth_client() -> httpx.Client:
2871
"""Getter function for the Nexus auth client."""
2972
return httpx.Client(
@@ -33,12 +76,15 @@ def _get_auth_client() -> httpx.Client:
3376
)
3477

3578

36-
def login() -> None:
79+
def login(force: bool = False) -> None:
3780
"""
3881
Log in to Quantinuum Nexus using the web browser.
3982
4083
(if web browser can't be launched, displays the link)
4184
"""
85+
if not force and is_logged_in():
86+
print("Already logged in. Tokens are valid.")
87+
return
4288

4389
res = _get_auth_client().post(
4490
"/device/device_authorization",
@@ -119,8 +165,11 @@ def login() -> None:
119165
raise qnx_exc.AuthenticationError("Browser login Failed, code has expired.")
120166

121167

122-
def login_with_credentials() -> None:
168+
def login_with_credentials(force: bool = False) -> None:
123169
"""Log in to Nexus using a username and password."""
170+
if not force and is_logged_in():
171+
print("Already logged in. Tokens are valid.")
172+
return
124173
user_name = input("Enter your Nexus email: ")
125174
pwd = getpass.getpass(prompt="Enter your Nexus password: ")
126175

@@ -129,12 +178,14 @@ def login_with_credentials() -> None:
129178
print(f"✅ Successfully logged in as {user_name}.")
130179

131180

132-
def login_no_interaction(user: EmailStr, pwd: str) -> None:
181+
def login_no_interaction(user: EmailStr, pwd: str, force: bool = False) -> None:
133182
"""Log in to Nexus using a username and password.
134183
Please be careful with storing credentials in plain text or source code.
135184
"""
185+
if not force and is_logged_in():
186+
print("Already logged in. Tokens are valid.")
187+
return
136188
_request_tokens(user=user, pwd=pwd)
137-
138189
print(f"✅ Successfully logged in as {user}.")
139190

140191

qnexus/client/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def normalize_included(included: list[Any]) -> dict[str, dict[str, Any]]:
4343
def remove_token(token_type: TokenTypes) -> None:
4444
"""Delete a token file."""
4545
# Don't try to delete refresh token in Jupyterhub
46+
# these get mounted in automatically and managed externally
4647
if is_jupyterhub_environment() and token_type == "refresh_token":
4748
return
4849
token_file_path = Path.home() / CONFIG.token_path / token_file_from_type[token_type]
@@ -90,7 +91,8 @@ def read_token(token_type: TokenTypes) -> str:
9091
def write_token(token_type: TokenTypes, token: str) -> None:
9192
"""Write a token to a file."""
9293

93-
# don't allow writing of refresh token in Jupyterhub
94+
# Don't allow writing of refresh token in Jupyterhub
95+
# these get mounted in automatically and managed externally
9496
if is_jupyterhub_environment() and token_type == "refresh_token":
9597
return
9698

tests/test_warnings.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import base64
2+
import json
3+
import time
14
import typing
25
import warnings
36
from importlib.metadata import version
@@ -13,6 +16,7 @@
1316
VERSION_STATUS_HEADER,
1417
get_nexus_client,
1518
)
19+
from qnexus.client.auth import is_logged_in
1620
from qnexus.client.utils import write_token
1721

1822
FAKE_LATEST_VERSION = "999.99.999-never-gonna-happen"
@@ -123,6 +127,12 @@ def test_version_check_emits_warning_login(m: mock.Mock) -> None:
123127
"""
124128
base_url = get_nexus_client().base_url
125129

130+
# Mock the is_logged_in check (returns 401 to indicate not logged in)
131+
respx.get(f"{base_url}/api/users/v1beta2/me").mock(return_value=httpx.Response(401))
132+
133+
# Mock token refresh attempt (also returns 401 since not logged in)
134+
respx.post(f"{base_url}/auth/tokens/refresh").mock(return_value=httpx.Response(401))
135+
126136
# Mock the login endpoints
127137
respx.post(f"{base_url}/auth/device/device_authorization").mock(
128138
return_value=httpx.Response(
@@ -156,3 +166,39 @@ def test_version_check_emits_warning_login(m: mock.Mock) -> None:
156166

157167
_check_request_includes_version_data(token_request_route)
158168
_check_version_warning_emitted(captured)
169+
170+
171+
def _base64url_encode(data: bytes) -> str:
172+
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
173+
174+
175+
@respx.mock
176+
def test_refresh_token_expiry_warning_emitted() -> None:
177+
"""Emits a warning when the refresh token expires in less than 24 hours."""
178+
# Create an unsigned JWT with exp in the next hour
179+
header = {"alg": "none", "typ": "JWT"}
180+
payload = {"exp": int(time.time()) + 3600}
181+
182+
jwt_token = (
183+
_base64url_encode(json.dumps(header).encode())
184+
+ "."
185+
+ _base64url_encode(json.dumps(payload).encode())
186+
+ "."
187+
)
188+
189+
write_token("refresh_token", jwt_token)
190+
write_token("access_token", "dummy_id")
191+
192+
# Mock the lightweight authenticated request
193+
me_route = respx.get(f"{get_nexus_client().base_url}/api/users/v1beta2/me").mock(
194+
return_value=httpx.Response(200, json={"ok": True})
195+
)
196+
197+
with warnings.catch_warnings(record=True) as w:
198+
warnings.simplefilter("always")
199+
assert is_logged_in() is True
200+
201+
messages = [str(item.message) for item in w]
202+
assert any("expires in less than 24 hours" in msg for msg in messages)
203+
204+
assert me_route.called

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)