Skip to content

Commit 7839676

Browse files
committed
Add basic unit test for proactive bearer token refresh before _JobStartTask #817
1 parent 202ec9b commit 7839676

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

openeo/extra/job_management/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -674,18 +674,20 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
674674
df.loc[i, "status"] = "skipped"
675675
stats["start_job skipped"] += 1
676676

677-
def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60):
677+
def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60) -> None:
678678
"""
679-
Helper to proactively refresh access token of connection
679+
Helper to proactively refresh the bearer (access) token of the connection
680680
(but not too often, based on `max_age`).
681681
"""
682682
# TODO: be smarter about timing, e.g. by inspecting expiry of current token?
683683
now = time.time()
684-
key = f"connection-{id(connection)}-refresh-time"
684+
key = f"connection:{id(connection)}:refresh-time"
685685
if self._cache.get(key, 0) + max_age < now:
686686
refreshed = connection.try_access_token_refresh()
687687
if refreshed:
688688
self._cache[key] = now
689+
else:
690+
_log.warning("Failed to proactively refresh bearer token")
689691

690692
def _process_threadworker_updates(
691693
self,

openeo/rest/_testing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,22 @@ def at_url(cls, root_url: str, *, requests_mock, capabilities: Optional[dict] =
132132
connection = Connection(root_url)
133133
return cls(requests_mock=requests_mock, connection=connection)
134134

135+
def setup_credentials_oidc(self, *, issuer: str = "https://oidc.test", id: str = "oi"):
136+
self._requests_mock.get(
137+
self.connection.build_url("/credentials/oidc"),
138+
json={
139+
"providers": [
140+
{
141+
"id": id,
142+
"issuer": issuer,
143+
"title": id,
144+
"scopes": ["openid"],
145+
}
146+
]
147+
},
148+
)
149+
return self
150+
135151
def setup_collection(
136152
self,
137153
collection_id: str,

tests/extra/job_management/test_job_management.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_TaskResult,
4848
)
4949
from openeo.rest._testing import OPENEO_BACKEND, DummyBackend, build_capabilities
50+
from openeo.rest.auth.testing import OidcMock
5051
from openeo.util import rfc3339
5152
from openeo.utils.version import ComparableVersion
5253

@@ -269,7 +270,7 @@ def test_create_job_db(self, tmp_path, job_manager, job_manager_root_dir, sleep_
269270
assert set(result.status) == {"finished"}
270271
assert set(result.backend_name) == {"foo", "bar"}
271272

272-
def test_basic_threading(self, tmp_path, job_manager, job_manager_root_dir, sleep_mock):
273+
def test_start_job_thread_basic(self, tmp_path, job_manager, job_manager_root_dir, sleep_mock):
273274
df = pd.DataFrame(
274275
{
275276
"year": [2018, 2019, 2020, 2021, 2022],
@@ -868,6 +869,52 @@ def execute(self):
868869
assert any("Skipping invalid db_update" in msg for msg in caplog.messages)
869870
assert any("Skipping invalid stats_update" in msg for msg in caplog.messages)
870871

872+
def test_refresh_bearer_token_before_start(
873+
self,
874+
tmp_path,
875+
job_manager,
876+
dummy_backend_foo,
877+
dummy_backend_bar,
878+
job_manager_root_dir,
879+
sleep_mock,
880+
requests_mock,
881+
):
882+
883+
client_id = "client123"
884+
client_secret = "$3cr3t"
885+
oidc_issuer = "https://oidc.test/"
886+
oidc_mock = OidcMock(
887+
requests_mock=requests_mock,
888+
expected_grant_type="client_credentials",
889+
expected_client_id=client_id,
890+
expected_fields={"client_secret": client_secret, "scope": "openid"},
891+
oidc_issuer=oidc_issuer,
892+
)
893+
dummy_backend_foo.setup_credentials_oidc(issuer=oidc_issuer)
894+
dummy_backend_bar.setup_credentials_oidc(issuer=oidc_issuer)
895+
dummy_backend_foo.connection.authenticate_oidc_client_credentials(client_id="client123", client_secret="$3cr3t")
896+
dummy_backend_bar.connection.authenticate_oidc_client_credentials(client_id="client123", client_secret="$3cr3t")
897+
898+
# After this setup, we have 2 client credential token requests (one for each backend)
899+
assert len(oidc_mock.grant_request_history) == 2
900+
901+
df = pd.DataFrame({"year": [2020, 2021, 2022, 2023, 2024]})
902+
job_db_path = tmp_path / "jobs.csv"
903+
job_db = CsvJobDatabase(job_db_path).initialize_from_df(df)
904+
run_stats = job_manager.run_jobs(job_db=job_db, start_job=self._create_year_job)
905+
906+
assert run_stats == dirty_equals.IsPartialDict(
907+
{
908+
"job_queued_for_start": 5,
909+
"job started running": 5,
910+
"job finished": 5,
911+
}
912+
)
913+
914+
# Because of proactive+throttled token refreshing,
915+
# we should have 2 additional token requests now
916+
assert len(oidc_mock.grant_request_history) == 4
917+
871918

872919
JOB_DB_DF_BASICS = pd.DataFrame(
873920
{

0 commit comments

Comments
 (0)