Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions eodag_cube/api/product/_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from eodag_cube.api.product._assets import AssetsDict
from eodag_cube.types import XarrayDict
from eodag_cube.utils.exceptions import DatasetCreationError
from eodag_cube.utils.metadata import build_bands, build_cube_metadata, merge_bands
from eodag_cube.utils.metadata import build_bands, build_stac_metadata, merge_bands
from eodag_cube.utils.xarray import try_open_dataset

logger = logging.getLogger("eodag-cube.api.product")
Expand Down Expand Up @@ -363,15 +363,14 @@ def augment_from_xarray(
if not self.assets:
try:
xd = self.to_xarray(roles=roles)
# single ds in XarrayDict
ds = next(iter(xd.values()))
except Exception:
return self

dimensions, variables, proj_info = build_cube_metadata(xd)
self.properties["cube:dimensions"] = dimensions
self.properties["cube:variables"] = variables
self.properties["bands"] = build_bands(xd)
for key, value in proj_info.items():
self.properties[key] = value
# update product properties
self.properties |= build_stac_metadata(ds)
self.properties["bands"] = build_bands(ds)

else:
# have roles been set in assets ?
Expand All @@ -391,16 +390,15 @@ def augment_from_xarray(
except Exception:
continue

dimensions, variables, proj_info = build_cube_metadata(xd)
asset["cube:dimensions"] = dimensions
asset["cube:variables"] = variables
for key, value in proj_info.items():
asset[key] = value
# single ds in XarrayDict
ds = next(iter(xd.values()))
# update asset metadata
asset |= build_stac_metadata(ds)

has_band_data = any("band_data" in ds.data_vars for ds in xd.values())

if has_band_data:
generated_bands = build_bands(xd)
generated_bands = build_bands(ds)
if "bands" in asset:
asset["bands"] = merge_bands(asset["bands"], generated_bands)
else:
Expand Down
133 changes: 60 additions & 73 deletions eodag_cube/utils/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import numpy as np
from xarray import DataArray, Dataset

from eodag_cube.types import XarrayDict


def extract_projection_info(ds: Dataset) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -126,102 +124,91 @@ def set_variables(ds: Dataset) -> dict[str, Any]:
return variables


def build_cube_metadata(ds_dict: XarrayDict) -> tuple[dict, dict, dict]:
def build_stac_metadata(ds: Dataset) -> dict[str, Any]:
"""
Build datacube and projection metadata from a dict of :class:`xarray.Dataset`.
Read STAC metadata from an xarray dataset.

:param ds_dict: input xarray dict
:return: tuple of 3 dicts for cube dimensions, cube variables and projection info
:param ds: input xarray dataset
:return: STAC metadata dictionary
"""
dimensions: dict[str, dict] = {}
variables: dict[str, dict] = {}

for ds in ds_dict.values():
proj_info: dict[str, Any] = extract_projection_info(ds)

# Dimensions
for dim_name in ds.sizes.keys():
dim_name_str = str(dim_name)

# Type
dim_type = (
"spatial"
if dim_name_str in ("x", "y", "lon", "lat")
else "temporal"
if dim_name_str == "time"
else "other"
)

dim_entry: dict[str, Any] = {"type": dim_type}

if dim_type == "spatial":
# Axis
if dim_name_str in ("x", "lon"):
dim_entry["axis"] = "x"
elif dim_name_str in ("y", "lat"):
dim_entry["axis"] = "y"
elif dim_name_str == "z":
dim_entry["axis"] = "z"

proj_code = proj_info.get("proj:code", "EPSG:4326")
try:
dim_entry["reference_system"] = int(proj_code.split(":")[-1])
except ValueError:
pass

if dim_name_str in ds.coords:
values = ds[dim_name_str].values
if values.ndim == 1:
if values.size <= 10:
dim_entry["values"] = values.tolist()
else:
dim_entry["extent"] = (
[float(values.min()), float(values.max())]
if np.issubdtype(values.dtype, np.number)
else [str(values.min()), str(values.max())]
)
diffs = np.diff(values)
if np.allclose(diffs, diffs[0]):
dim_entry["step"] = (
float(diffs[0]) if np.issubdtype(values.dtype, np.number) else str(diffs[0])
)
proj_info: dict[str, Any] = extract_projection_info(ds)

# Dimensions
for dim_name in ds.sizes.keys():
dim_name_str = str(dim_name)

# Type
dim_type = (
"spatial" if dim_name_str in ("x", "y", "lon", "lat") else "temporal" if dim_name_str == "time" else "other"
)

dim_entry: dict[str, Any] = {"type": dim_type}

if dim_type == "spatial":
# Axis
if dim_name_str in ("x", "lon"):
dim_entry["axis"] = "x"
elif dim_name_str in ("y", "lat"):
dim_entry["axis"] = "y"
elif dim_name_str == "z":
dim_entry["axis"] = "z"

proj_code = proj_info.get("proj:code", "EPSG:4326")
try:
dim_entry["reference_system"] = int(proj_code.split(":")[-1])
except ValueError:
pass

if dim_name_str in ds.coords:
values = ds[dim_name_str].values
if values.ndim == 1:
if values.size <= 10:
dim_entry["values"] = values.tolist()
else:
dim_entry["extent"] = [float(np.nanmin(values)), float(np.nanmax(values))]
dim_entry["extent"] = (
[float(values.min()), float(values.max())]
if np.issubdtype(values.dtype, np.number)
else [str(values.min()), str(values.max())]
)
diffs = np.diff(values)
if np.allclose(diffs, diffs[0]):
dim_entry["step"] = float(diffs[0]) if np.issubdtype(values.dtype, np.number) else str(diffs[0])
else:
dim_entry["extent"] = [float(np.nanmin(values)), float(np.nanmax(values))]

dimensions[dim_name_str] = dim_entry
dimensions[dim_name_str] = dim_entry

# Variables
var_ds = set_variables(ds)
variables.update(var_ds)
# Variables
var_ds = set_variables(ds)
variables.update(var_ds)

return dimensions, variables, proj_info
return {"cube:dimensions": dimensions, "cube:variables": variables, **proj_info}


def build_bands(xd: XarrayDict) -> list[dict]:
def build_bands(ds: Dataset) -> list[dict]:
"""
Build STAC bands metadata from xarray datasets.
Build STAC bands metadata from xarray dataset.

If names are not available, use generic band names.

:param xd: input xarray dict
:param ds: input xarray dataset
:return: list of bands metadata
"""
band_count = 0

for ds in xd.values():
for var in ds.data_vars.values():
for dim in var.dims:
if str(dim).lower() in ("band", "bands"):
band_count = ds.sizes[dim]
break
if band_count:
for var in ds.data_vars.values():
for dim in var.dims:
if str(dim).lower() in ("band", "bands"):
band_count = ds.sizes[dim]
break

if band_count:
break

if band_count == 0:
band_count = len(next(iter(xd.values())).data_vars)
band_count = len(ds.data_vars)

return [{"name": f"band{i + 1}"} for i in range(band_count)]

Expand Down
86 changes: 51 additions & 35 deletions tests/units/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import xarray as xr
from fsspec.core import OpenFile

from eodag_cube.types import XarrayDict
from eodag_cube.utils import metadata
from tests.context import (
DatasetCreationError,
Expand Down Expand Up @@ -415,47 +414,66 @@ def __init__(self):
var_without_rio = MockVar()
self.assertIsNone(metadata._get_nodata_value(var_without_rio))

def test_build_cube_metadata(self):
def test_build_stac_metadata_1d(self):
"""Test cube dimensions, variables and projection metadata"""

dims, vars_, proj_info = metadata.build_cube_metadata(self.xd_dict)
stac_mtd = metadata.build_stac_metadata(self.ds_1d)

self.assertIn("x", dims)
self.assertIn("y", dims)
self.assertIn("time", dims)
self.assertIn("x", stac_mtd["cube:dimensions"])
self.assertIn("time", stac_mtd["cube:dimensions"])

self.assertEqual(dims["x"]["type"], "spatial")
self.assertEqual(dims["x"]["axis"], "x")
self.assertEqual(stac_mtd["cube:dimensions"]["x"]["type"], "spatial")
self.assertEqual(stac_mtd["cube:dimensions"]["x"]["axis"], "x")

self.assertEqual(dims["y"]["type"], "spatial")
self.assertEqual(dims["y"]["axis"], "y")
self.assertEqual(stac_mtd["cube:dimensions"]["time"]["type"], "temporal")

self.assertEqual(dims["time"]["type"], "temporal")
if "extent" in stac_mtd["cube:dimensions"]["x"]:
self.assertEqual(len(stac_mtd["cube:dimensions"]["x"]["extent"]), 2)

if "extent" in dims["x"]:
self.assertEqual(len(dims["x"]["extent"]), 2)
if "step" in stac_mtd["cube:dimensions"]["x"]:
self.assertIsInstance(stac_mtd["cube:dimensions"]["x"]["step"], (int, float))
self.assertIn("band_data", stac_mtd["cube:variables"])
self.assertEqual(stac_mtd["cube:variables"]["band_data"]["type"], "data")

if "step" in dims["x"]:
self.assertIsInstance(dims["x"]["step"], (int, float))
self.assertIn("latitude", stac_mtd["cube:variables"])
self.assertIn("longitude", stac_mtd["cube:variables"])

self.assertIn("band_data", vars_)
self.assertEqual(vars_["band_data"]["type"], "data")
self.assertEqual(stac_mtd["cube:variables"]["latitude"]["type"], "auxiliary")
self.assertEqual(stac_mtd["cube:variables"]["longitude"]["type"], "auxiliary")

self.assertIn("latitude", vars_)
self.assertIn("longitude", vars_)
self.assertEqual(stac_mtd["cube:variables"]["latitude"]["description"], "Latitude")
self.assertEqual(stac_mtd["cube:variables"]["longitude"]["description"], "Longitude")

self.assertEqual(vars_["latitude"]["type"], "auxiliary")
self.assertEqual(vars_["longitude"]["type"], "auxiliary")
self.assertIn("dimensions", stac_mtd["cube:variables"]["latitude"])
self.assertIsInstance(stac_mtd["cube:variables"]["latitude"]["dimensions"], list)
self.assertEqual(stac_mtd["proj:code"], "EPSG:4326")
self.assertNotIn("proj:shape", stac_mtd)

self.assertEqual(vars_["latitude"]["description"], "Latitude")
self.assertEqual(vars_["longitude"]["description"], "Longitude")
def test_build_stac_metadata_2d(self):
"""Test cube dimensions, variables and projection metadata"""

self.assertIn("dimensions", vars_["latitude"])
self.assertIsInstance(vars_["latitude"]["dimensions"], list)
stac_mtd = metadata.build_stac_metadata(self.ds_2d)

self.assertEqual(proj_info["proj:code"], "EPSG:4326")
self.assertIn("proj:shape", proj_info)
self.assertEqual(proj_info["proj:shape"], [2, 2])
self.assertIn("x", stac_mtd["cube:dimensions"])
self.assertIn("y", stac_mtd["cube:dimensions"])

self.assertEqual(stac_mtd["cube:dimensions"]["x"]["type"], "spatial")
self.assertEqual(stac_mtd["cube:dimensions"]["x"]["axis"], "x")

self.assertEqual(stac_mtd["cube:dimensions"]["y"]["type"], "spatial")
self.assertEqual(stac_mtd["cube:dimensions"]["y"]["axis"], "y")

if "extent" in stac_mtd["cube:dimensions"]["x"]:
self.assertEqual(len(stac_mtd["cube:dimensions"]["x"]["extent"]), 2)

if "step" in stac_mtd["cube:dimensions"]["x"]:
self.assertIsInstance(stac_mtd["cube:dimensions"]["x"]["step"], (int, float))
self.assertIn("band_data", stac_mtd["cube:variables"])
self.assertEqual(stac_mtd["cube:variables"]["band_data"]["type"], "data")

self.assertEqual(stac_mtd["proj:code"], "EPSG:4326")
self.assertIn("proj:shape", stac_mtd)
self.assertEqual(stac_mtd["proj:shape"], [2, 2])

def test_aux_variable_not_added_if_dimension(self):
"""latitude/longitude must not be added as auxiliary if they are dimensions"""
Expand All @@ -465,16 +483,14 @@ def test_aux_variable_not_added_if_dimension(self):
coords={"latitude": [10, 20, 30]},
)

xd = XarrayDict({"test": ds})

dims, vars_, _ = metadata.build_cube_metadata(xd)
stac_mtd = metadata.build_stac_metadata(ds)

self.assertIn("latitude", dims)
self.assertNotIn("latitude", vars_)
self.assertIn("latitude", stac_mtd["cube:dimensions"])
self.assertNotIn("latitude", stac_mtd["cube:variables"])

def test_build_bands(self):
"""Test bands generation"""
bands = metadata.build_bands({"ds": self.ds_1d})
bands = metadata.build_bands(self.ds_1d)
self.assertEqual(len(bands), 4)
self.assertTrue(all("name" in b for b in bands))

Expand All @@ -490,7 +506,7 @@ def test_merge_bands(self):
def test_build_bands_no_band_dim(self):
"""If no 'band' dimension, should fallback to number of data_vars"""
ds_simple = xr.Dataset({"a": (("x",), [1, 2]), "b": (("x",), [3, 4])})
bands = metadata.build_bands({"simple": ds_simple})
bands = metadata.build_bands(ds_simple)
self.assertEqual(len(bands), 2)
self.assertEqual(bands[0]["name"], "band1")
self.assertEqual(bands[1]["name"], "band2")
Loading