Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions openeo/extra/job_management/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openeo.extra.job_management._thread_worker import (
_JobManagerWorkerThreadPool,
_JobStartTask,
_JobDownloadTask
)
from openeo.rest import OpenEoApiError
from openeo.rest.auth.auth import BearerAuth
Expand Down Expand Up @@ -172,6 +173,7 @@ def start_job(

.. versionchanged:: 0.47.0
Added ``download_results`` parameter.

"""

# Expected columns in the job DB dataframes.
Expand Down Expand Up @@ -219,6 +221,7 @@ def __init__(
)
self._thread = None
self._worker_pool = None
self._download_pool = None
# Generic cache
self._cache = {}

Expand Down Expand Up @@ -351,6 +354,7 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas

self._stop_thread = False
self._worker_pool = _JobManagerWorkerThreadPool()
self._download_pool = _JobManagerWorkerThreadPool()

def run_loop():
# TODO: support user-provided `stats`
Expand Down Expand Up @@ -388,7 +392,13 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET):

.. versionadded:: 0.32.0
"""
self._worker_pool.shutdown()
if self._worker_pool is not None:
self._worker_pool.shutdown()
self._worker_pool = None

if self._download_pool is not None:
self._download_pool.shutdown()
self._download_pool = None

if self._thread is not None:
self._stop_thread = True
Expand Down Expand Up @@ -494,13 +504,20 @@ def run_jobs(

self._worker_pool = _JobManagerWorkerThreadPool()

if self._download_results:
self._download_pool = _JobManagerWorkerThreadPool()


while (
sum(
job_db.count_by_status(
statuses=["not_started", "created", "queued_for_start", "queued", "running"]
).values()
)
> 0
).values()) > 0

or (self._worker_pool is not None and self._worker_pool.num_pending_tasks() > 0)

or (self._download_pool is not None and self._download_pool.num_pending_tasks() > 0)

):
self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats)
stats["run_jobs loop"] += 1
Expand All @@ -510,8 +527,14 @@ def run_jobs(
time.sleep(self.poll_sleep)
stats["sleep"] += 1

# TODO; run post process after shutdown once more to ensure completion?
self._worker_pool.shutdown()

if self._worker_pool is not None:
self._worker_pool.shutdown()
self._worker_pool = None

if self._download_pool is not None:
self._download_pool.shutdown()
self._download_pool = None

return stats

Expand Down Expand Up @@ -553,7 +576,11 @@ def _job_update_loop(
stats["job_db persist"] += 1
total_added += 1

self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats)
if self._worker_pool is not None:
self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats)

if self._download_pool is not None:
self._process_threadworker_updates(worker_pool=self._download_pool, job_db=job_db, stats=stats)

# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
for job, row in jobs_done:
Expand All @@ -565,6 +592,7 @@ def _job_update_loop(
for job, row in jobs_cancel:
self.on_job_cancel(job, row)


def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None):
"""Helper method for launching jobs

Expand Down Expand Up @@ -721,17 +749,28 @@ def on_job_done(self, job: BatchJob, row):
:param job: The job that has finished.
:param row: DataFrame row containing the job's metadata.
"""
# TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use?
if self._download_results:
job_metadata = job.describe()
job_dir = self.get_job_dir(job.job_id)
metadata_path = self.get_job_metadata_path(job.job_id)

job_dir = self.get_job_dir(job.job_id)
self.ensure_job_dir_exists(job.job_id)
job.get_results().download_files(target=job_dir)

with metadata_path.open("w", encoding="utf-8") as f:
json.dump(job_metadata, f, ensure_ascii=False)
# Proactively refresh bearer token (because task in thread will not be able to do that
job_con = job.connection
self._refresh_bearer_token(connection=job_con)

task = _JobDownloadTask(
root_url=job_con.root_url,
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
job_id=job.job_id,
df_idx=row.name, #this is going to be the index in the not saterted dataframe; should not be an issue as there is no db update for download task
download_dir=job_dir,
)
_log.info(f"Submitting download task {task} to download thread pool")

if self._download_pool is None:
self._download_pool = _JobManagerWorkerThreadPool()

self._download_pool.submit_task(task)

def on_job_error(self, job: BatchJob, row):
"""
Expand Down Expand Up @@ -783,6 +822,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row):
except Exception as e:
_log.error(f"Unexpected error while handling job {job.job_id}: {e}")

#TODO pull this functionality away from the manager to a general utility class? job dir creation could be reused for tje Jobdownload task
def get_job_dir(self, job_id: str) -> Path:
"""Path to directory where job metadata, results and error logs are be saved."""
return self._root_dir / f"job_{job_id}"
Expand Down
45 changes: 45 additions & 0 deletions openeo/extra/job_management/_thread_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path

import json
import urllib3.util

import openeo
Expand Down Expand Up @@ -140,7 +142,46 @@ def execute(self) -> _TaskResult:
stats_update={"start_job error": 1},
)

class _JobDownloadTask(ConnectedTask):
"""
Task for downloading job results and metadata.

