Skip to content

Commit a907e38

Browse files
committed
Disallow empty bands array in load_collection/load_stac
and add support for parameterized bands in `load_stac` refs: #424, Open-EO/openeo-processes#372
1 parent ddd2185 commit a907e38

File tree

5 files changed

+105
-10
lines changed

5 files changed

+105
-10
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Added `Connection.web_editor()` to build link to the openEO backend in the openEO Web Editor
1414
- Add support for `log_level` in `create_job()` and `execute_job()` ([#704](https://github.com/Open-EO/openeo-python-client/issues/704))
1515
- Add initial support for "geometry" dimension type in `CubeMetadata` ([#705](https://github.com/Open-EO/openeo-python-client/issues/705))
16+
- Add support for parameterized `bands` argument in `load_stac()`
1617

1718
### Changed
1819

20+
- Raise exception when providing empty bands array to `load_collection`/`load_stac` ([#424](https://github.com/Open-EO/openeo-python-client/issues/424), [Open-EO/openeo-processes#372](https://github.com/Open-EO/openeo-processes/issues/372))
21+
1922
### Removed
2023

2124
### Fixed

openeo/rest/connection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ def load_collection(
12581258
collection_id: Union[str, Parameter],
12591259
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
12601260
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
1261-
bands: Union[None, List[str], Parameter] = None,
1261+
bands: Union[Iterable[str], Parameter, str, None] = None,
12621262
properties: Union[
12631263
None, Dict[str, Union[str, PGNode, Callable]], List[CollectionProperty], CollectionProperty
12641264
] = None,
@@ -1348,7 +1348,7 @@ def load_stac(
13481348
url: str,
13491349
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
13501350
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
1351-
bands: Optional[List[str]] = None,
1351+
bands: Union[Iterable[str], Parameter, str, None] = None,
13521352
properties: Optional[Dict[str, Union[str, PGNode, Callable]]] = None,
13531353
) -> DataCube:
13541354
"""

openeo/rest/datacube.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def load_collection(
145145
connection: Optional[Connection] = None,
146146
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
147147
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
148-
bands: Union[None, List[str], Parameter] = None,
148+
bands: Union[Iterable[str], Parameter, str, None] = None,
149149
fetch_metadata: bool = True,
150150
properties: Union[
151151
None, Dict[str, Union[str, PGNode, typing.Callable]], List[CollectionProperty], CollectionProperty
@@ -198,10 +198,9 @@ def load_collection(
198198
metadata: Optional[CollectionMetadata] = (
199199
connection.collection_metadata(collection_id) if connection and fetch_metadata else None
200200
)
201-
if bands:
202-
if isinstance(bands, str):
203-
bands = [bands]
204-
elif isinstance(bands, Parameter):
201+
if bands is not None:
202+
bands = cls._get_bands(bands, process_id="load_collection")
203+
if isinstance(bands, Parameter):
205204
metadata = None
206205
if metadata:
207206
bands = [b if isinstance(b, str) else metadata.band_dimension.band_name(b) for b in bands]
@@ -272,7 +271,7 @@ def load_stac(
272271
url: str,
273272
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
274273
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
275-
bands: Optional[List[str]] = None,
274+
bands: Union[Iterable[str], Parameter, str, None] = None,
276275
properties: Optional[Dict[str, Union[str, PGNode, Callable]]] = None,
277276
connection: Optional[Connection] = None,
278277
) -> DataCube:
@@ -379,7 +378,8 @@ def load_stac(
379378
arguments["spatial_extent"] = spatial_extent
380379
if temporal_extent:
381380
arguments["temporal_extent"] = DataCube._get_temporal_extent(extent=temporal_extent)
382-
if bands:
381+
bands = cls._get_bands(bands, process_id="load_stac")
382+
if bands is not None:
383383
arguments["bands"] = bands
384384
if properties:
385385
arguments["properties"] = {
@@ -388,7 +388,7 @@ def load_stac(
388388
graph = PGNode("load_stac", arguments=arguments)
389389
try:
390390
metadata = metadata_from_stac(url)
391-
if bands:
391+
if isinstance(bands, list):
392392
# TODO: also apply spatial/temporal filters to metadata?
393393
metadata = metadata.filter_bands(band_names=bands)
394394
except Exception:
@@ -429,6 +429,24 @@ def convertor(d: Any) -> Any:
429429
get_temporal_extent(*args, start_date=start_date, end_date=end_date, extent=extent, convertor=convertor)
430430
)
431431

432+
@staticmethod
433+
def _get_bands(
434+
bands: Union[Iterable[str], Parameter, str, None], process_id: str
435+
) -> Union[None, List[str], Parameter]:
436+
"""Normalize band array for processes like load_collection, load_stac"""
437+
if bands is None:
438+
pass
439+
elif isinstance(bands, str):
440+
bands = [bands]
441+
elif isinstance(bands, Parameter):
442+
pass
443+
else:
444+
# Coerce to list
445+
bands = list(bands)
446+
if len(bands) == 0:
447+
raise OpenEoClientException(f"Bands array should not be empty (process {process_id!r})")
448+
return bands
449+
432450
@openeo_process
433451
def filter_temporal(
434452
self,

tests/rest/datacube/test_datacube100.py

+28
Original file line numberDiff line numberDiff line change
@@ -2308,6 +2308,34 @@ def test_load_collection_parameterized_bands(con100):
23082308
}
23092309

23102310

2311+
@pytest.mark.parametrize(
2312+
"bands",
2313+
[
2314+
["B02", "B03"],
2315+
("B02", "B03"),
2316+
iter(["B02", "B03"]),
2317+
],
2318+
)
2319+
def test_load_collection_bands_iterable(con100, bands):
2320+
cube = con100.load_collection("S2", bands=bands)
2321+
assert get_download_graph(cube, drop_save_result=True) == {
2322+
"loadcollection1": {
2323+
"process_id": "load_collection",
2324+
"arguments": {
2325+
"id": "S2",
2326+
"spatial_extent": None,
2327+
"temporal_extent": None,
2328+
"bands": ["B02", "B03"],
2329+
},
2330+
},
2331+
}
2332+
2333+
2334+
def test_load_collection_empty_bands_array(con100):
2335+
with pytest.raises(OpenEoClientException, match="Bands array should not be empty"):
2336+
_ = con100.load_collection("S2", bands=[])
2337+
2338+
23112339
@pytest.mark.parametrize(
23122340
["spatial_extent", "temporal_extent", "spatial_name", "temporal_name"],
23132341
[

tests/rest/test_connection.py

+46
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import openeo
1919
from openeo import BatchJob
20+
from openeo.api.process import Parameter
2021
from openeo.capabilities import ApiVersionException
2122
from openeo.internal.graph_building import FlatGraphableMixin, PGNode
2223
from openeo.metadata import _PYSTAC_1_9_EXTENSION_INTERFACE, TemporalDimension
@@ -2681,6 +2682,51 @@ def test_load_stac_band_filtering(self, con120, tmp_path):
26812682
cube = con120.load_stac(str(stac_path), bands=["B03", "B02"])
26822683
assert cube.metadata.band_names == ["B03", "B02"]
26832684

2685+
@pytest.mark.parametrize(
2686+
"bands",
2687+
[
2688+
["B02", "B03"],
2689+
("B02", "B03"),
2690+
iter(["B02", "B03"]),
2691+
],
2692+
)
2693+
def test_bands_iterable(self, con120, bands):
2694+
cube = con120.load_stac(
2695+
"https://provider.test/dataset",
2696+
bands=bands,
2697+
)
2698+
assert cube.flat_graph() == {
2699+
"loadstac1": {
2700+
"process_id": "load_stac",
2701+
"arguments": {
2702+
"url": "https://provider.test/dataset",
2703+
"bands": ["B02", "B03"],
2704+
},
2705+
"result": True,
2706+
}
2707+
}
2708+
2709+
def test_bands_empty(self, con120):
2710+
with pytest.raises(OpenEoClientException, match="Bands array should not be empty"):
2711+
_ = con120.load_stac("https://provider.test/dataset", bands=[])
2712+
2713+
def test_bands_parameterized(self, con120):
2714+
bands = Parameter(name="my_bands", schema={"type": "array", "items": {"type": "string"}})
2715+
cube = con120.load_stac(
2716+
"https://provider.test/dataset",
2717+
bands=bands,
2718+
)
2719+
assert cube.flat_graph() == {
2720+
"loadstac1": {
2721+
"process_id": "load_stac",
2722+
"arguments": {
2723+
"url": "https://provider.test/dataset",
2724+
"bands": {"from_parameter": "my_bands"},
2725+
},
2726+
"result": True,
2727+
}
2728+
}
2729+
26842730

26852731
@pytest.mark.parametrize(
26862732
"data",

0 commit comments

Comments
 (0)