From 17d90a7efe50920772b1812fe5c262825c28f8fb Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Wed, 26 Nov 2025 10:20:57 +0100 Subject: [PATCH 01/16] adding download task and creating seperate download pool --- openeo/extra/job_management/_manager.py | 45 ++++++++++++++---- openeo/extra/job_management/_thread_worker.py | 46 +++++++++++++++++++ 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index 2ce78ba52..d0458ba83 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -32,6 +32,7 @@ from openeo.extra.job_management._thread_worker import ( _JobManagerWorkerThreadPool, _JobStartTask, + _JobDownloadTask ) from openeo.rest import OpenEoApiError from openeo.rest.auth.auth import BearerAuth @@ -172,6 +173,7 @@ def start_job( .. versionchanged:: 0.47.0 Added ``download_results`` parameter. + """ # Expected columns in the job DB dataframes. @@ -219,6 +221,7 @@ def __init__( ) self._thread = None self._worker_pool = None + self._download_pool = None # Generic cache self._cache = {} @@ -351,6 +354,7 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas self._stop_thread = False self._worker_pool = _JobManagerWorkerThreadPool() + self._download_pool = _JobManagerWorkerThreadPool() def run_loop(): # TODO: support user-provided `stats` @@ -388,7 +392,13 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): .. versionadded:: 0.32.0 """ - self._worker_pool.shutdown() + if self._worker_pool is not None: + self._worker_pool.shutdown() + self._worker_pool = None + + if self._download_pool is not None: + self._download_pool.shutdown() + self._download_pool = None if self._thread is not None: self._stop_thread = True @@ -493,6 +503,8 @@ def run_jobs( stats = collections.defaultdict(int) self._worker_pool = _JobManagerWorkerThreadPool() + self._download_pool = _JobManagerWorkerThreadPool() + while ( sum( @@ -511,7 +523,7 @@ def run_jobs( stats["sleep"] += 1 # TODO; run post process after shutdown once more to ensure completion? - self._worker_pool.shutdown() + self.stop_job_thread() return stats @@ -553,7 +565,11 @@ def _job_update_loop( stats["job_db persist"] += 1 total_added += 1 - self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats) + if self._worker_pool is not None: + self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats) + + if self._download_pool is not None: + self._process_threadworker_updates(worker_pool=self._download_pool, job_db=job_db, stats=stats) # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads? for job, row in jobs_done: @@ -565,6 +581,7 @@ def _job_update_loop( for job, row in jobs_cancel: self.on_job_cancel(job, row) + def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None): """Helper method for launching jobs @@ -657,7 +674,7 @@ def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60) else: _log.warning("Failed to proactively refresh bearer token") - def _process_threadworker_updates( + def _process_task_results( self, worker_pool: _JobManagerWorkerThreadPool, *, @@ -723,15 +740,23 @@ def on_job_done(self, job: BatchJob, row): """ # TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use? if self._download_results: - job_metadata = job.describe() - job_dir = self.get_job_dir(job.job_id) - metadata_path = self.get_job_metadata_path(job.job_id) + job_dir = self.get_job_dir(job.job_id) self.ensure_job_dir_exists(job.job_id) - job.get_results().download_files(target=job_dir) - with metadata_path.open("w", encoding="utf-8") as f: - json.dump(job_metadata, f, ensure_ascii=False) + # Proactively refresh bearer token (because task in thread will not be able to do that + job_con = job.connection + self._refresh_bearer_token(connection=job_con) + + task = _JobDownloadTask( + root_url=job_con.root_url, + bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, + job_id=job.job_id, + df_idx=row.name, # TODO figure out correct index usage + download_dir=job_dir, + ) + _log.info(f"Submitting download task {task} to download thread pool") + self._download_pool.submit_task(task) def on_job_error(self, job: BatchJob, row): """ diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 6040fade1..e5306ff7a 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -7,7 +7,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union +from pathlib import Path +import json import urllib3.util import openeo @@ -140,6 +142,50 @@ def execute(self) -> _TaskResult: stats_update={"start_job error": 1}, ) +@dataclass(frozen=True) +class _JobDownloadTask(ConnectedTask): + """ + Task for downloading job results and metadata. + + :param download_dir: + Root directory where job results and metadata will be downloaded. + """ + download_dir: Path + + def execute(self) -> _TaskResult: + """ + Download job results and metadata. + """ + try: + job = self.get_connection(retry=True).job(self.job_id) + + # Ensure download directory exists + self.download_dir.mkdir(parents=True, exist_ok=True) + + # Download results + job.get_results().download_files(target=self.download_dir) + + # Download metadata + job_metadata = job.describe() + metadata_path = self.download_dir / f"job_{self.job_id}.json" + with metadata_path.open("w", encoding="utf-8") as f: + json.dump(job_metadata, f, ensure_ascii=False) + + _log.info(f"Job {self.job_id!r} results downloaded successfully") + return _TaskResult( + job_id=self.job_id, + df_idx=self.df_idx, + db_update={}, #TODO consider db updates + stats_update={"job download": 1}, + ) + except Exception as e: + _log.error(f"Failed to download results for job {self.job_id!r}: {e!r}") + return _TaskResult( + job_id=self.job_id, + df_idx=self.df_idx, + db_update={}, + stats_update={"job download error": 1}, + ) class _JobManagerWorkerThreadPool: """ From d22ac209858adf0c0f90d9dce5903abec3841730 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Fri, 28 Nov 2025 11:03:47 +0100 Subject: [PATCH 02/16] include initial unit testing --- openeo/extra/job_management/_manager.py | 5 +- openeo/extra/job_management/_thread_worker.py | 10 ++- .../job_management/test_thread_worker.py | 64 +++++++++++++++++++ 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index d0458ba83..c575c98e6 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -674,7 +674,7 @@ def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60) else: _log.warning("Failed to proactively refresh bearer token") - def _process_task_results( + def _process_threadworker_updates( self, worker_pool: _JobManagerWorkerThreadPool, *, @@ -756,6 +756,9 @@ def on_job_done(self, job: BatchJob, row): download_dir=job_dir, ) _log.info(f"Submitting download task {task} to download thread pool") + + if self._download_pool is None: + self._download_pool = _JobManagerWorkerThreadPool() self._download_pool.submit_task(task) def on_job_error(self, job: BatchJob, row): diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index e5306ff7a..523cce0d1 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -142,7 +142,6 @@ def execute(self) -> _TaskResult: stats_update={"start_job error": 1}, ) -@dataclass(frozen=True) class _JobDownloadTask(ConnectedTask): """ Task for downloading job results and metadata. @@ -150,12 +149,11 @@ class _JobDownloadTask(ConnectedTask): :param download_dir: Root directory where job results and metadata will be downloaded. """ - download_dir: Path + def __init__(self, download_dir: Path, **kwargs): + super().__init__(**kwargs) + object.__setattr__(self, 'download_dir', download_dir) def execute(self) -> _TaskResult: - """ - Download job results and metadata. - """ try: job = self.get_connection(retry=True).job(self.job_id) @@ -186,7 +184,7 @@ def execute(self) -> _TaskResult: db_update={}, stats_update={"job download error": 1}, ) - + class _JobManagerWorkerThreadPool: """ Thread pool-based worker that manages the execution of asynchronous tasks. diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index 52ee833f1..128741ac5 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -11,6 +11,7 @@ _JobManagerWorkerThreadPool, _JobStartTask, _TaskResult, + _JobDownloadTask ) from openeo.rest._testing import DummyBackend @@ -288,3 +289,66 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): assert caplog.messages == [ "Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting for you, buddy')" ] + + def test_download_task_in_pool(self, worker_pool, tmp_path): + # Test that download tasks can be submitted to the thread pool + # without needing actual backend functionality + task = _JobDownloadTask( + job_id="pool-job-123", + df_idx=42, + root_url="https://example.com", + bearer_token="test-token", + download_dir=tmp_path + ) + + worker_pool.submit_task(task) + results, remaining = worker_pool.process_futures(timeout=1) + + # We can't test the actual download result without a backend, + # but we can verify the task was processed + assert len(results) == 1 + result = results[0] + assert result.job_id == "pool-job-123" + assert result.df_idx == 42 + assert remaining == 0 + +class TestJobDownloadTask: + def test_download_success(self, tmp_path, caplog): + caplog.set_level(logging.INFO) + + # Test the basic functionality without complex backend setup + download_dir = tmp_path / "downloads" + task = _JobDownloadTask( + job_id="test-job-123", + df_idx=0, + root_url="https://example.com", + bearer_token="test-token", + download_dir=download_dir + ) + + # Since we can't test actual downloads without a real backend, + # we'll test that the task is properly constructed and the directory is handled + assert task.job_id == "test-job-123" + assert task.df_idx == 0 + assert task.root_url == "https://example.com" + assert task.download_dir == download_dir + # Token should be hidden in repr + assert "test-token" not in repr(task) + + def test_download_failure_handling(self, tmp_path, caplog): + caplog.set_level(logging.ERROR) + + # Test that the task properly handles execution context + # We can't easily test actual download failures without complex setup, + # but we can verify the task structure and error handling approach + download_dir = tmp_path / "downloads" + task = _JobDownloadTask( + job_id="failing-job", + df_idx=1, + root_url="https://example.com", + bearer_token="test-token", + download_dir=download_dir + ) + + # The task should be properly constructed for error handling + assert task.job_id == "failing-job" \ No newline at end of file From 7acc04304dc4829fdebd5ae44df6e483e499fe28 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Wed, 10 Dec 2025 16:19:08 +0100 Subject: [PATCH 03/16] updated unit tests --- .../job_management/test_thread_worker.py | 159 ++++++++++++------ 1 file changed, 107 insertions(+), 52 deletions(-) diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index 128741ac5..c28c1c3a5 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -2,7 +2,9 @@ import threading import time from dataclasses import dataclass -from typing import Iterator +from typing import Iterator, Dict, Any +from pathlib import Path +import json import pytest @@ -290,65 +292,118 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): "Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting for you, buddy')" ] - def test_download_task_in_pool(self, worker_pool, tmp_path): - # Test that download tasks can be submitted to the thread pool - # without needing actual backend functionality - task = _JobDownloadTask( - job_id="pool-job-123", - df_idx=42, - root_url="https://example.com", - bearer_token="test-token", - download_dir=tmp_path - ) + + +import tempfile +from requests_mock import Mocker +OPENEO_BACKEND = "https://openeo.dummy.test/" + +class TestJobDownloadTask: + + # Use a temporary directory for safe file handling + @pytest.fixture + def temp_dir(self): + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path): + """ + Test a successful job download and verify file content and stats update. + """ + job_id = "job-007" + df_idx = 42 - worker_pool.submit_task(task) - results, remaining = worker_pool.process_futures(timeout=1) + # Setup Dummy Backend + backend = DummyBackend.at_url(OPENEO_BACKEND, requests_mock=requests_mock) + backend.next_result = b"The downloaded file content." - # We can't test the actual download result without a backend, - # but we can verify the task was processed - assert len(results) == 1 - result = results[0] - assert result.job_id == "pool-job-123" - assert result.df_idx == 42 - assert remaining == 0 + # Pre-set job status to "finished" so the download link is available + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} + + # Need to ensure job status is "finished" for download attempt to occur + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["status"] = "finished" -class TestJobDownloadTask: - def test_download_success(self, tmp_path, caplog): - caplog.set_level(logging.INFO) + download_dir = temp_dir / job_id / "results" + download_dir.mkdir(parents=True) - # Test the basic functionality without complex backend setup - download_dir = tmp_path / "downloads" + # Create the task instance task = _JobDownloadTask( - job_id="test-job-123", - df_idx=0, - root_url="https://example.com", - bearer_token="test-token", - download_dir=download_dir + root_url=OPENEO_BACKEND, + bearer_token="dummy-token-7", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, ) + + # Execute the task + result = task.execute() + + # 4. Assertions + + # A. Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # B. Verify stats update for the MultiBackendJobManager + assert result.stats_update == {"job download": 1} + + # C. Verify download content (crucial part of the unit test) + downloaded_file = download_dir / "result.data" + assert downloaded_file.exists() + assert downloaded_file.read_bytes() == b"The downloaded file content." + + # Verify backend interaction + get_results_calls = [c for c in requests_mock.request_history if c.method == "GET" and f"/jobs/{job_id}/results" in c.url] + assert len(get_results_calls) >= 1 + get_asset_calls = [c for c in requests_mock.request_history if c.method == "GET" and f"/jobs/{job_id}/results/result.data" in c.url] + assert len(get_asset_calls) == 1 + + def test_job_download_failure(self, requests_mock: Mocker, temp_dir: Path): + """ + Test a failed download (e.g., bad connection) and verify error reporting. + """ + job_id = "job-008" + df_idx = 55 + + # Need to ensure job status is "finished" for download attempt to occur + backend = DummyBackend.at_url(OPENEO_BACKEND, requests_mock=requests_mock) + + requests_mock.get( + f"{OPENEO_BACKEND}jobs/{job_id}/results", + status_code=500, + json={"code": "InternalError", "message": "Failed to list results"}) + + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} - # Since we can't test actual downloads without a real backend, - # we'll test that the task is properly constructed and the directory is handled - assert task.job_id == "test-job-123" - assert task.df_idx == 0 - assert task.root_url == "https://example.com" - assert task.download_dir == download_dir - # Token should be hidden in repr - assert "test-token" not in repr(task) - - def test_download_failure_handling(self, tmp_path, caplog): - caplog.set_level(logging.ERROR) + # Need to ensure job status is "finished" for download attempt to occur + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["status"] = "finished" - # Test that the task properly handles execution context - # We can't easily test actual download failures without complex setup, - # but we can verify the task structure and error handling approach - download_dir = tmp_path / "downloads" + download_dir = temp_dir / job_id / "results" + download_dir.mkdir(parents=True) + + # Create the task instance task = _JobDownloadTask( - job_id="failing-job", - df_idx=1, - root_url="https://example.com", - bearer_token="test-token", - download_dir=download_dir + root_url=OPENEO_BACKEND, + bearer_token="dummy-token-8", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, ) + + # Execute the task + result = task.execute() + + # Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # Verify stats update for the MultiBackendJobManager + assert result.stats_update == {"job download error": 1} - # The task should be properly constructed for error handling - assert task.job_id == "failing-job" \ No newline at end of file + # Verify no file was created (or only empty/failed files) + assert not any(p.is_file() for p in download_dir.glob("*")) + From 246ac2f8f8fde90c629b9233263a706b7d9d0c90 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 09:54:18 +0100 Subject: [PATCH 04/16] including two simple unit tests and unifying pool usage --- .../job_management/test_thread_worker.py | 216 ++++++++---------- 1 file changed, 100 insertions(+), 116 deletions(-) diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index c28c1c3a5..47bde093f 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -2,9 +2,10 @@ import threading import time from dataclasses import dataclass -from typing import Iterator, Dict, Any +from typing import Iterator from pathlib import Path -import json +import tempfile +from requests_mock import Mocker import pytest @@ -82,6 +83,103 @@ def test_hide_token(self, serializer): assert "job-123" in serialized assert secret not in serialized +class TestJobDownloadTask: + + # Use a temporary directory for safe file handling + @pytest.fixture + def temp_dir(self): + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path): + """ + Test a successful job download and verify file content and stats update. + """ + job_id = "job-007" + df_idx = 42 + + # We set up a dummy backend to simulate the job results and assert the expected calls are triggered + backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock) + backend.next_result = b"The downloaded file content." + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} + + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["status"] = "finished" + + download_dir = temp_dir / job_id / "results" + download_dir.mkdir(parents=True) + + # Create the task instance + task = _JobDownloadTask( + root_url="https://openeo.dummy.test/", + bearer_token="dummy-token-7", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, + ) + + # Execute the task + result = task.execute() + + # Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # Verify stats update for the MultiBackendJobManager + assert result.stats_update == {"job download": 1} + + # Verify download content (crucial part of the unit test) + downloaded_file = download_dir / "result.data" + assert downloaded_file.exists() + assert downloaded_file.read_bytes() == b"The downloaded file content." + + + def test_job_download_failure(self, requests_mock: Mocker, temp_dir: Path): + """ + Test a failed download (e.g., bad connection) and verify error reporting. + """ + job_id = "job-008" + df_idx = 55 + + # Set up dummy backend to simulate failure during results listing + backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock) + + #simulate and error when downloading the results + requests_mock.get( + f"https://openeo.dummy.test/jobs/{job_id}/results", + status_code=500, + json={"code": "InternalError", "message": "Failed to list results"}) + + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["finished"] = "error" + + download_dir = temp_dir / job_id / "results" + download_dir.mkdir(parents=True) + + # Create the task instance + task = _JobDownloadTask( + root_url="https://openeo.dummy.test/", + bearer_token="dummy-token-8", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, + ) + + # Execute the task + result = task.execute() + + # Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # Verify stats update for the MultiBackendJobManager + assert result.stats_update == {"job download error": 1} + + # Verify no file was created (or only empty/failed files) + assert not any(p.is_file() for p in download_dir.glob("*")) class NopTask(Task): """Do Nothing""" @@ -293,117 +391,3 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): ] - -import tempfile -from requests_mock import Mocker -OPENEO_BACKEND = "https://openeo.dummy.test/" - -class TestJobDownloadTask: - - # Use a temporary directory for safe file handling - @pytest.fixture - def temp_dir(self): - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path): - """ - Test a successful job download and verify file content and stats update. - """ - job_id = "job-007" - df_idx = 42 - - # Setup Dummy Backend - backend = DummyBackend.at_url(OPENEO_BACKEND, requests_mock=requests_mock) - backend.next_result = b"The downloaded file content." - - # Pre-set job status to "finished" so the download link is available - backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} - - # Need to ensure job status is "finished" for download attempt to occur - backend._set_job_status(job_id=job_id, status="finished") - backend.batch_jobs[job_id]["status"] = "finished" - - download_dir = temp_dir / job_id / "results" - download_dir.mkdir(parents=True) - - # Create the task instance - task = _JobDownloadTask( - root_url=OPENEO_BACKEND, - bearer_token="dummy-token-7", - job_id=job_id, - df_idx=df_idx, - download_dir=download_dir, - ) - - # Execute the task - result = task.execute() - - # 4. Assertions - - # A. Verify TaskResult structure - assert isinstance(result, _TaskResult) - assert result.job_id == job_id - assert result.df_idx == df_idx - - # B. Verify stats update for the MultiBackendJobManager - assert result.stats_update == {"job download": 1} - - # C. Verify download content (crucial part of the unit test) - downloaded_file = download_dir / "result.data" - assert downloaded_file.exists() - assert downloaded_file.read_bytes() == b"The downloaded file content." - - # Verify backend interaction - get_results_calls = [c for c in requests_mock.request_history if c.method == "GET" and f"/jobs/{job_id}/results" in c.url] - assert len(get_results_calls) >= 1 - get_asset_calls = [c for c in requests_mock.request_history if c.method == "GET" and f"/jobs/{job_id}/results/result.data" in c.url] - assert len(get_asset_calls) == 1 - - def test_job_download_failure(self, requests_mock: Mocker, temp_dir: Path): - """ - Test a failed download (e.g., bad connection) and verify error reporting. - """ - job_id = "job-008" - df_idx = 55 - - # Need to ensure job status is "finished" for download attempt to occur - backend = DummyBackend.at_url(OPENEO_BACKEND, requests_mock=requests_mock) - - requests_mock.get( - f"{OPENEO_BACKEND}jobs/{job_id}/results", - status_code=500, - json={"code": "InternalError", "message": "Failed to list results"}) - - backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} - - # Need to ensure job status is "finished" for download attempt to occur - backend._set_job_status(job_id=job_id, status="finished") - backend.batch_jobs[job_id]["status"] = "finished" - - download_dir = temp_dir / job_id / "results" - download_dir.mkdir(parents=True) - - # Create the task instance - task = _JobDownloadTask( - root_url=OPENEO_BACKEND, - bearer_token="dummy-token-8", - job_id=job_id, - df_idx=df_idx, - download_dir=download_dir, - ) - - # Execute the task - result = task.execute() - - # Verify TaskResult structure - assert isinstance(result, _TaskResult) - assert result.job_id == job_id - assert result.df_idx == df_idx - - # Verify stats update for the MultiBackendJobManager - assert result.stats_update == {"job download error": 1} - - # Verify no file was created (or only empty/failed files) - assert not any(p.is_file() for p in download_dir.glob("*")) - From d67fdd67ccad386b20333f17a9df950da751242d Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 09:58:32 +0100 Subject: [PATCH 05/16] changes to job manager --- openeo/extra/job_management/_manager.py | 1 + openeo/extra/job_management/_thread_worker.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index c575c98e6..6bff9aa1a 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -811,6 +811,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row): except Exception as e: _log.error(f"Unexpected error while handling job {job.job_id}: {e}") + #TODO pull this functionality away from the manager to a general utility class? job dir creation could be reused for tje Jobdownload task def get_job_dir(self, job_id: str) -> Path: """Path to directory where job metadata, results and error logs are be saved.""" return self._root_dir / f"job_{job_id}" diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 523cce0d1..1a2c33ca9 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -173,7 +173,7 @@ def execute(self) -> _TaskResult: return _TaskResult( job_id=self.job_id, df_idx=self.df_idx, - db_update={}, #TODO consider db updates + db_update={}, #TODO consider db updates? stats_update={"job download": 1}, ) except Exception as e: From 2973beedef01432920ed16f77826d2906d8c8ca7 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 10:28:18 +0100 Subject: [PATCH 06/16] adding easy callback to check number of pending tasks on thread workers; this can be used to guarantee the download gets finished --- openeo/extra/job_management/_manager.py | 25 ++++++++++++++++++- openeo/extra/job_management/_thread_worker.py | 4 +++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index 6bff9aa1a..43d43fd9c 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -400,6 +400,14 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): self._download_pool.shutdown() self._download_pool = None + if self._download_pool is not None: + # Wait for downloads to complete before shutting down + _log.info("Waiting for download tasks to complete before stopping...") + while self._download_pool.num_pending_tasks() > 0: + time.sleep(0.5) + self._download_pool.shutdown() + self._download_pool = None + if self._thread is not None: self._stop_thread = True if timeout_seconds is _UNSET: @@ -513,6 +521,8 @@ def run_jobs( ).values() ) > 0 + + or self._worker_pool.num_pending_tasks() > 0 ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 @@ -523,7 +533,20 @@ def run_jobs( stats["sleep"] += 1 # TODO; run post process after shutdown once more to ensure completion? - self.stop_job_thread() + # Wait for all download tasks to complete + if self._download_results and self._download_pool is not None: + _log.info("Waiting for download tasks to complete...") + while self._download_pool.num_pending_tasks() > 0: + self._process_threadworker_updates( + worker_pool=self._download_pool, + job_db=job_db, + stats=stats + ) + time.sleep(1) # Brief pause to avoid busy waiting + _log.info("All download tasks completed.") + + self._worker_pool.shutdown() + self._download_pool.shutdown() return stats diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 1a2c33ca9..ba00b8828 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -251,6 +251,10 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe self._future_task_pairs = to_keep return results, len(to_keep) + + def num_pending_tasks(self) -> int: + """Return the number of tasks that are still pending (not completed).""" + return len(self._future_task_pairs) def shutdown(self) -> None: """Shuts down the thread pool gracefully.""" From 0e7c4f5de22719cf6db9dc516f773abe60ca6e62 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 12:30:37 +0100 Subject: [PATCH 07/16] process updates through job update loop --- openeo/extra/job_management/_manager.py | 28 +++++++++---------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index 43d43fd9c..ee692a06c 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -511,18 +511,21 @@ def run_jobs( stats = collections.defaultdict(int) self._worker_pool = _JobManagerWorkerThreadPool() - self._download_pool = _JobManagerWorkerThreadPool() + + if self._download_results: + self._download_pool = _JobManagerWorkerThreadPool() while ( sum( job_db.count_by_status( statuses=["not_started", "created", "queued_for_start", "queued", "running"] - ).values() - ) - > 0 + ).values()) > 0 + + or (self._worker_pool is not None and self._worker_pool.num_pending_tasks() > 0) - or self._worker_pool.num_pending_tasks() > 0 + or (self._download_pool is not None and self._download_pool.num_pending_tasks() > 0) + ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 @@ -532,18 +535,6 @@ def run_jobs( time.sleep(self.poll_sleep) stats["sleep"] += 1 - # TODO; run post process after shutdown once more to ensure completion? - # Wait for all download tasks to complete - if self._download_results and self._download_pool is not None: - _log.info("Waiting for download tasks to complete...") - while self._download_pool.num_pending_tasks() > 0: - self._process_threadworker_updates( - worker_pool=self._download_pool, - job_db=job_db, - stats=stats - ) - time.sleep(1) # Brief pause to avoid busy waiting - _log.info("All download tasks completed.") self._worker_pool.shutdown() self._download_pool.shutdown() @@ -775,13 +766,14 @@ def on_job_done(self, job: BatchJob, row): root_url=job_con.root_url, bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, job_id=job.job_id, - df_idx=row.name, # TODO figure out correct index usage + df_idx=row.name, #this is going to be the index in the not saterted dataframe; should not be an issue as there is no db update for download task download_dir=job_dir, ) _log.info(f"Submitting download task {task} to download thread pool") if self._download_pool is None: self._download_pool = _JobManagerWorkerThreadPool() + self._download_pool.submit_task(task) def on_job_error(self, job: BatchJob, row): From 8ccb442a65a572ddb7301fdc38dc83d2850373a3 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 12:38:57 +0100 Subject: [PATCH 08/16] remove folder creation logic from thread to resprect optional downloa'; similarly download shutdown depend on optional download --- openeo/extra/job_management/_manager.py | 7 +++++-- openeo/extra/job_management/_thread_worker.py | 3 --- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index ee692a06c..d6676d198 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -536,8 +536,11 @@ def run_jobs( stats["sleep"] += 1 - self._worker_pool.shutdown() - self._download_pool.shutdown() + if self._worker_pool is not None: + self._worker_pool.shutdown() + + if self._download_pool is not None: + self._download_pool.shutdown() return stats diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index ba00b8828..2a55c2599 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -157,9 +157,6 @@ def execute(self) -> _TaskResult: try: job = self.get_connection(retry=True).job(self.job_id) - # Ensure download directory exists - self.download_dir.mkdir(parents=True, exist_ok=True) - # Download results job.get_results().download_files(target=self.download_dir) From 855a393aff15aa2872b3b78a7b0f36c2e4bcd6b6 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 12:39:58 +0100 Subject: [PATCH 09/16] fix stop_job_thread --- openeo/extra/job_management/_manager.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index d6676d198..e99ebd250 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -395,16 +395,12 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): if self._worker_pool is not None: self._worker_pool.shutdown() self._worker_pool = None - - if self._download_pool is not None: - self._download_pool.shutdown() - self._download_pool = None if self._download_pool is not None: # Wait for downloads to complete before shutting down _log.info("Waiting for download tasks to complete before stopping...") while self._download_pool.num_pending_tasks() > 0: - time.sleep(0.5) + time.sleep(1) self._download_pool.shutdown() self._download_pool = None From e2b6ab8d7573b2d0ed08e498ab7a366a2e41317c Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 12:52:32 +0100 Subject: [PATCH 10/16] working on fix for indefinete loop --- openeo/extra/job_management/_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index e99ebd250..9787fd678 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -534,9 +534,11 @@ def run_jobs( if self._worker_pool is not None: self._worker_pool.shutdown() + self._worker_pool = None if self._download_pool is not None: self._download_pool.shutdown() + self._download_pool = None return stats @@ -751,7 +753,6 @@ def on_job_done(self, job: BatchJob, row): :param job: The job that has finished. :param row: DataFrame row containing the job's metadata. """ - # TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use? if self._download_results: job_dir = self.get_job_dir(job.job_id) From dc75ca8ec27d902ffbc18507a079a60229ea3c16 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Thu, 11 Dec 2025 13:31:04 +0100 Subject: [PATCH 11/16] fix infinite loop --- openeo/extra/job_management/_manager.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index 9787fd678..c83a2a8d6 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -397,10 +397,6 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): self._worker_pool = None if self._download_pool is not None: - # Wait for downloads to complete before shutting down - _log.info("Waiting for download tasks to complete before stopping...") - while self._download_pool.num_pending_tasks() > 0: - time.sleep(1) self._download_pool.shutdown() self._download_pool = None From 4fc299de157ecf4b7485391448bcf890e673dd45 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Mon, 15 Dec 2025 08:42:21 +0100 Subject: [PATCH 12/16] wrapper to abstract multiple threadpools --- openeo/extra/job_management/_manager.py | 39 +++++---------- openeo/extra/job_management/_thread_worker.py | 48 +++++++++++++++++-- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index c83a2a8d6..2c5d98674 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -221,7 +221,6 @@ def __init__( ) self._thread = None self._worker_pool = None - self._download_pool = None # Generic cache self._cache = {} @@ -354,7 +353,6 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas self._stop_thread = False self._worker_pool = _JobManagerWorkerThreadPool() - self._download_pool = _JobManagerWorkerThreadPool() def run_loop(): # TODO: support user-provided `stats` @@ -367,6 +365,9 @@ def run_loop(): ).values() ) > 0 + + or (self._worker_pool.num_pending_tasks() > 0) + and not self._stop_thread ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) @@ -392,13 +393,10 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): .. versionadded:: 0.32.0 """ - if self._worker_pool is not None: + if self._worker_pool is not None: #TODO or thread_pool.num_pending_tasks() > 0 self._worker_pool.shutdown() self._worker_pool = None - if self._download_pool is not None: - self._download_pool.shutdown() - self._download_pool = None if self._thread is not None: self._stop_thread = True @@ -504,9 +502,6 @@ def run_jobs( self._worker_pool = _JobManagerWorkerThreadPool() - if self._download_results: - self._download_pool = _JobManagerWorkerThreadPool() - while ( sum( @@ -514,9 +509,7 @@ def run_jobs( statuses=["not_started", "created", "queued_for_start", "queued", "running"] ).values()) > 0 - or (self._worker_pool is not None and self._worker_pool.num_pending_tasks() > 0) - - or (self._download_pool is not None and self._download_pool.num_pending_tasks() > 0) + or (self._worker_pool.num_pending_tasks() > 0) ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) @@ -528,13 +521,9 @@ def run_jobs( stats["sleep"] += 1 - if self._worker_pool is not None: - self._worker_pool.shutdown() - self._worker_pool = None - - if self._download_pool is not None: - self._download_pool.shutdown() - self._download_pool = None + + self._worker_pool.shutdown() + self._worker_pool = None return stats @@ -579,8 +568,6 @@ def _job_update_loop( if self._worker_pool is not None: self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats) - if self._download_pool is not None: - self._process_threadworker_updates(worker_pool=self._download_pool, job_db=job_db, stats=stats) # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads? for job, row in jobs_done: @@ -657,7 +644,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No df_idx=i, ) _log.info(f"Submitting task {task} to thread pool") - self._worker_pool.submit_task(task) + self._worker_pool.submit_start_task(task) stats["job_queued_for_start"] += 1 df.loc[i, "status"] = "queued_for_start" @@ -703,7 +690,7 @@ def _process_threadworker_updates( :param stats: Dictionary accumulating statistic counters """ # Retrieve completed task results immediately - results, _ = worker_pool.process_futures(timeout=0) + results, start_remaining, download_remaining = worker_pool.process_all_updates(timeout=0) # Collect update dicts updates: List[Dict[str, Any]] = [] @@ -767,10 +754,10 @@ def on_job_done(self, job: BatchJob, row): ) _log.info(f"Submitting download task {task} to download thread pool") - if self._download_pool is None: - self._download_pool = _JobManagerWorkerThreadPool() + if self._worker_pool is None: + self._worker_pool = _JobManagerWorkerThreadPool() - self._download_pool.submit_task(task) + self._worker_pool.submit_download_task(task) def on_job_error(self, job: BatchJob, row): """ diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 2a55c2599..1fe3e713c 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -3,6 +3,8 @@ """ import concurrent.futures +import threading +import queue import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -148,12 +150,14 @@ class _JobDownloadTask(ConnectedTask): :param download_dir: Root directory where job results and metadata will be downloaded. + :param download_throttle: + A threading.Semaphore to limit concurrent downloads. """ - def __init__(self, download_dir: Path, **kwargs): - super().__init__(**kwargs) - object.__setattr__(self, 'download_dir', download_dir) + download_dir: Path = field(repr=False) + def execute(self) -> _TaskResult: + try: job = self.get_connection(retry=True).job(self.job_id) @@ -182,7 +186,7 @@ def execute(self) -> _TaskResult: stats_update={"job download error": 1}, ) -class _JobManagerWorkerThreadPool: +class _TaskThreadPool: """ Thread pool-based worker that manages the execution of asynchronous tasks. @@ -257,3 +261,39 @@ def shutdown(self) -> None: """Shuts down the thread pool gracefully.""" _log.info("Shutting down thread pool") self._executor.shutdown(wait=True) + + +class _JobManagerWorkerThreadPool: + """WRAPPER that hides two pools behind one interface""" + + def __init__(self, max_start_workers=2, max_download_workers=10): + # These are the TWO pools with their OWN _future_task_pairs + self._start_pool = _TaskThreadPool(max_workers=max_start_workers) + self._download_pool = _TaskThreadPool(max_workers=max_download_workers) + + def submit_start_task(self, task): + # Delegate to start pool + self._start_pool.submit_task(task) + + def submit_download_task(self, task): + # Delegate to download pool + self._download_pool.submit_task(task) + + def process_all_updates(self, timeout=0): + # Get results from BOTH pools + start_results, start_remaining = self._start_pool.process_futures(timeout) + download_results, download_remaining = self._download_pool.process_futures(timeout) + + # Combine and return + all_results = start_results + download_results + return all_results, start_remaining, download_remaining + + def num_pending_tasks(self): + # Sum of BOTH pools + return (self._start_pool.num_pending_tasks() + + self._download_pool.num_pending_tasks()) + + def shutdown(self): + # Shutdown BOTH pools + self._start_pool.shutdown() + self._download_pool.shutdown() From 382eae4ce816ac6e60717e2fcfb4006c6d4bf126 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Mon, 15 Dec 2025 16:18:46 +0100 Subject: [PATCH 13/16] coupling task type to seperate pool --- openeo/extra/job_management/_manager.py | 12 +- openeo/extra/job_management/_thread_worker.py | 110 ++++++++++++------ tests/extra/job_management/test_manager.py | 8 +- .../job_management/test_thread_worker.py | 2 +- 4 files changed, 85 insertions(+), 47 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index 2c5d98674..ad4a6a2ff 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -367,7 +367,7 @@ def run_loop(): > 0 or (self._worker_pool.num_pending_tasks() > 0) - + and not self._stop_thread ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) @@ -644,7 +644,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No df_idx=i, ) _log.info(f"Submitting task {task} to thread pool") - self._worker_pool.submit_start_task(task) + self._worker_pool.submit_task(task) stats["job_queued_for_start"] += 1 df.loc[i, "status"] = "queued_for_start" @@ -690,7 +690,7 @@ def _process_threadworker_updates( :param stats: Dictionary accumulating statistic counters """ # Retrieve completed task results immediately - results, start_remaining, download_remaining = worker_pool.process_all_updates(timeout=0) + results, _ = worker_pool.process_all_updates(timeout=0) # Collect update dicts updates: List[Dict[str, Any]] = [] @@ -746,10 +746,10 @@ def on_job_done(self, job: BatchJob, row): self._refresh_bearer_token(connection=job_con) task = _JobDownloadTask( - root_url=job_con.root_url, - bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, job_id=job.job_id, df_idx=row.name, #this is going to be the index in the not saterted dataframe; should not be an issue as there is no db update for download task + root_url=job_con.root_url, + bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, download_dir=job_dir, ) _log.info(f"Submitting download task {task} to download thread pool") @@ -757,7 +757,7 @@ def on_job_done(self, job: BatchJob, row): if self._worker_pool is None: self._worker_pool = _JobManagerWorkerThreadPool() - self._worker_pool.submit_download_task(task) + self._worker_pool.submit_task(task) def on_job_error(self, job: BatchJob, row): """ diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 1fe3e713c..bfb9814ef 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -3,8 +3,6 @@ """ import concurrent.futures -import threading -import queue import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -103,7 +101,7 @@ def get_connection(self, retry: Union[urllib3.util.Retry, dict, bool, None] = No connection.authenticate_bearer_token(self.bearer_token) return connection - +@dataclass(frozen=True) class _JobStartTask(ConnectedTask): """ Task for starting an openEO batch job (the `POST /jobs//result` request). @@ -143,18 +141,16 @@ def execute(self) -> _TaskResult: db_update={"status": "start_failed"}, stats_update={"start_job error": 1}, ) - + +@dataclass(frozen=True) class _JobDownloadTask(ConnectedTask): """ Task for downloading job results and metadata. :param download_dir: Root directory where job results and metadata will be downloaded. - :param download_throttle: - A threading.Semaphore to limit concurrent downloads. """ - download_dir: Path = field(repr=False) - + download_dir: Path = field(default=None, repr=False) def execute(self) -> _TaskResult: @@ -198,9 +194,10 @@ class _TaskThreadPool: Defaults to 2. """ - def __init__(self, max_workers: int = 2): + def __init__(self, max_workers: int = 2, name: str = None): self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = [] + self._name = name def submit_task(self, task: Task) -> None: """ @@ -264,36 +261,77 @@ def shutdown(self) -> None: class _JobManagerWorkerThreadPool: - """WRAPPER that hides two pools behind one interface""" - - def __init__(self, max_start_workers=2, max_download_workers=10): - # These are the TWO pools with their OWN _future_task_pairs - self._start_pool = _TaskThreadPool(max_workers=max_start_workers) - self._download_pool = _TaskThreadPool(max_workers=max_download_workers) + + """ + Generic wrapper that manages multiple thread pools with a dict. + Uses task class names as pool names automatically. + """ - def submit_start_task(self, task): - # Delegate to start pool - self._start_pool.submit_task(task) + def __init__(self, pool_configs: Optional[Dict[str, int]] = None): + """ + :param pool_configs: Dict of task_class_name -> max_workers + Example: {"_JobStartTask": 1, "_JobDownloadTask": 2} + """ + self._pools: Dict[str, _TaskThreadPool] = {} + self._pool_configs = pool_configs or {} - def submit_download_task(self, task): - # Delegate to download pool - self._download_pool.submit_task(task) + def _get_pool_name_for_task(self, task: Task) -> str: + """ + Get pool name from task class name. + """ + return task.__class__.__name__ - def process_all_updates(self, timeout=0): - # Get results from BOTH pools - start_results, start_remaining = self._start_pool.process_futures(timeout) - download_results, download_remaining = self._download_pool.process_futures(timeout) + def submit_task(self, task: Task) -> None: + """ + Submit a task to a pool named after its class. + Creates pool dynamically if it doesn't exist. + """ + pool_name = self._get_pool_name_for_task(task) - # Combine and return - all_results = start_results + download_results - return all_results, start_remaining, download_remaining + if pool_name not in self._pools: + # Create pool on-demand + max_workers = self._pool_configs.get(pool_name, 1) # Default 1 worker + self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers, name=pool_name) + _log.info(f"Created pool '{pool_name}' with {max_workers} workers") + + self._pools[pool_name].submit_task(task) + + def process_all_updates(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]: + """ + Process updates from ALL pools. + Returns: (all_results, dict of remaining tasks per pool) + """ + all_results = [] + remaining_by_pool = {} + + for pool_name, pool in self._pools.items(): + results, remaining = pool.process_futures(timeout) + all_results.extend(results) + remaining_by_pool[pool_name] = remaining + + return all_results, remaining_by_pool - def num_pending_tasks(self): - # Sum of BOTH pools - return (self._start_pool.num_pending_tasks() + - self._download_pool.num_pending_tasks()) + def num_pending_tasks(self, pool_name: Optional[str] = None) -> int: + if pool_name: + pool = self._pools.get(pool_name) + return pool.num_pending_tasks() if pool else 0 + else: + return sum(pool.num_pending_tasks() for pool in self._pools.values()) + + def shutdown(self, pool_name: Optional[str] = None) -> None: + """ + Shutdown pools. + If pool_name is None, shuts down all pools. + """ + if pool_name: + if pool_name in self._pools: + self._pools[pool_name].shutdown() + del self._pools[pool_name] + else: + for pool_name, pool in list(self._pools.items()): + pool.shutdown() + del self._pools[pool_name] - def shutdown(self): - # Shutdown BOTH pools - self._start_pool.shutdown() - self._download_pool.shutdown() + def list_pools(self) -> List[str]: + """List all active pool names.""" + return list(self._pools.keys()) diff --git a/tests/extra/job_management/test_manager.py b/tests/extra/job_management/test_manager.py index 1d02afb1c..2c4162974 100644 --- a/tests/extra/job_management/test_manager.py +++ b/tests/extra/job_management/test_manager.py @@ -729,7 +729,7 @@ def get_status(job_id, current_status): assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime) def test_process_threadworker_updates(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) # Submit tasks covering all cases @@ -769,7 +769,7 @@ def test_process_threadworker_updates(self, tmp_path, caplog): assert caplog.messages == [] def test_process_threadworker_updates_unknown(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) pool.submit_task(DummyResultTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) @@ -806,7 +806,7 @@ def test_process_threadworker_updates_unknown(self, tmp_path, caplog): assert caplog.messages == [dirty_equals.IsStr(regex=".*Ignoring unknown.*indices.*4.*")] def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]}) @@ -820,7 +820,7 @@ def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog): assert stats == {} def test_logs_on_invalid_update(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) # Malformed db_update (not a dict unpackable via **) diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index 47bde093f..04d810764 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -221,7 +221,7 @@ class TestJobManagerWorkerThreadPool: @pytest.fixture def worker_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: """Fixture for creating and cleaning up a worker thread pool.""" - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() yield pool pool.shutdown() From ab9914a7f21aa3b7d358c10aec909315723d5f98 Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Mon, 15 Dec 2025 19:21:56 +0100 Subject: [PATCH 14/16] include unit test for dict of pools --- openeo/extra/job_management/_thread_worker.py | 4 +- .../job_management/test_thread_worker.py | 378 +++++++++++++++++- 2 files changed, 378 insertions(+), 4 deletions(-) diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index bfb9814ef..de5c606ec 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -194,7 +194,7 @@ class _TaskThreadPool: Defaults to 2. """ - def __init__(self, max_workers: int = 2, name: str = None): + def __init__(self, max_workers: int = 1, name: str = 'default'): self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = [] self._name = name @@ -335,3 +335,5 @@ def shutdown(self, pool_name: Optional[str] = None) -> None: def list_pools(self) -> List[str]: """List all active pool names.""" return list(self._pools.keys()) + + diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index 04d810764..a52a763b7 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -11,6 +11,7 @@ from openeo.extra.job_management._thread_worker import ( Task, + _TaskThreadPool, _JobManagerWorkerThreadPool, _JobStartTask, _TaskResult, @@ -217,11 +218,11 @@ def execute(self) -> _TaskResult: return _TaskResult(job_id=self.job_id, df_idx=self.df_idx, db_update={"status": "all fine"}) -class TestJobManagerWorkerThreadPool: +class TestTaskThreadPool: @pytest.fixture - def worker_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + def worker_pool(self) -> Iterator[_TaskThreadPool]: """Fixture for creating and cleaning up a worker thread pool.""" - pool = _JobManagerWorkerThreadPool() + pool = _TaskThreadPool() yield pool pool.shutdown() @@ -391,3 +392,374 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): ] +import pytest +import time +import threading +import logging +from typing import Iterator + +_log = logging.getLogger(__name__) + + +class TestJobManagerWorkerThreadPool: + @pytest.fixture + def thread_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + """Fixture for creating and cleaning up a thread pool manager.""" + pool = _JobManagerWorkerThreadPool() + yield pool + pool.shutdown() + + @pytest.fixture + def configured_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + """Fixture with pre-configured pools.""" + pool = _JobManagerWorkerThreadPool( + pool_configs={ + "NopTask": 2, + "DummyTask": 3, + "BlockingTask": 1, + } + ) + yield pool + pool.shutdown() + + def test_init_empty_config(self): + """Test initialization with empty config.""" + pool = _JobManagerWorkerThreadPool() + assert pool._pools == {} + assert pool._pool_configs == {} + pool.shutdown() + + def test_init_with_config(self): + """Test initialization with pool configurations.""" + pool = _JobManagerWorkerThreadPool({ + "NopTask": 2, + "DummyTask": 3, + }) + # Pools should NOT be created until first use + assert pool._pools == {} + assert pool._pool_configs == { + "NopTask": 2, + "DummyTask": 3, + } + pool.shutdown() + + def test_submit_task_creates_pool(self, thread_pool): + """Test that submitting a task creates a pool dynamically.""" + task = NopTask(job_id="j-1", df_idx=1) + + # No pools initially + assert thread_pool.list_pools() == [] + + # Submit task - should create pool + thread_pool.submit_task(task) + + # Pool should be created with default workers (1) + assert thread_pool.list_pools() == ["NopTask"] + assert "NopTask" in thread_pool._pools + + # Process to complete the task + results, remaining = thread_pool.process_all_updates(timeout=0.1) + assert len(results) == 1 + assert results[0].job_id == "j-1" + assert remaining == {"NopTask": 0} + + def test_submit_task_uses_config(self, configured_pool): + """Test that pool creation uses configuration.""" + task = NopTask(job_id="j-1", df_idx=1) + + # Submit task - should create pool with configured workers + configured_pool.submit_task(task) + + assert "NopTask" in configured_pool._pools + # Can't directly check max_workers, but pool should exist + assert "NopTask" in configured_pool.list_pools() + + def test_submit_multiple_task_types(self, thread_pool): + """Test submitting different task types to different pools.""" + # Submit different task types + task1 = NopTask(job_id="j-1", df_idx=1) + task2 = DummyTask(job_id="j-2", df_idx=2) + task3 = NopTask(job_id="j-3", df_idx=3) + + thread_pool.submit_task(task1) # Goes to "NopTask" pool + thread_pool.submit_task(task2) # Goes to "DummyTask" pool + thread_pool.submit_task(task3) # Goes to "NopTask" pool (existing) + + # Should have 2 pools + pools = sorted(thread_pool.list_pools()) + assert pools == ["DummyTask", "NopTask"] + + # Check pending tasks + assert thread_pool.num_pending_tasks() == 3 + assert thread_pool.num_pending_tasks("NopTask") == 2 + assert thread_pool.num_pending_tasks("DummyTask") == 1 + assert thread_pool.num_pending_tasks("NonExistent") == 0 + + def test_process_all_updates_empty(self, thread_pool): + """Test processing updates with no pools.""" + results, remaining = thread_pool.process_all_updates(timeout=0) + assert results == [] + assert remaining == {} + + def test_process_all_updates_multiple_pools(self, thread_pool): + """Test processing updates across multiple pools.""" + # Submit tasks to different pools + thread_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) # NopTask pool + thread_pool.submit_task(NopTask(job_id="j-2", df_idx=2)) # NopTask pool + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) # DummyTask pool + + # Process updates + results, remaining = thread_pool.process_all_updates(timeout=0.1) + + # Should get 3 results + assert len(results) == 3 + # Check results by pool + nop_results = [r for r in results if r.job_id in ["j-1", "j-2"]] + dummy_results = [r for r in results if r.job_id == "j-3"] + assert len(nop_results) == 2 + assert len(dummy_results) == 1 + + # All tasks should be completed + assert remaining == {"NopTask": 0, "DummyTask": 0} + + def test_process_all_updates_partial_completion(self): + """Test processing when some tasks are still running.""" + # Use a pool with blocking tasks + pool = _JobManagerWorkerThreadPool() + + # Create a blocking task + event = threading.Event() + blocking_task = BlockingTask(job_id="j-block", df_idx=0, event=event, success=True) + + # Create a quick task + quick_task = NopTask(job_id="j-quick", df_idx=1) + + pool.submit_task(blocking_task) # BlockingTask pool + pool.submit_task(quick_task) # NopTask pool + + # Process with timeout=0 - only quick task should complete + results, remaining = pool.process_all_updates(timeout=0) + + # Only quick task completed + assert len(results) == 1 + assert results[0].job_id == "j-quick" + + # Blocking task still pending + assert remaining == {"BlockingTask": 1, "NopTask": 0} + assert pool.num_pending_tasks() == 1 + assert pool.num_pending_tasks("BlockingTask") == 1 + + # Release blocking task and process again + event.set() + results2, remaining2 = pool.process_all_updates(timeout=0.1) + + assert len(results2) == 1 + assert results2[0].job_id == "j-block" + assert remaining2 == {"BlockingTask": 0, "NopTask": 0} + + pool.shutdown() + + def test_num_pending_tasks(self, thread_pool): + """Test counting pending tasks.""" + # Initially empty + assert thread_pool.num_pending_tasks() == 0 + assert thread_pool.num_pending_tasks("NopTask") == 0 + + # Add some tasks + thread_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) + thread_pool.submit_task(NopTask(job_id="j-2", df_idx=2)) + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) + + # Check totals + assert thread_pool.num_pending_tasks() == 3 + assert thread_pool.num_pending_tasks("NopTask") == 2 + assert thread_pool.num_pending_tasks("DummyTask") == 1 + assert thread_pool.num_pending_tasks("NonExistentPool") == 0 + + # Process all + thread_pool.process_all_updates(timeout=0.1) + + # Should be empty + assert thread_pool.num_pending_tasks() == 0 + assert thread_pool.num_pending_tasks("NopTask") == 0 + + def test_shutdown_specific_pool(self): + """Test shutting down a specific pool.""" + # Create fresh pool for destructive test + pool = _JobManagerWorkerThreadPool() + + # Create two pools + pool.submit_task(NopTask(job_id="j-1", df_idx=1)) # NopTask pool + pool.submit_task(DummyTask(job_id="j-2", df_idx=2)) # DummyTask pool + + assert sorted(pool.list_pools()) == ["DummyTask", "NopTask"] + + # Shutdown NopTask pool only + pool.shutdown("NopTask") + + # Only DummyTask pool should remain + assert pool.list_pools() == ["DummyTask"] + + # Can't submit to shutdown pool + # Actually, it will create a new pool since we deleted it + pool.submit_task(NopTask(job_id="j-3", df_idx=3)) # Creates new NopTask pool + assert sorted(pool.list_pools()) == ["DummyTask", "NopTask"] + + pool.shutdown() + + def test_shutdown_all(self): + """Test shutting down all pools.""" + # Create fresh pool for destructive test + pool = _JobManagerWorkerThreadPool() + + # Create multiple pools + pool.submit_task(NopTask(job_id="j-1", df_idx=1)) + pool.submit_task(DummyTask(job_id="j-2", df_idx=2)) + + assert len(pool.list_pools()) == 2 + + # Shutdown all + pool.shutdown() + + # All pools should be gone + assert pool.list_pools() == [] + + # Can't submit any more tasks after shutdown + # Actually, shutdown() doesn't prevent creating new pools + # So we can test that shutdown clears existing pools + assert len(pool._pools) == 0 + + def test_custom_get_pool_name(self): + """Test custom task class to verify pool name selection.""" + + @dataclass(frozen=True) + class CustomTask(Task): + # Fields are inherited from Task: job_id, df_idx + + def execute(self) -> _TaskResult: + return _TaskResult(job_id=self.job_id, df_idx=self.df_idx) + + pool = _JobManagerWorkerThreadPool() + + # Submit custom task - must provide all required fields + task = CustomTask(job_id="j-1", df_idx=1) + pool.submit_task(task) + + # Pool should be named after class + assert pool.list_pools() == ["CustomTask"] + assert pool.num_pending_tasks() == 1 + + # Process it + results, remaining = pool.process_all_updates(timeout=0.1) + assert len(results) == 1 + assert results[0].job_id == "j-1" + + pool.shutdown() + + def test_concurrent_submissions(self, thread_pool): + """Test concurrent task submissions to same pool.""" + import concurrent.futures + + def submit_tasks(start_idx: int): + for i in range(5): + thread_pool.submit_task(NopTask(job_id=f"j-{start_idx + i}", df_idx=start_idx + i)) + + # Submit tasks from multiple threads + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(submit_tasks, i * 10) for i in range(3)] + concurrent.futures.wait(futures) + + # Should have all tasks in one pool + assert thread_pool.list_pools() == ["NopTask"] + assert thread_pool.num_pending_tasks() == 15 + + # Process them all + results, remaining = thread_pool.process_all_updates(timeout=0.5) + + assert len(results) == 15 + assert remaining == {"NopTask": 0} + + def test_pool_parallelism_with_blocking_tasks(self): + """Test that multiple workers allow parallel execution.""" + pool = _JobManagerWorkerThreadPool({ + "BlockingTask": 3, # 3 workers for blocking tasks + }) + + # Create multiple blocking tasks + events = [threading.Event() for _ in range(5)] + start_time = time.time() + + for i, event in enumerate(events): + pool.submit_task(BlockingTask( + job_id=f"j-block-{i}", + df_idx=i, + event=event, + success=True + )) + + # Initially all pending + assert pool.num_pending_tasks() == 5 + + # Release all events at once + for event in events: + event.set() + + # Process with timeout - all should complete + results, remaining = pool.process_all_updates(timeout=0.5) + + # All should complete (if pool had enough workers) + assert len(results) == 5 + assert remaining == {"BlockingTask": 0} + + # Check they all completed + for result in results: + assert result.job_id.startswith("j-block-") + + pool.shutdown() + + def test_task_with_error_handling(self, thread_pool): + """Test that task errors are properly handled in the pool.""" + # Submit a failing DummyTask (j-666 fails) + thread_pool.submit_task(DummyTask(job_id="j-666", df_idx=0)) + + # Process it + results, remaining = thread_pool.process_all_updates(timeout=0.1) + + # Should get error result + assert len(results) == 1 + result = results[0] + assert result.job_id == "j-666" + assert result.db_update == {"status": "threaded task failed"} + assert result.stats_update == {"threaded task failed": 1} + assert remaining == {"DummyTask": 0} + + def test_mixed_success_and_error_tasks(self, thread_pool): + """Test mix of successful and failing tasks.""" + # Submit mix of tasks + thread_pool.submit_task(DummyTask(job_id="j-1", df_idx=1)) # Success + thread_pool.submit_task(DummyTask(job_id="j-666", df_idx=2)) # Failure + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) # Success + + # Process all + results, remaining = thread_pool.process_all_updates(timeout=0.1) + + # Should get 3 results + assert len(results) == 3 + assert remaining == {"DummyTask": 0} + + # Check results + success_results = [r for r in results if r.job_id != "j-666"] + error_results = [r for r in results if r.job_id == "j-666"] + + assert len(success_results) == 2 + assert len(error_results) == 1 + + # Verify success results + for result in success_results: + assert result.db_update == {"status": "dummified"} + assert result.stats_update == {"dummy": 1} + + # Verify error result + error_result = error_results[0] + assert error_result.db_update == {"status": "threaded task failed"} + assert error_result.stats_update == {"threaded task failed": 1} \ No newline at end of file From 1fce77b89f7feec5cd77e14f509cfa06802e5c5e Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Tue, 16 Dec 2025 12:53:19 +0100 Subject: [PATCH 15/16] tmp_path usage and renaming --- openeo/extra/job_management/_manager.py | 2 +- openeo/extra/job_management/_thread_worker.py | 2 +- .../job_management/test_thread_worker.py | 76 ++++++------------- 3 files changed, 25 insertions(+), 55 deletions(-) diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index ad4a6a2ff..ae0b38305 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -690,7 +690,7 @@ def _process_threadworker_updates( :param stats: Dictionary accumulating statistic counters """ # Retrieve completed task results immediately - results, _ = worker_pool.process_all_updates(timeout=0) + results, _ = worker_pool.process_futures(timeout=0) # Collect update dicts updates: List[Dict[str, Any]] = [] diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index de5c606ec..439a6879b 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -296,7 +296,7 @@ def submit_task(self, task: Task) -> None: self._pools[pool_name].submit_task(task) - def process_all_updates(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]: + def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]: """ Process updates from ALL pools. Returns: (all_results, dict of remaining tasks per pool) diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index a52a763b7..13115aeb6 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -86,13 +86,8 @@ def test_hide_token(self, serializer): class TestJobDownloadTask: - # Use a temporary directory for safe file handling - @pytest.fixture - def temp_dir(self): - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path): + def test_job_download_success(self, requests_mock: Mocker, tmp_path: Path): """ Test a successful job download and verify file content and stats update. """ @@ -107,7 +102,7 @@ def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path): backend._set_job_status(job_id=job_id, status="finished") backend.batch_jobs[job_id]["status"] = "finished" - download_dir = temp_dir / job_id / "results" + download_dir = tmp_path / job_id / "results" download_dir.mkdir(parents=True) # Create the task instance @@ -136,7 +131,7 @@ def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path): assert downloaded_file.read_bytes() == b"The downloaded file content." - def test_job_download_failure(self, requests_mock: Mocker, temp_dir: Path): + def test_job_download_failure(self, requests_mock: Mocker, tmp_path: Path): """ Test a failed download (e.g., bad connection) and verify error reporting. """ @@ -156,7 +151,7 @@ def test_job_download_failure(self, requests_mock: Mocker, temp_dir: Path): backend._set_job_status(job_id=job_id, status="finished") backend.batch_jobs[job_id]["finished"] = "error" - download_dir = temp_dir / job_id / "results" + download_dir = tmp_path / job_id / "results" download_dir.mkdir(parents=True) # Create the task instance @@ -392,14 +387,6 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): ] -import pytest -import time -import threading -import logging -from typing import Iterator - -_log = logging.getLogger(__name__) - class TestJobManagerWorkerThreadPool: @pytest.fixture @@ -447,18 +434,17 @@ def test_submit_task_creates_pool(self, thread_pool): """Test that submitting a task creates a pool dynamically.""" task = NopTask(job_id="j-1", df_idx=1) - # No pools initially assert thread_pool.list_pools() == [] # Submit task - should create pool thread_pool.submit_task(task) - # Pool should be created with default workers (1) + # Pool should be created assert thread_pool.list_pools() == ["NopTask"] assert "NopTask" in thread_pool._pools # Process to complete the task - results, remaining = thread_pool.process_all_updates(timeout=0.1) + results, remaining = thread_pool.process_futures(timeout=0.1) assert len(results) == 1 assert results[0].job_id == "j-1" assert remaining == {"NopTask": 0} @@ -471,7 +457,6 @@ def test_submit_task_uses_config(self, configured_pool): configured_pool.submit_task(task) assert "NopTask" in configured_pool._pools - # Can't directly check max_workers, but pool should exist assert "NopTask" in configured_pool.list_pools() def test_submit_multiple_task_types(self, thread_pool): @@ -495,25 +480,23 @@ def test_submit_multiple_task_types(self, thread_pool): assert thread_pool.num_pending_tasks("DummyTask") == 1 assert thread_pool.num_pending_tasks("NonExistent") == 0 - def test_process_all_updates_empty(self, thread_pool): - """Test processing updates with no pools.""" - results, remaining = thread_pool.process_all_updates(timeout=0) + def test_process_futures_updates_empty(self, thread_pool): + """Test process futures with no pools.""" + results, remaining = thread_pool.process_futures(timeout=0) assert results == [] assert remaining == {} - def test_process_all_updates_multiple_pools(self, thread_pool): + def test_process_futures_updates_multiple_pools(self, thread_pool): """Test processing updates across multiple pools.""" # Submit tasks to different pools thread_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) # NopTask pool thread_pool.submit_task(NopTask(job_id="j-2", df_idx=2)) # NopTask pool thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) # DummyTask pool - # Process updates - results, remaining = thread_pool.process_all_updates(timeout=0.1) + results, remaining = thread_pool.process_futures(timeout=0.1) - # Should get 3 results assert len(results) == 3 - # Check results by pool + nop_results = [r for r in results if r.job_id in ["j-1", "j-2"]] dummy_results = [r for r in results if r.job_id == "j-3"] assert len(nop_results) == 2 @@ -522,7 +505,7 @@ def test_process_all_updates_multiple_pools(self, thread_pool): # All tasks should be completed assert remaining == {"NopTask": 0, "DummyTask": 0} - def test_process_all_updates_partial_completion(self): + def test_process_futures_updates_partial_completion(self): """Test processing when some tasks are still running.""" # Use a pool with blocking tasks pool = _JobManagerWorkerThreadPool() @@ -538,7 +521,7 @@ def test_process_all_updates_partial_completion(self): pool.submit_task(quick_task) # NopTask pool # Process with timeout=0 - only quick task should complete - results, remaining = pool.process_all_updates(timeout=0) + results, remaining = pool.process_futures(timeout=0) # Only quick task completed assert len(results) == 1 @@ -551,7 +534,7 @@ def test_process_all_updates_partial_completion(self): # Release blocking task and process again event.set() - results2, remaining2 = pool.process_all_updates(timeout=0.1) + results2, remaining2 = pool.process_futures(timeout=0.1) assert len(results2) == 1 assert results2[0].job_id == "j-block" @@ -577,7 +560,7 @@ def test_num_pending_tasks(self, thread_pool): assert thread_pool.num_pending_tasks("NonExistentPool") == 0 # Process all - thread_pool.process_all_updates(timeout=0.1) + thread_pool.process_futures(timeout=0.1) # Should be empty assert thread_pool.num_pending_tasks() == 0 @@ -620,28 +603,20 @@ def test_shutdown_all(self): # Shutdown all pool.shutdown() - - # All pools should be gone + assert pool.list_pools() == [] - - # Can't submit any more tasks after shutdown - # Actually, shutdown() doesn't prevent creating new pools - # So we can test that shutdown clears existing pools assert len(pool._pools) == 0 def test_custom_get_pool_name(self): """Test custom task class to verify pool name selection.""" @dataclass(frozen=True) - class CustomTask(Task): - # Fields are inherited from Task: job_id, df_idx - + class CustomTask(Task): def execute(self) -> _TaskResult: return _TaskResult(job_id=self.job_id, df_idx=self.df_idx) pool = _JobManagerWorkerThreadPool() - # Submit custom task - must provide all required fields task = CustomTask(job_id="j-1", df_idx=1) pool.submit_task(task) @@ -650,7 +625,7 @@ def execute(self) -> _TaskResult: assert pool.num_pending_tasks() == 1 # Process it - results, remaining = pool.process_all_updates(timeout=0.1) + results, _ = pool.process_futures(timeout=0.1) assert len(results) == 1 assert results[0].job_id == "j-1" @@ -674,7 +649,7 @@ def submit_tasks(start_idx: int): assert thread_pool.num_pending_tasks() == 15 # Process them all - results, remaining = thread_pool.process_all_updates(timeout=0.5) + results, remaining = thread_pool.process_futures(timeout=0.5) assert len(results) == 15 assert remaining == {"NopTask": 0} @@ -687,7 +662,6 @@ def test_pool_parallelism_with_blocking_tasks(self): # Create multiple blocking tasks events = [threading.Event() for _ in range(5)] - start_time = time.time() for i, event in enumerate(events): pool.submit_task(BlockingTask( @@ -704,14 +678,10 @@ def test_pool_parallelism_with_blocking_tasks(self): for event in events: event.set() - # Process with timeout - all should complete - results, remaining = pool.process_all_updates(timeout=0.5) - - # All should complete (if pool had enough workers) + results, remaining = pool.process_futures(timeout=0.5) assert len(results) == 5 assert remaining == {"BlockingTask": 0} - # Check they all completed for result in results: assert result.job_id.startswith("j-block-") @@ -723,7 +693,7 @@ def test_task_with_error_handling(self, thread_pool): thread_pool.submit_task(DummyTask(job_id="j-666", df_idx=0)) # Process it - results, remaining = thread_pool.process_all_updates(timeout=0.1) + results, remaining = thread_pool.process_futures(timeout=0.1) # Should get error result assert len(results) == 1 @@ -741,7 +711,7 @@ def test_mixed_success_and_error_tasks(self, thread_pool): thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) # Success # Process all - results, remaining = thread_pool.process_all_updates(timeout=0.1) + results, remaining = thread_pool.process_futures(timeout=0.1) # Should get 3 results assert len(results) == 3 From 21992faa6d94ab8611976b8f6385db8281c281ff Mon Sep 17 00:00:00 2001 From: Hans Vanrompay Date: Tue, 16 Dec 2025 13:43:55 +0100 Subject: [PATCH 16/16] fix documentation --- openeo/extra/job_management/_thread_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 439a6879b..42a3ee097 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -191,7 +191,7 @@ class _TaskThreadPool: :param max_workers: Maximum number of concurrent threads to use for execution. - Defaults to 2. + Defaults to 1. """ def __init__(self, max_workers: int = 1, name: str = 'default'):