:param download_dir:
Root directory where job results and metadata will be downloaded.
"""
def __init__(self, download_dir: Path, **kwargs):
super().__init__(**kwargs)
object.__setattr__(self, 'download_dir', download_dir)

def execute(self) -> _TaskResult:
try:
job = self.get_connection(retry=True).job(self.job_id)

# Download results
job.get_results().download_files(target=self.download_dir)

# Download metadata
job_metadata = job.describe()
metadata_path = self.download_dir / f"job_{self.job_id}.json"
with metadata_path.open("w", encoding="utf-8") as f:
json.dump(job_metadata, f, ensure_ascii=False)

_log.info(f"Job {self.job_id!r} results downloaded successfully")
return _TaskResult(
job_id=self.job_id,
df_idx=self.df_idx,
db_update={}, #TODO consider db updates?
stats_update={"job download": 1},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably also want to keep track of the number of files downloaded, not just the number of jobs that were downloaded from

)
except Exception as e:
_log.error(f"Failed to download results for job {self.job_id!r}: {e!r}")
return _TaskResult(
job_id=self.job_id,
df_idx=self.df_idx,
db_update={},
stats_update={"job download error": 1},
)

class _JobManagerWorkerThreadPool:
"""
Thread pool-based worker that manages the execution of asynchronous tasks.
Expand Down Expand Up @@ -207,6 +248,10 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe

self._future_task_pairs = to_keep
return results, len(to_keep)

def num_pending_tasks(self) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abbreviating "number" to "num" is a bit off brand with the rest of the openEO client API.
In this age of large monitors and AI powered autocomplete there is little reason anymore to save on keystrokens at cost of readability.
I'd go for pending_task_count which is is just 2 characters more 😄

"""Return the number of tasks that are still pending (not completed)."""
return len(self._future_task_pairs)

def shutdown(self) -> None:
"""Shuts down the thread pool gracefully."""
Expand Down
103 changes: 103 additions & 0 deletions tests/extra/job_management/test_thread_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import time
from dataclasses import dataclass
from typing import Iterator
from pathlib import Path
import tempfile
from requests_mock import Mocker

import pytest

Expand All @@ -11,6 +14,7 @@
_JobManagerWorkerThreadPool,
_JobStartTask,
_TaskResult,
_JobDownloadTask
)
from openeo.rest._testing import DummyBackend

Expand Down Expand Up @@ -79,6 +83,103 @@ def test_hide_token(self, serializer):
assert "job-123" in serialized
assert secret not in serialized

class TestJobDownloadTask:

# Use a temporary directory for safe file handling
@pytest.fixture
def temp_dir(self):
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)

def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path):
"""
Test a successful job download and verify file content and stats update.
"""
job_id = "job-007"
df_idx = 42

# We set up a dummy backend to simulate the job results and assert the expected calls are triggered
backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock)
backend.next_result = b"The downloaded file content."
backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"}

backend._set_job_status(job_id=job_id, status="finished")
backend.batch_jobs[job_id]["status"] = "finished"

download_dir = temp_dir / job_id / "results"
download_dir.mkdir(parents=True)

# Create the task instance
task = _JobDownloadTask(
root_url="https://openeo.dummy.test/",
bearer_token="dummy-token-7",
job_id=job_id,
df_idx=df_idx,
download_dir=download_dir,
)

# Execute the task
result = task.execute()

# Verify TaskResult structure
assert isinstance(result, _TaskResult)
assert result.job_id == job_id
assert result.df_idx == df_idx

# Verify stats update for the MultiBackendJobManager
assert result.stats_update == {"job download": 1}

# Verify download content (crucial part of the unit test)
downloaded_file = download_dir / "result.data"
assert downloaded_file.exists()
assert downloaded_file.read_bytes() == b"The downloaded file content."


def test_job_download_failure(self, requests_mock: Mocker, temp_dir: Path):
"""
Test a failed download (e.g., bad connection) and verify error reporting.
"""
job_id = "job-008"
df_idx = 55

# Set up dummy backend to simulate failure during results listing
backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock)

#simulate and error when downloading the results
requests_mock.get(
f"https://openeo.dummy.test/jobs/{job_id}/results",
status_code=500,
json={"code": "InternalError", "message": "Failed to list results"})

backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"}
backend._set_job_status(job_id=job_id, status="finished")
backend.batch_jobs[job_id]["finished"] = "error"

download_dir = temp_dir / job_id / "results"
download_dir.mkdir(parents=True)

# Create the task instance
task = _JobDownloadTask(
root_url="https://openeo.dummy.test/",
bearer_token="dummy-token-8",
job_id=job_id,
df_idx=df_idx,
download_dir=download_dir,
)

# Execute the task
result = task.execute()

# Verify TaskResult structure
assert isinstance(result, _TaskResult)
assert result.job_id == job_id
assert result.df_idx == df_idx

# Verify stats update for the MultiBackendJobManager
assert result.stats_update == {"job download error": 1}

# Verify no file was created (or only empty/failed files)
assert not any(p.is_file() for p in download_dir.glob("*"))

class NopTask(Task):
"""Do Nothing"""
Expand Down Expand Up @@ -288,3 +389,5 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog):
assert caplog.messages == [
"Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting for you, buddy')"
]