Skip to content
198 changes: 103 additions & 95 deletions openeo_processes_dask/process_implementations/cubes/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,14 @@ def load_stac(
) -> RasterCube:
stac_type = _validate_stac(url)

# TODO: load_stac should have a parameter to enable scale and offset?

# If the user provide the bands list as a single string, wrap it in a list:
if isinstance(bands, str):
bands = [bands]

if stac_type == "COLLECTION":
# If query parameters are passed, try to get the parent Catalog if possible/exists, to use the /search endpoint
if spatial_extent or temporal_extent or bands or properties:
# If query parameters are passed, try to get the parent Catalog if possible/exists, to use the /search endpoint
catalog_url, collection_id = _search_for_parent_catalog(url)

# Check if we are connecting to Microsoft Planetary Computer, where we need to sign the connection
modifier = pc.sign_inplace if "planetarycomputer" in catalog_url else None

catalog = pystac_client.Client.open(catalog_url, modifier=modifier)

query_params = {"collections": [collection_id]}

if spatial_extent is not None:
Expand All @@ -132,117 +123,147 @@ def load_stac(
raise Exception(f"Unable to parse the provided spatial extent: {e}")

if temporal_extent is not None:
start_date = None
end_date = None
if temporal_extent[0] is not None:
start_date = str(temporal_extent[0].to_numpy())
if temporal_extent[1] is not None:
end_date = str(temporal_extent[1].to_numpy())
start_date = (
str(temporal_extent[0].to_numpy())
if temporal_extent[0] is not None
else None
)
end_date = (
str(temporal_extent[1].to_numpy())
if temporal_extent[1] is not None
else None
)
query_params["datetime"] = [start_date, end_date]

if properties is not None:
query_params["query"] = properties

items = catalog.search(**query_params).item_collection()

else:
# Load the whole collection wihout filters
raise Exception(
f"No parameters for filtering provided. Loading the whole STAC Collection is not supported yet."
"No parameters for filtering provided. Loading the whole STAC Collection is not supported yet."
)
elif stac_type == "ITEM":
stac_api = pystac_client.stac_api_io.StacApiIO()
stac_dict = json.loads(stac_api.read_text(url))
items = [stac_api.stac_object_from_dict(stac_dict)]
else:
raise Exception(
f"The provided URL is a STAC {stac_type}, which is not yet supported. Please provide a valid URL to a STAC Collection or Item."
f"The provided URL is a STAC {stac_type}, which is not yet supported."
)

available_assets = {tuple(i.assets.keys()) for i in items}
if (len(available_assets)) > 1:
raise OpenEOException(
f"The resulting STAC Items contain two separate set of assets: {available_assets}. We can't load them at the same time."
)

available_assets = [x for t in available_assets for x in t]
if len(set(available_assets) & set(bands)) == 0:
raise OpenEOException(
f"The provided bands: {bands} can't be found in the STAC assets: {available_assets}"
)
reference_system = None
# Check if the reference system is available under properties with the datacube extension
item_dict = items[0].to_dict()
if "properties" in item_dict:
if "cube:dimensions" in item_dict["properties"]:
for d in item_dict["properties"]["cube:dimensions"]:
if "reference_system" in item_dict["properties"]["cube:dimensions"][d]:
reference_system = item_dict["properties"]["cube:dimensions"][d][
"reference_system"
]
break

# Initialize asset metadata tracking
asset_scale_offset = {}
zarr_assets = False
use_xarray_open_kwargs = False
use_xarray_storage_options = False

for asset in available_assets:
if asset in bands:
asset_scale = 1
asset_offset = 0
asset_nodata = None
asset_dtype = None
asset_type = None
asset_dict = items[0].assets[asset].to_dict()
if "raster:bands" in asset_dict:
asset_scale = asset_dict["raster:bands"][0].get("scale", 1)
asset_offset = asset_dict["raster:bands"][0].get("offset", 0)
asset_nodata = asset_dict["raster:bands"][0].get("nodata", None)
asset_dtype = asset_dict["raster:bands"][0].get("data_type", None)
if "type" in asset_dict:
asset_type = asset_dict["type"]
if asset_type == "application/vnd+zarr":
zarr_assets = True
if "xarray:open_kwargs" in asset_dict:
use_xarray_open_kwargs = True
asset_scale_offset[asset] = {
"scale": asset_scale,
"offset": asset_offset,
"nodata": asset_nodata,
"data_type": asset_dtype,
"type": asset_type,
}
asset_dict = items[0].assets[asset].to_dict()
asset_scale = 1
asset_offset = 0
asset_nodata = None
asset_dtype = None
asset_type = None

if "raster:bands" in asset_dict:
asset_scale = asset_dict["raster:bands"][0].get("scale", 1)
asset_offset = asset_dict["raster:bands"][0].get("offset", 0)
asset_nodata = asset_dict["raster:bands"][0].get("nodata", None)
asset_dtype = asset_dict["raster:bands"][0].get("data_type", None)

if "type" in asset_dict:
asset_type = asset_dict["type"]
if asset_type == "application/vnd+zarr":
zarr_assets = True

if "xarray:open_kwargs" in asset_dict:
use_xarray_open_kwargs = True
if "xarray:storage_options" in asset_dict:
use_xarray_storage_options = True

asset_scale_offset[asset] = {
"scale": asset_scale,
"offset": asset_offset,
"nodata": asset_nodata,
"data_type": asset_dtype,
"type": asset_type,
}

item_dict = items[0].to_dict() if items else {}
available_variables = []
if "properties" in item_dict and "cube:variables" in item_dict["properties"]:
available_variables = list(item_dict["properties"]["cube:variables"].keys())

if bands is not None:
if zarr_assets and available_variables:
missing_bands = set(bands) - set(available_variables)
if missing_bands:
raise OpenEOException(
f"The following requested bands were not found: {missing_bands}. "
f"Available bands are: {available_variables}"
)
else:
if len(set(available_assets) & set(bands)) == 0:
raise OpenEOException(
f"The provided bands: {bands} can't be found in the STAC assets: {available_assets}"
)

reference_system = None
if "properties" in item_dict and "cube:dimensions" in item_dict["properties"]:
for d in item_dict["properties"]["cube:dimensions"]:
if "reference_system" in item_dict["properties"]["cube:dimensions"][d]:
reference_system = item_dict["properties"]["cube:dimensions"][d][
"reference_system"
]
break

if zarr_assets:
if use_xarray_open_kwargs:
datasets = [
xr.open_dataset(asset.href, **asset.extra_fields["xarray:open_kwargs"])
for item in items
for asset in item.assets.values()
if 1
# if any(b in asset.href for b in bands)
]
else:
datasets = [
xr.open_dataset(asset.href, engine="zarr", consolidated=True, chunks={})
for item in items
for asset in item.assets.values()
if 1
# if any(b in asset.href for b in bands)
]
datasets = []
for item in items:
for asset in item.assets.values():
kwargs = (
asset.extra_fields.get("xarray:open_kwargs", {})
if use_xarray_open_kwargs
else {"engine": "zarr", "consolidated": True, "chunks": {}}
)

if use_xarray_storage_options:
storage_opts = asset.extra_fields.get("xarray:storage_options", {})
s3_endpoint_url = storage_opts.get("client_kwargs", {}).get(
"endpoint_url"
)
if s3_endpoint_url is not None:
kwargs["storage_options"] = {
"client_kwargs": {"endpoint_url": s3_endpoint_url}
}

ds = xr.open_dataset(asset.href, **kwargs)
if bands is not None and available_variables:
vars_to_load = [b for b in bands if b in ds.data_vars]
ds = ds[vars_to_load]
datasets.append(ds)

stack = xr.combine_by_coords(
datasets, join="exact", combine_attrs="drop_conflicts"
)
if not stack.rio.crs:
stack.rio.write_crs(reference_system, inplace=True)
# TODO: now drop data which consist in dates. Probably we should allow it if not conflicitng with other data types.
for d in stack.data_vars:
if "datetime" in str(stack[d].dtype):
stack = stack.drop(d)
stack = stack.to_dataarray(dim="bands")
stack = stack.to_array(dim="bands")
else:
# If at least one band has the nodata field set, we have to apply it at loading time
apply_nodata = True
nodata_set = {asset_scale_offset[k]["nodata"] for k in asset_scale_offset}
dtype_set = {asset_scale_offset[k]["data_type"] for k in asset_scale_offset}
kwargs = {}

if resolution is not None:
kwargs["resolution"] = resolution
if projection is not None:
Expand All @@ -253,14 +274,11 @@ def load_stac(
if len(nodata_set) == 1 and list(nodata_set)[0] == None:
apply_nodata = False
if apply_nodata:
# We can pass only a single nodata value for all the assets/variables/bands https://github.com/opendatacube/odc-stac/issues/147#issuecomment-2005315438
# Therefore, if we load multiple assets having different nodata values, the first one will be used
kwargs["nodata"] = list(nodata_set)[0]
dtype = list(dtype_set)[0]
if dtype is not None:
kwargs["nodata"] = np.dtype(dtype).type(kwargs["nodata"])
# TODO: the dimension names (like "bands") should come from the STAC metadata and not hardcoded
# Note: unfortunately, converting the dataset to a dataarray, casts all the data types to the same

if bands is not None:
stack = odc.stac.load(items, bands=bands, chunks={}, **kwargs).to_dataarray(
dim="bands"
Expand All @@ -274,16 +292,6 @@ def load_stac(
if temporal_extent is not None and (stac_type == "ITEM" or zarr_assets):
stack = filter_temporal(stack, temporal_extent)

# If at least one band requires to apply scale and/or offset, the datatype of the whole DataArray must be cast to float -> do not apply it automatically yet. see https://github.com/Open-EO/openeo-processes/issues/503
# b_dim = stack.openeo.band_dims[0]
# for b in stack[b_dim]:
# scale = asset_scale_offset[b.item(0)]["scale"]
# offset = asset_scale_offset[b.item(0)]["offset"]
# if scale != 1:
# stack.loc[{b_dim: b.item(0)}] *= scale
# if offset != 0:
# stack.loc[{b_dim: b.item(0)}] += offset

return stack


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def merge_cubes(
positional_parameters=positional_parameters,
named_parameters=named_parameters,
)
merged_cube.rio.write_crs(crs,inplace=True)
merged_cube.rio.write_crs(crs, inplace=True)
else:
# Example 1 & 2
dims_requiring_resolve = [
Expand Down
Loading