diff --git a/eodag_cube/api/product/_product.py b/eodag_cube/api/product/_product.py index c18ef68..680c78f 100644 --- a/eodag_cube/api/product/_product.py +++ b/eodag_cube/api/product/_product.py @@ -46,7 +46,7 @@ from eodag_cube.api.product._assets import AssetsDict from eodag_cube.types import XarrayDict from eodag_cube.utils.exceptions import DatasetCreationError -from eodag_cube.utils.metadata import build_bands, build_cube_metadata, merge_bands +from eodag_cube.utils.metadata import build_bands, build_stac_metadata, merge_bands from eodag_cube.utils.xarray import try_open_dataset logger = logging.getLogger("eodag-cube.api.product") @@ -363,15 +363,14 @@ def augment_from_xarray( if not self.assets: try: xd = self.to_xarray(roles=roles) + # single ds in XarrayDict + ds = next(iter(xd.values())) except Exception: return self - dimensions, variables, proj_info = build_cube_metadata(xd) - self.properties["cube:dimensions"] = dimensions - self.properties["cube:variables"] = variables - self.properties["bands"] = build_bands(xd) - for key, value in proj_info.items(): - self.properties[key] = value + # update product properties + self.properties |= build_stac_metadata(ds) + self.properties["bands"] = build_bands(ds) else: # have roles been set in assets ? @@ -391,16 +390,15 @@ def augment_from_xarray( except Exception: continue - dimensions, variables, proj_info = build_cube_metadata(xd) - asset["cube:dimensions"] = dimensions - asset["cube:variables"] = variables - for key, value in proj_info.items(): - asset[key] = value + # single ds in XarrayDict + ds = next(iter(xd.values())) + # update asset metadata + asset |= build_stac_metadata(ds) has_band_data = any("band_data" in ds.data_vars for ds in xd.values()) if has_band_data: - generated_bands = build_bands(xd) + generated_bands = build_bands(ds) if "bands" in asset: asset["bands"] = merge_bands(asset["bands"], generated_bands) else: diff --git a/eodag_cube/utils/metadata.py b/eodag_cube/utils/metadata.py index fad0f31..50cc4d8 100644 --- a/eodag_cube/utils/metadata.py +++ b/eodag_cube/utils/metadata.py @@ -23,8 +23,6 @@ import numpy as np from xarray import DataArray, Dataset -from eodag_cube.types import XarrayDict - def extract_projection_info(ds: Dataset) -> dict[str, Any]: """ @@ -126,102 +124,91 @@ def set_variables(ds: Dataset) -> dict[str, Any]: return variables -def build_cube_metadata(ds_dict: XarrayDict) -> tuple[dict, dict, dict]: +def build_stac_metadata(ds: Dataset) -> dict[str, Any]: """ - Build datacube and projection metadata from a dict of :class:`xarray.Dataset`. + Read STAC metadata from an xarray dataset. - :param ds_dict: input xarray dict - :return: tuple of 3 dicts for cube dimensions, cube variables and projection info + :param ds: input xarray dataset + :return: STAC metadata dictionary """ dimensions: dict[str, dict] = {} variables: dict[str, dict] = {} - for ds in ds_dict.values(): - proj_info: dict[str, Any] = extract_projection_info(ds) - - # Dimensions - for dim_name in ds.sizes.keys(): - dim_name_str = str(dim_name) - - # Type - dim_type = ( - "spatial" - if dim_name_str in ("x", "y", "lon", "lat") - else "temporal" - if dim_name_str == "time" - else "other" - ) - - dim_entry: dict[str, Any] = {"type": dim_type} - - if dim_type == "spatial": - # Axis - if dim_name_str in ("x", "lon"): - dim_entry["axis"] = "x" - elif dim_name_str in ("y", "lat"): - dim_entry["axis"] = "y" - elif dim_name_str == "z": - dim_entry["axis"] = "z" - - proj_code = proj_info.get("proj:code", "EPSG:4326") - try: - dim_entry["reference_system"] = int(proj_code.split(":")[-1]) - except ValueError: - pass - - if dim_name_str in ds.coords: - values = ds[dim_name_str].values - if values.ndim == 1: - if values.size <= 10: - dim_entry["values"] = values.tolist() - else: - dim_entry["extent"] = ( - [float(values.min()), float(values.max())] - if np.issubdtype(values.dtype, np.number) - else [str(values.min()), str(values.max())] - ) - diffs = np.diff(values) - if np.allclose(diffs, diffs[0]): - dim_entry["step"] = ( - float(diffs[0]) if np.issubdtype(values.dtype, np.number) else str(diffs[0]) - ) + proj_info: dict[str, Any] = extract_projection_info(ds) + + # Dimensions + for dim_name in ds.sizes.keys(): + dim_name_str = str(dim_name) + + # Type + dim_type = ( + "spatial" if dim_name_str in ("x", "y", "lon", "lat") else "temporal" if dim_name_str == "time" else "other" + ) + + dim_entry: dict[str, Any] = {"type": dim_type} + + if dim_type == "spatial": + # Axis + if dim_name_str in ("x", "lon"): + dim_entry["axis"] = "x" + elif dim_name_str in ("y", "lat"): + dim_entry["axis"] = "y" + elif dim_name_str == "z": + dim_entry["axis"] = "z" + + proj_code = proj_info.get("proj:code", "EPSG:4326") + try: + dim_entry["reference_system"] = int(proj_code.split(":")[-1]) + except ValueError: + pass + + if dim_name_str in ds.coords: + values = ds[dim_name_str].values + if values.ndim == 1: + if values.size <= 10: + dim_entry["values"] = values.tolist() else: - dim_entry["extent"] = [float(np.nanmin(values)), float(np.nanmax(values))] + dim_entry["extent"] = ( + [float(values.min()), float(values.max())] + if np.issubdtype(values.dtype, np.number) + else [str(values.min()), str(values.max())] + ) + diffs = np.diff(values) + if np.allclose(diffs, diffs[0]): + dim_entry["step"] = float(diffs[0]) if np.issubdtype(values.dtype, np.number) else str(diffs[0]) + else: + dim_entry["extent"] = [float(np.nanmin(values)), float(np.nanmax(values))] - dimensions[dim_name_str] = dim_entry + dimensions[dim_name_str] = dim_entry - # Variables - var_ds = set_variables(ds) - variables.update(var_ds) + # Variables + var_ds = set_variables(ds) + variables.update(var_ds) - return dimensions, variables, proj_info + return {"cube:dimensions": dimensions, "cube:variables": variables, **proj_info} -def build_bands(xd: XarrayDict) -> list[dict]: +def build_bands(ds: Dataset) -> list[dict]: """ - Build STAC bands metadata from xarray datasets. + Build STAC bands metadata from xarray dataset. If names are not available, use generic band names. - :param xd: input xarray dict + :param ds: input xarray dataset :return: list of bands metadata """ band_count = 0 - for ds in xd.values(): - for var in ds.data_vars.values(): - for dim in var.dims: - if str(dim).lower() in ("band", "bands"): - band_count = ds.sizes[dim] - break - if band_count: + for var in ds.data_vars.values(): + for dim in var.dims: + if str(dim).lower() in ("band", "bands"): + band_count = ds.sizes[dim] break - if band_count: break if band_count == 0: - band_count = len(next(iter(xd.values())).data_vars) + band_count = len(ds.data_vars) return [{"name": f"band{i + 1}"} for i in range(band_count)] diff --git a/tests/units/test_utils.py b/tests/units/test_utils.py index f94fe5e..0756385 100644 --- a/tests/units/test_utils.py +++ b/tests/units/test_utils.py @@ -25,7 +25,6 @@ import xarray as xr from fsspec.core import OpenFile -from eodag_cube.types import XarrayDict from eodag_cube.utils import metadata from tests.context import ( DatasetCreationError, @@ -415,47 +414,66 @@ def __init__(self): var_without_rio = MockVar() self.assertIsNone(metadata._get_nodata_value(var_without_rio)) - def test_build_cube_metadata(self): + def test_build_stac_metadata_1d(self): """Test cube dimensions, variables and projection metadata""" - dims, vars_, proj_info = metadata.build_cube_metadata(self.xd_dict) + stac_mtd = metadata.build_stac_metadata(self.ds_1d) - self.assertIn("x", dims) - self.assertIn("y", dims) - self.assertIn("time", dims) + self.assertIn("x", stac_mtd["cube:dimensions"]) + self.assertIn("time", stac_mtd["cube:dimensions"]) - self.assertEqual(dims["x"]["type"], "spatial") - self.assertEqual(dims["x"]["axis"], "x") + self.assertEqual(stac_mtd["cube:dimensions"]["x"]["type"], "spatial") + self.assertEqual(stac_mtd["cube:dimensions"]["x"]["axis"], "x") - self.assertEqual(dims["y"]["type"], "spatial") - self.assertEqual(dims["y"]["axis"], "y") + self.assertEqual(stac_mtd["cube:dimensions"]["time"]["type"], "temporal") - self.assertEqual(dims["time"]["type"], "temporal") + if "extent" in stac_mtd["cube:dimensions"]["x"]: + self.assertEqual(len(stac_mtd["cube:dimensions"]["x"]["extent"]), 2) - if "extent" in dims["x"]: - self.assertEqual(len(dims["x"]["extent"]), 2) + if "step" in stac_mtd["cube:dimensions"]["x"]: + self.assertIsInstance(stac_mtd["cube:dimensions"]["x"]["step"], (int, float)) + self.assertIn("band_data", stac_mtd["cube:variables"]) + self.assertEqual(stac_mtd["cube:variables"]["band_data"]["type"], "data") - if "step" in dims["x"]: - self.assertIsInstance(dims["x"]["step"], (int, float)) + self.assertIn("latitude", stac_mtd["cube:variables"]) + self.assertIn("longitude", stac_mtd["cube:variables"]) - self.assertIn("band_data", vars_) - self.assertEqual(vars_["band_data"]["type"], "data") + self.assertEqual(stac_mtd["cube:variables"]["latitude"]["type"], "auxiliary") + self.assertEqual(stac_mtd["cube:variables"]["longitude"]["type"], "auxiliary") - self.assertIn("latitude", vars_) - self.assertIn("longitude", vars_) + self.assertEqual(stac_mtd["cube:variables"]["latitude"]["description"], "Latitude") + self.assertEqual(stac_mtd["cube:variables"]["longitude"]["description"], "Longitude") - self.assertEqual(vars_["latitude"]["type"], "auxiliary") - self.assertEqual(vars_["longitude"]["type"], "auxiliary") + self.assertIn("dimensions", stac_mtd["cube:variables"]["latitude"]) + self.assertIsInstance(stac_mtd["cube:variables"]["latitude"]["dimensions"], list) + self.assertEqual(stac_mtd["proj:code"], "EPSG:4326") + self.assertNotIn("proj:shape", stac_mtd) - self.assertEqual(vars_["latitude"]["description"], "Latitude") - self.assertEqual(vars_["longitude"]["description"], "Longitude") + def test_build_stac_metadata_2d(self): + """Test cube dimensions, variables and projection metadata""" - self.assertIn("dimensions", vars_["latitude"]) - self.assertIsInstance(vars_["latitude"]["dimensions"], list) + stac_mtd = metadata.build_stac_metadata(self.ds_2d) - self.assertEqual(proj_info["proj:code"], "EPSG:4326") - self.assertIn("proj:shape", proj_info) - self.assertEqual(proj_info["proj:shape"], [2, 2]) + self.assertIn("x", stac_mtd["cube:dimensions"]) + self.assertIn("y", stac_mtd["cube:dimensions"]) + + self.assertEqual(stac_mtd["cube:dimensions"]["x"]["type"], "spatial") + self.assertEqual(stac_mtd["cube:dimensions"]["x"]["axis"], "x") + + self.assertEqual(stac_mtd["cube:dimensions"]["y"]["type"], "spatial") + self.assertEqual(stac_mtd["cube:dimensions"]["y"]["axis"], "y") + + if "extent" in stac_mtd["cube:dimensions"]["x"]: + self.assertEqual(len(stac_mtd["cube:dimensions"]["x"]["extent"]), 2) + + if "step" in stac_mtd["cube:dimensions"]["x"]: + self.assertIsInstance(stac_mtd["cube:dimensions"]["x"]["step"], (int, float)) + self.assertIn("band_data", stac_mtd["cube:variables"]) + self.assertEqual(stac_mtd["cube:variables"]["band_data"]["type"], "data") + + self.assertEqual(stac_mtd["proj:code"], "EPSG:4326") + self.assertIn("proj:shape", stac_mtd) + self.assertEqual(stac_mtd["proj:shape"], [2, 2]) def test_aux_variable_not_added_if_dimension(self): """latitude/longitude must not be added as auxiliary if they are dimensions""" @@ -465,16 +483,14 @@ def test_aux_variable_not_added_if_dimension(self): coords={"latitude": [10, 20, 30]}, ) - xd = XarrayDict({"test": ds}) - - dims, vars_, _ = metadata.build_cube_metadata(xd) + stac_mtd = metadata.build_stac_metadata(ds) - self.assertIn("latitude", dims) - self.assertNotIn("latitude", vars_) + self.assertIn("latitude", stac_mtd["cube:dimensions"]) + self.assertNotIn("latitude", stac_mtd["cube:variables"]) def test_build_bands(self): """Test bands generation""" - bands = metadata.build_bands({"ds": self.ds_1d}) + bands = metadata.build_bands(self.ds_1d) self.assertEqual(len(bands), 4) self.assertTrue(all("name" in b for b in bands)) @@ -490,7 +506,7 @@ def test_merge_bands(self): def test_build_bands_no_band_dim(self): """If no 'band' dimension, should fallback to number of data_vars""" ds_simple = xr.Dataset({"a": (("x",), [1, 2]), "b": (("x",), [3, 4])}) - bands = metadata.build_bands({"simple": ds_simple}) + bands = metadata.build_bands(ds_simple) self.assertEqual(len(bands), 2) self.assertEqual(bands[0]["name"], "band1") self.assertEqual(bands[1]["name"], "band2")