Skip to content

Commit a6a3bc9

Browse files
committed
Merge branch 'issue817-prerefresh-bearer-for-threaded-task'
2 parents 90b55da + 7839676 commit a6a3bc9

File tree

7 files changed

+407
-37
lines changed

7 files changed

+407
-37
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
### Fixed
1919

2020
- `DataCube.sar_backscatter()`: add corresponding band names to metadata when enabling "mask", "contributing_area", "local_incidence_angle" or "ellipsoid_incidence_angle" ([#804](https://github.com/Open-EO/openeo-python-client/issues/804))
21+
- Proactively refresh access/bearer token in `MultiBackendJobManager` before launching a job start thread ([#817](https://github.com/Open-EO/openeo-python-client/issues/817))
2122

2223

2324
## [0.45.0] - 2025-09-17

openeo/extra/job_management/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def __init__(
244244
)
245245
self._thread = None
246246
self._worker_pool = None
247+
# Generic cache
248+
self._cache = {}
247249

248250
def add_backend(
249251
self,
@@ -650,6 +652,8 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
650652
# start job if not yet done by callback
651653
try:
652654
job_con = job.connection
655+
# Proactively refresh bearer token (because task in thread will not be able to do that)
656+
self._refresh_bearer_token(connection=job_con)
653657
task = _JobStartTask(
654658
root_url=job_con.root_url,
655659
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
@@ -670,6 +674,21 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
670674
df.loc[i, "status"] = "skipped"
671675
stats["start_job skipped"] += 1
672676

677+
def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60) -> None:
678+
"""
679+
Helper to proactively refresh the bearer (access) token of the connection
680+
(but not too often, based on `max_age`).
681+
"""
682+
# TODO: be smarter about timing, e.g. by inspecting expiry of current token?
683+
now = time.time()
684+
key = f"connection:{id(connection)}:refresh-time"
685+
if self._cache.get(key, 0) + max_age < now:
686+
refreshed = connection.try_access_token_refresh()
687+
if refreshed:
688+
self._cache[key] = now
689+
else:
690+
_log.warning("Failed to proactively refresh bearer token")
691+
673692
def _process_threadworker_updates(
674693
self,
675694
worker_pool: _JobManagerWorkerThreadPool,

openeo/rest/_testing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,22 @@ def at_url(cls, root_url: str, *, requests_mock, capabilities: Optional[dict] =
132132
connection = Connection(root_url)
133133
return cls(requests_mock=requests_mock, connection=connection)
134134

135+
def setup_credentials_oidc(self, *, issuer: str = "https://oidc.test", id: str = "oi"):
136+
self._requests_mock.get(
137+
self.connection.build_url("/credentials/oidc"),
138+
json={
139+
"providers": [
140+
{
141+
"id": id,
142+
"issuer": issuer,
143+
"title": id,
144+
"scopes": ["openid"],
145+
}
146+
]
147+
},
148+
)
149+
return self
150+
135151
def setup_collection(
136152
self,
137153
collection_id: str,

openeo/rest/auth/testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def token_callback_resource_owner_password_credentials(self, params: dict, conte
143143
assert params["scope"] == self.expected_fields["scope"]
144144
return self._build_token_response()
145145

146+
def token_callback_block_400(self, params: dict, context):
147+
"""Failing callback with 400 Bad Request"""
148+
context.status_code = 400
149+
return "block_400"
150+
146151
def device_code_callback(self, request: requests_mock.request._RequestObjectProxy, context):
147152
params = self._get_query_params(query=request.text)
148153
assert params["client_id"] == self.expected_client_id

openeo/rest/connection.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -342,28 +342,32 @@ def _authenticate_oidc(
342342
*,
343343
provider_id: str,
344344
store_refresh_token: bool = False,
345-
fallback_refresh_token_to_store: Optional[str] = None,
345+
auto_renew_from_refresh_token: bool = False,
346+
fallback_refresh_token: Optional[str] = None,
346347
oidc_auth_renewer: Optional[OidcAuthenticator] = None,
347348
) -> Connection:
348349
"""
349350
Authenticate through OIDC and set up bearer token (based on OIDC access_token) for further requests.
350351
"""
351-
tokens = authenticator.get_tokens(request_refresh_token=store_refresh_token)
352+
request_refresh_token = store_refresh_token or (not oidc_auth_renewer and auto_renew_from_refresh_token)
353+
tokens = authenticator.get_tokens(request_refresh_token=request_refresh_token)
352354
_log.info("Obtained tokens: {t}".format(t=[k for k, v in tokens._asdict().items() if v]))
355+
356+
refresh_token = tokens.refresh_token or fallback_refresh_token
353357
if store_refresh_token:
354-
refresh_token = tokens.refresh_token or fallback_refresh_token_to_store
355358
if refresh_token:
356359
self._get_refresh_token_store().set_refresh_token(
357360
issuer=authenticator.provider_info.issuer,
358361
client_id=authenticator.client_id,
359362
refresh_token=refresh_token
360363
)
361-
if not oidc_auth_renewer:
362-
oidc_auth_renewer = OidcRefreshTokenAuthenticator(
363-
client_info=authenticator.client_info, refresh_token=refresh_token
364-
)
365364
else:
366365
_log.warning("No OIDC refresh token to store.")
366+
if not oidc_auth_renewer and auto_renew_from_refresh_token and refresh_token:
367+
oidc_auth_renewer = OidcRefreshTokenAuthenticator(
368+
client_info=authenticator.client_info, refresh_token=refresh_token
369+
)
370+
367371
token = tokens.access_token
368372
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token)
369373
self._oidc_auth_renewer = oidc_auth_renewer
@@ -452,7 +456,12 @@ def authenticate_oidc_resource_owner_password_credentials(
452456
authenticator = OidcResourceOwnerPasswordAuthenticator(
453457
client_info=client_info, username=username, password=password
454458
)
455-
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)
459+
return self._authenticate_oidc(
460+
authenticator,
461+
provider_id=provider_id,
462+
store_refresh_token=store_refresh_token,
463+
oidc_auth_renewer=authenticator,
464+
)
456465

457466
def authenticate_oidc_refresh_token(
458467
self,
@@ -493,7 +502,7 @@ def authenticate_oidc_refresh_token(
493502
authenticator,
494503
provider_id=provider_id,
495504
store_refresh_token=store_refresh_token,
496-
fallback_refresh_token_to_store=refresh_token,
505+
fallback_refresh_token=refresh_token,
497506
oidc_auth_renewer=authenticator,
498507
)
499508

@@ -534,7 +543,13 @@ def authenticate_oidc_device(
534543
authenticator = OidcDeviceAuthenticator(
535544
client_info=client_info, use_pkce=use_pkce, max_poll_time=max_poll_time, **kwargs
536545
)
537-
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)
546+
return self._authenticate_oidc(
547+
authenticator,
548+
provider_id=provider_id,
549+
store_refresh_token=store_refresh_token,
550+
# TODO: expose `auto_renew_from_refresh_token` directly as option instead of reusing `store_refresh_token` arg?
551+
auto_renew_from_refresh_token=store_refresh_token,
552+
)
538553

539554
def authenticate_oidc(
540555
self,
@@ -604,7 +619,8 @@ def authenticate_oidc(
604619
authenticator,
605620
provider_id=provider_id,
606621
store_refresh_token=store_refresh_token,
607-
fallback_refresh_token_to_store=refresh_token,
622+
fallback_refresh_token=refresh_token,
623+
oidc_auth_renewer=authenticator,
608624
)
609625
# TODO: pluggable/jupyter-aware display function?
610626
print("Authenticated using refresh token.")
@@ -622,6 +638,8 @@ def authenticate_oidc(
622638
authenticator,
623639
provider_id=provider_id,
624640
store_refresh_token=store_refresh_token,
641+
# TODO: expose `auto_renew_from_refresh_token` directly as option instead of reusing `store_refresh_token` arg?
642+
auto_renew_from_refresh_token=store_refresh_token,
625643
)
626644
print("Authenticated using device code flow.")
627645
return con
@@ -665,6 +683,28 @@ def authenticate_bearer_token(self, bearer_token: str) -> Connection:
665683
self._oidc_auth_renewer = None
666684
return self
667685

686+
def try_access_token_refresh(self, *, reason: Optional[str] = None) -> bool:
687+
"""
688+
Try to get a fresh access token if possible.
689+
Returns whether a new access token was obtained.
690+
"""
691+
reason = f" Reason: {reason}" if reason else ""
692+
if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer:
693+
try:
694+
self._authenticate_oidc(
695+
authenticator=self._oidc_auth_renewer,
696+
provider_id=self._oidc_auth_renewer.provider_info.id,
697+
store_refresh_token=False,
698+
oidc_auth_renewer=self._oidc_auth_renewer,
699+
)
700+
_log.info(f"Obtained new access token (grant {self._oidc_auth_renewer.grant_type!r}).{reason}")
701+
return True
702+
except OpenEoClientException as auth_exc:
703+
_log.error(
704+
f"Failed to obtain new access token (grant {self._oidc_auth_renewer.grant_type!r}): {auth_exc!r}.{reason}"
705+
)
706+
return False
707+
668708
def request(
669709
self,
670710
method: str,
@@ -690,24 +730,11 @@ def _request():
690730
api_exc.http_status_code in {HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN}
691731
and api_exc.code == "TokenInvalid"
692732
):
693-
# Auth token expired: can we refresh?
694-
if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer:
695-
msg = f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})."
696-
try:
697-
self._authenticate_oidc(
698-
authenticator=self._oidc_auth_renewer,
699-
provider_id=self._oidc_auth_renewer.provider_info.id,
700-
store_refresh_token=False,
701-
oidc_auth_renewer=self._oidc_auth_renewer,
702-
)
703-
_log.info(f"{msg} Obtained new access token (grant {self._oidc_auth_renewer.grant_type!r}).")
704-
except OpenEoClientException as auth_exc:
705-
_log.error(
706-
f"{msg} Failed to obtain new access token (grant {self._oidc_auth_renewer.grant_type!r}): {auth_exc!r}."
707-
)
708-
else:
709-
# Retry request.
710-
return _request()
733+
# Retry if we can refresh the access token
734+
if self.try_access_token_refresh(
735+
reason=f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})."
736+
):
737+
return _request()
711738
raise
712739

713740
def describe_account(self) -> dict:

tests/extra/job_management/test_job_management.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_TaskResult,
4848
)
4949
from openeo.rest._testing import OPENEO_BACKEND, DummyBackend, build_capabilities
50+
from openeo.rest.auth.testing import OidcMock
5051
from openeo.util import rfc3339
5152
from openeo.utils.version import ComparableVersion
5253

@@ -269,7 +270,7 @@ def test_create_job_db(self, tmp_path, job_manager, job_manager_root_dir, sleep_
269270
assert set(result.status) == {"finished"}
270271
assert set(result.backend_name) == {"foo", "bar"}
271272

272-
def test_basic_threading(self, tmp_path, job_manager, job_manager_root_dir, sleep_mock):
273+
def test_start_job_thread_basic(self, tmp_path, job_manager, job_manager_root_dir, sleep_mock):
273274
df = pd.DataFrame(
274275
{
275276
"year": [2018, 2019, 2020, 2021, 2022],
@@ -868,6 +869,52 @@ def execute(self):
868869
assert any("Skipping invalid db_update" in msg for msg in caplog.messages)
869870
assert any("Skipping invalid stats_update" in msg for msg in caplog.messages)
870871

872+
def test_refresh_bearer_token_before_start(
873+
self,
874+
tmp_path,
875+
job_manager,
876+
dummy_backend_foo,
877+
dummy_backend_bar,
878+
job_manager_root_dir,
879+
sleep_mock,
880+
requests_mock,
881+
):
882+
883+
client_id = "client123"
884+
client_secret = "$3cr3t"
885+
oidc_issuer = "https://oidc.test/"
886+
oidc_mock = OidcMock(
887+
requests_mock=requests_mock,
888+
expected_grant_type="client_credentials",
889+
expected_client_id=client_id,
890+
expected_fields={"client_secret": client_secret, "scope": "openid"},
891+
oidc_issuer=oidc_issuer,
892+
)
893+
dummy_backend_foo.setup_credentials_oidc(issuer=oidc_issuer)
894+
dummy_backend_bar.setup_credentials_oidc(issuer=oidc_issuer)
895+
dummy_backend_foo.connection.authenticate_oidc_client_credentials(client_id="client123", client_secret="$3cr3t")
896+
dummy_backend_bar.connection.authenticate_oidc_client_credentials(client_id="client123", client_secret="$3cr3t")
897+
898+
# After this setup, we have 2 client credential token requests (one for each backend)
899+
assert len(oidc_mock.grant_request_history) == 2
900+
901+
df = pd.DataFrame({"year": [2020, 2021, 2022, 2023, 2024]})
902+
job_db_path = tmp_path / "jobs.csv"
903+
job_db = CsvJobDatabase(job_db_path).initialize_from_df(df)
904+
run_stats = job_manager.run_jobs(job_db=job_db, start_job=self._create_year_job)
905+
906+
assert run_stats == dirty_equals.IsPartialDict(
907+
{
908+
"job_queued_for_start": 5,
909+
"job started running": 5,
910+
"job finished": 5,
911+
}
912+
)
913+
914+
# Because of proactive+throttled token refreshing,
915+
# we should have 2 additional token requests now
916+
assert len(oidc_mock.grant_request_history) == 4
917+
871918

872919
JOB_DB_DF_BASICS = pd.DataFrame(
873920
{

0 commit comments

Comments
 (0)