Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `MultiBackendJobManager`: add `download_results` option to enable/disable the automated download of job results once completed by the job manager ([#744](https://github.com/Open-EO/openeo-python-client/issues/744))
- Support UDF based spatial and temporal extents in `load_collection`, `load_stac` and `filter_temporal` ([#831](https://github.com/Open-EO/openeo-python-client/pull/831))

### Changed

Expand Down
2 changes: 2 additions & 0 deletions openeo/internal/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def print_json(
class _FromNodeMixin(abc.ABC):
"""Mixin for classes that want to hook into the generation of a "from_node" reference."""

# TODO: rename this class: it's more an interface than a mixin, and "from node" might be confusing as explained below.

@abc.abstractmethod
def from_node(self) -> PGNode:
# TODO: "from_node" is a bit a confusing name:
Expand Down
17 changes: 12 additions & 5 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
import openeo
from openeo.config import config_log, get_config_option
from openeo.internal.documentation import openeo_process
from openeo.internal.graph_building import FlatGraphableMixin, PGNode, as_flat_graph
from openeo.internal.graph_building import (
FlatGraphableMixin,
PGNode,
_FromNodeMixin,
as_flat_graph,
)
from openeo.internal.jupyter import VisualDict, VisualList
from openeo.internal.processes.builder import ProcessBuilderBase
from openeo.internal.warnings import deprecated, legacy_alias
Expand Down Expand Up @@ -1186,8 +1191,8 @@ def load_collection(
self,
collection_id: Union[str, Parameter],
spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
bands: Union[Iterable[str], Parameter, str, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
bands: Union[Iterable[str], Parameter, str, _FromNodeMixin, None] = None,
properties: Union[
Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None
] = None,
Expand Down Expand Up @@ -1287,8 +1292,10 @@ def load_result(
def load_stac(
self,
url: str,
spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
spatial_extent: Union[
dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, _FromNodeMixin, None
] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
bands: Union[Iterable[str], Parameter, str, None] = None,
properties: Union[
Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None
Expand Down
20 changes: 11 additions & 9 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@


# Type annotation aliases
InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, None]
InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, _FromNodeMixin, None]


class DataCube(_ProcessGraphAbstraction):
Expand Down Expand Up @@ -165,8 +165,10 @@ def load_collection(
cls,
collection_id: Union[str, Parameter],
connection: Optional[Connection] = None,
spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
spatial_extent: Union[
dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, _FromNodeMixin, None
] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
bands: Union[Iterable[str], Parameter, str, None] = None,
fetch_metadata: bool = True,
properties: Union[
Expand Down Expand Up @@ -480,22 +482,22 @@ def _get_temporal_extent(
*args,
start_date: InputDate = None,
end_date: InputDate = None,
extent: Union[Sequence[InputDate], Parameter, str, None] = None,
) -> Union[List[Union[str, Parameter, PGNode, None]], Parameter]:
extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
) -> Union[List[Union[str, Parameter, PGNode, _FromNodeMixin, None]], Parameter, _FromNodeMixin]:
"""Parameter aware temporal_extent normalizer"""
# TODO: move this outside of DataCube class
# TODO: return extent as tuple instead of list
if len(args) == 1 and isinstance(args[0], Parameter):
if len(args) == 1 and isinstance(args[0], (Parameter, _FromNodeMixin)):
assert start_date is None and end_date is None and extent is None
return args[0]
elif len(args) == 0 and isinstance(extent, Parameter):
elif len(args) == 0 and isinstance(extent, (Parameter, _FromNodeMixin)):
assert start_date is None and end_date is None
# TODO: warn about unexpected parameter schema
return extent
else:
def convertor(d: Any) -> Any:
# TODO: can this be generalized through _FromNodeMixin?
if isinstance(d, Parameter) or isinstance(d, PGNode):
if isinstance(d, Parameter) or isinstance(d, _FromNodeMixin):
# TODO: warn about unexpected parameter schema
return d
elif isinstance(d, ProcessBuilderBase):
Expand Down Expand Up @@ -531,7 +533,7 @@ def filter_temporal(
*args,
start_date: InputDate = None,
end_date: InputDate = None,
extent: Union[Sequence[InputDate], Parameter, str, None] = None,
extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
) -> DataCube:
"""
Limit the DataCube to a certain date range, which can be specified in several ways:
Expand Down
65 changes: 65 additions & 0 deletions tests/rest/datacube/test_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import shapely
import shapely.geometry

import openeo.processes
from openeo import collection_property
from openeo.api.process import Parameter
from openeo.internal.graph_building import PGNode
from openeo.metadata import SpatialDimension
from openeo.rest import BandMathException, OpenEoClientException
from openeo.rest._testing import build_capabilities
Expand Down Expand Up @@ -698,6 +700,69 @@ def test_filter_temporal_single_arg(s2cube: DataCube, arg, expect_failure):
_ = s2cube.filter_temporal(arg)


@pytest.mark.parametrize(
"udf_factory",
[
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
],
)
def test_filter_temporal_from_udf(s2cube: DataCube, udf_factory):
temporal_extent = udf_factory(data=[1, 2, 3], udf="print('hello time')", runtime="Python")
cube = s2cube.filter_temporal(temporal_extent)
assert get_download_graph(cube, drop_save_result=True) == {
"loadcollection1": {
"process_id": "load_collection",
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
},
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": [1, 2, 3], "udf": "print('hello time')", "runtime": "Python"},
},
"filtertemporal1": {
"process_id": "filter_temporal",
"arguments": {
"data": {"from_node": "loadcollection1"},
"extent": {"from_node": "runudf1"},
},
},
}


@pytest.mark.parametrize(
"udf_factory",
[
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
],
)
def test_filter_temporal_start_end_from_udf(s2cube: DataCube, udf_factory):
start = udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python")
end = udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python")
cube = s2cube.filter_temporal(start_date=start, end_date=end)
assert get_download_graph(cube, drop_save_result=True) == {
"loadcollection1": {
"process_id": "load_collection",
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
},
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"},
},
"runudf2": {
"process_id": "run_udf",
"arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"},
},
"filtertemporal1": {
"process_id": "filter_temporal",
"arguments": {
"data": {"from_node": "loadcollection1"},
"extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}],
},
},
}


def test_max_time(s2cube, api_version):
im = s2cube.max_time()
graph = _get_leaf_node(im, force_flat=True)
Expand Down
64 changes: 64 additions & 0 deletions tests/rest/datacube/test_datacube100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,6 +2375,70 @@ def test_load_collection_parameterized_extents(con100, spatial_extent, temporal_
}


@pytest.mark.parametrize(
"udf_factory",
[
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
],
)
def test_load_collection_extents_from_udf(con100, udf_factory):
spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python")
temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python")
cube = con100.load_collection("S2", spatial_extent=spatial_extent, temporal_extent=temporal_extent)
assert get_download_graph(cube, drop_save_result=True) == {
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"},
},
"runudf2": {
"process_id": "run_udf",
"arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"},
},
"loadcollection1": {
"process_id": "load_collection",
"arguments": {
"id": "S2",
"spatial_extent": {"from_node": "runudf1"},
"temporal_extent": {"from_node": "runudf2"},
},
},
}


@pytest.mark.parametrize(
"udf_factory",
[
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
],
)
def test_load_collection_temporal_extent_from_udf(con100, udf_factory):
temporal_extent = [
udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"),
udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"),
]
cube = con100.load_collection("S2", temporal_extent=temporal_extent)
assert get_download_graph(cube, drop_save_result=True) == {
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"},
},
"runudf2": {
"process_id": "run_udf",
"arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"},
},
"loadcollection1": {
"process_id": "load_collection",
"arguments": {
"id": "S2",
"spatial_extent": None,
"temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}],
},
},
}


def test_apply_dimension_temporal_cumsum_with_target(con100, test_data):
cumsum = con100.load_collection("S2").apply_dimension('cumsum', dimension="t", target_dimension="MyNewTime")
actual_graph = cumsum.flat_graph()
Expand Down
68 changes: 68 additions & 0 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import shapely.geometry

import openeo
import openeo.processes
from openeo import BatchJob
from openeo.api.process import Parameter
from openeo.internal.graph_building import FlatGraphableMixin, PGNode
Expand Down Expand Up @@ -3715,6 +3716,73 @@ def test_load_stac_spatial_extent_vector_cube(self, dummy_backend):
},
}

@pytest.mark.parametrize(
"udf_factory",
[
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
],
)
def test_load_stac_extents_from_udf(self, dummy_backend, udf_factory):
spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python")
temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python")
cube = dummy_backend.connection.load_stac(
"https://stac.test/data", spatial_extent=spatial_extent, temporal_extent=temporal_extent
)
cube.execute()
assert dummy_backend.get_sync_pg() == {
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"},
},
"runudf2": {
"process_id": "run_udf",
"arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"},
},
"loadstac1": {
"process_id": "load_stac",
"arguments": {
"url": "https://stac.test/data",
"spatial_extent": {"from_node": "runudf1"},
"temporal_extent": {"from_node": "runudf2"},
},
"result": True,
},
}

@pytest.mark.parametrize(
"udf_factory",
[
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
],
)
def test_load_stac_temporal_extent_from_udf(self, dummy_backend, udf_factory):
temporal_extent = [
udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"),
udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"),
]
cube = dummy_backend.connection.load_stac("https://stac.test/data", temporal_extent=temporal_extent)
cube.execute()
assert dummy_backend.get_sync_pg() == {
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"},
},
"runudf2": {
"process_id": "run_udf",
"arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"},
},
"loadstac1": {
"process_id": "load_stac",
"arguments": {
"url": "https://stac.test/data",
"temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}],
},
"result": True,
},
}


@pytest.mark.parametrize(
"data",
Expand Down