diff --git a/notebooks/07_harmonic_parameters_zarr.ipynb b/notebooks/07_harmonic_parameters_zarr.ipynb new file mode 100644 index 0000000..c4a78f5 --- /dev/null +++ b/notebooks/07_harmonic_parameters_zarr.ipynb @@ -0,0 +1,431 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Calculate Harmonic Parameters from Sentinel-1 Backscatter Virtual Zarr\n", + "\n", + "For operational flood mapping operations, the area of interest may be significantly smaller than a single Equi7 Grid tile. Fetching data from (non-Cloud Optimized) GeoTIFFs stored in a STAC catalogue requires downloading an entire tile-size image for each timestamp of interest, when only a small portion of each image is actually needed. This overhead is not so burdensome when analysing a single flood event, as only a handful of images are needed, but recalculation of harmonic parameters requires data from at least a full year of observations, which can be dozens of images.\n", + "\n", + "We can avoid this overhead by using a virtual Zarr dataset created with [kerchunk](https://fsspec.github.io/kerchunk/). Because the virtual Zarr provides a reference from the coordinates of the S1 datacube to the byte ranges of the corresponding GeoTIFF chunks, we can download and read only the portions of each image that intersect our area of interest, and avoid downloading entire images. This is particularly efficient when combined with Dask, which can load and process those chunks in parallel. To read more about virtual Zarr datasets, see [kerchunk](https://fsspec.github.io/kerchunk/) and [virtualizarr](https://virtualizarr.readthedocs.io/en/latest/)" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "Let's begin by downloading the virtual Zarr. This could also be hosted on object storage, eliminating the need to download the entire reference file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "zip_url_root = \"https://git.geo.tuwien.ac.at/public_projects/rs/s1-virtualzarr/-/raw/main/\"\n", + "zip_filename = \"SIG0_S1_2022-2023_EU020M.parq.zip\"\n", + "zip_path = \"/tmp/\" + zip_filename\n", + "\n", + "import os\n", + "\n", + "if not os.path.exists(zip_path):\n", + " import urllib.request\n", + "\n", + " urllib.request.urlretrieve(zip_url_root + zip_filename, zip_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import dask\n", + "import numpy as np\n", + "import xarray as xr\n", + "from dask.diagnostics import ProgressBar\n", + "from dask_flood_mapper.harmonic_params import create_harmonic_parameters_zarr\n", + "from dask_flood_mapper.vzarr import open_s1_datacube\n", + "from dask_flood_mapper.vzarr.utils import get_bbox_from_tile_cube\n", + "from matplotlib import pyplot as plt\n", + "\n", + "pbar = ProgressBar()\n", + "pbar.register()" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "As an example we will select only a small region of interest contained in the Zingst case study. " + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "We will now open up the Sentinel 1 Virtual Zarr and select the parts we need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "y_chunks = 600\n", + "s1_ds = open_s1_datacube(\n", + " zip_path,\n", + " chunks={\n", + " \"X\": 15000,\n", + " \"Y\": y_chunks,\n", + " \"polarization\": 1,\n", + " \"obs\": 10,\n", + " \"orbit\": 1,\n", + " \"tile\": 1,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "from pyproj import Transformer\n", + "\n", + "minlon, maxlon = 12.3, 13.1\n", + "minlat, maxlat = 54.3, 54.6\n", + "((minx, maxx), (miny, maxy)) = Transformer.from_crs(\n", + " \"EPSG:4326\", \"EPSG:27704\", always_xy=True\n", + ").transform([minlon, maxlon], [minlat, maxlat])\n", + "bounding_box = [minx, miny, maxx, maxy]\n", + "bounding_box" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "region_ds = get_bbox_from_tile_cube(\n", + " s1_ds, bounding_box, y_chunk_size=y_chunks\n", + ").sel(polarization=\"VV\", drop=True)\n", + "region_ds" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "We only need parameters for orbits that cross our actual time of interest, so we can filter those out as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "region_ds.time.load()\n", + "flood_times = (region_ds.time >= np.datetime64(\"2023-10-11T00:00:00\")) & (\n", + " region_ds.time < np.datetime64(\"2023-10-26T00:00:00\")\n", + ")\n", + "harmpar_times = (region_ds.time >= np.datetime64(\"2022-10-11T00:00:00\")) & (\n", + " region_ds.time < np.datetime64(\"2023-10-11T00:00:00\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "flood_ds = region_ds.where(flood_times, drop=True)\n", + "# we can shortcut some calculations later by setting data to NaN for any orbits where we don't have flood data for a tile\n", + "harmpar_ds = region_ds.where(harmpar_times, drop=True).where(\n", + " flood_times.any([\"obs\"])\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "flood_ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "harmpar_ds" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "## Calculate Harmonic Parameters\n", + "\n", + "This function fits sine and cosine functions known as harmonic oscillators to each pixel of the Sentinel 1 $\\sigma^0$ datacube. These seasonally varying curves can then be extracted from time series. What is left is the noise or transient events, for example flood events, superimposed on the seasonal trend." + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "Because the virtual Zarr dataset is already structured along tile and orbit dimensions, and chunked along the Y dimension, Dask can efficiently load only the data needed for each tile, orbit, and chunk, and process them in parallel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "hpar_dc = create_harmonic_parameters_zarr(\n", + " harmpar_ds.sel(orbit=flood_ds.orbit)\n", + " .isel(orbit=[1])\n", + " .chunk({\"Y\": y_chunks, \"orbit\": 1, \"obs\": -1, \"tile\": 1}),\n", + " min_nobs=10,\n", + ")\n", + "hpar_dc" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "The result of the last cell is lazy. Finally, we can mosaic the constituent tiles together to get a complete harmonic parameter dataset for our region of interest." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "hpar_mosaic = xr.combine_by_coords(\n", + " [\n", + " hpar_dc.sel(tile=i, drop=True).set_xindex(\"Y\").set_xindex(\"X\")\n", + " for i in hpar_dc.tile.values\n", + " ]\n", + ").sel(X=slice(minx, maxx), Y=slice(maxy, miny))\n", + "hpar_mosaic" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "We probably want to use these harmonic parameters more than once without recalculating them, so we will save them to a Zarr store after computing them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "with dask.config.set(scheduler=\"threading\"):\n", + " hpar_mosaic.load().chunk(\"auto\").rio.set_spatial_dims(\n", + " x_dim=\"X\", y_dim=\"Y\"\n", + " ).rio.write_crs(\"EPSG:27704\").to_zarr(\"hpar_mosaic.zarr\")" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "Let's map the harmonic parameters as a sanity check:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "hpar_mosaic[\"M0\"].plot.imshow(\n", + " cmap=\"viridis\",\n", + " vmin=-20,\n", + " vmax=0,\n", + " col=\"orbit\",\n", + " col_wrap=3,\n", + ")\n", + "plt.savefig(\"harmonic_parameters_M0.png\", dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "plt.close()\n", + "hpar_mosaic.isel(orbit=0)[[\"S1\", \"S2\", \"S3\"]].to_dataarray(\n", + " dim=\"param\"\n", + ").plot.imshow(\n", + " cmap=\"viridis\",\n", + " vmin=-2,\n", + " vmax=2,\n", + " col=\"param\",\n", + ")\n", + "plt.savefig(\"harmonic_parameters_sine.png\", dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "\n", + "hpar_mosaic.isel(orbit=0)[[\"C1\", \"C2\", \"C3\"]].to_dataarray(\n", + " dim=\"param\"\n", + ").plot.imshow(cmap=\"viridis\", vmin=-2, vmax=2, col=\"param\")\n", + "plt.savefig(\"harmonic_parameters_cosine.png\", dpi=300)" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "We can now use the harmonic parameters to generate predicted $\\sigma^0$ values for the flood period." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "t = flood_ds.time.dt.dayofyear\n", + "n = 365\n", + "y = (\n", + " hpar_dc.M0\n", + " + hpar_dc.C1 * np.cos(2 * np.pi * t / n)\n", + " + hpar_dc.S1 * np.sin(2 * np.pi * t / n)\n", + " + hpar_dc.C2 * np.cos(2 * np.pi * t / n)\n", + " + hpar_dc.S2 * np.sin(2 * np.pi * t / n)\n", + " + hpar_dc.C3 * np.cos(2 * np.pi * t / n)\n", + " + hpar_dc.S3 * np.sin(2 * np.pi * t / n)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, + "source": [ + "## Fit Harmonic Function to Original Data\n", + "\n", + "Finally, we merge the two datasets and superimpose the fitted harmonic function on the raw sigma nought data. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [], + "source": [ + "xr.merge([y.rename(\"pred\"), flood_ds.sig0]).squeeze().hvplot(x=\"time\")" + ] + }, + { + "cell_type": "markdown", + "id": "29", + "metadata": {}, + "source": [ + "## Integrate into Flood Mapping Workflow" + ] + }, + { + "cell_type": "markdown", + "id": "30", + "metadata": {}, + "source": [ + "In order to integrate the calculated harmonic parameters into the flood mapping workflow (see [notebook 3](03_flood_map.ipynb)), we must simply read from the Zarr store and reproject from the Equi7 grid to WGS84. Replace the code in notebook 3's section \"Harmonic Parameters\" with the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "hpar_dc = xr.open_zarr(\n", + " \"hpar_mosaic.zarr\",\n", + " chunks={\"Y\": y_chunks, \"X\": 15000, \"orbit\": 1, \"tile\": 1},\n", + ").chunk({\"Y\": y_chunks, \"X\": 15000, \"orbit\": 1, \"tile\": 1})\n", + "hpar_dc = (\n", + " hpar_dc.rio.set_spatial_dims(x_dim=\"X\", y_dim=\"Y\")\n", + " .rio.reproject(\"EPSG:4326\")\n", + " .rename({\"Y\": \"latitude\", \"X\": \"longitude\"})\n", + ")\n", + "hpar_dc" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dask-flood-mapper", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/setup.cfg b/setup.cfg index 85d47c0..900c2a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,12 +72,21 @@ app = hvplot geoviews datashader +vzarr = + tqdm + imagecodecs + zarr<3 + fastparquet + aiohttp + tifffile + kerchunk all = %(test)s %(remote)s %(docs)s %(dev)s %(app)s + %(vzarr)s [options.entry_points] console_scripts = diff --git a/src/dask_flood_mapper/harmonic_params.py b/src/dask_flood_mapper/harmonic_params.py index 3e1a466..9ebbea4 100644 --- a/src/dask_flood_mapper/harmonic_params.py +++ b/src/dask_flood_mapper/harmonic_params.py @@ -4,9 +4,41 @@ import numpy as np import xarray as xr +from dask_flood_mapper.processing import order_orbits from numba import njit, prange -from dask_flood_mapper.processing import order_orbits + +def create_harmonic_parameters_zarr( + sig0_dc: xr.Dataset, + min_nobs: int = 32, + k: int = 3, +): + param_names = model_coords(k) + template = ( + sig0_dc.isel(obs=slice(len(param_names))) + .rename({"obs": "param"}) + .drop_vars("time") + ) + template["param"] = param_names + hpar_dc = xr.map_blocks( + reduce_ds_to_harmonic_parameters, + obj=sig0_dc, + kwargs={ + "fit_var_name": "sig0", + "k": k, + "x_var_name": "X", + "y_var_name": "Y", + "min_nobs": min_nobs, + }, + template=template, + ) + hpar_dc = hpar_dc.rename({"sig0": "harmonic_parameters"}) + hpar_dc = hpar_dc.where(hpar_dc.sel(param="NOBS") >= min_nobs).drop_sel( + param="NOBS" + ) + hpar_dc = hpar_dc.harmonic_parameters.to_dataset(dim="param") + + return hpar_dc def create_harmonic_parameters( @@ -65,24 +97,52 @@ def process_harmonic_parameters_datacube( return sig0_dc, hpar_dc, orbit_sig0 +def reduce_ds_to_harmonic_parameters( + ts_ds: xr.Dataset, fit_var_name: str, min_nobs: int = 0, **kwargs +): + extra_dims = [dim for dim in ts_ds.dims if dim not in ts_ds.squeeze().dims] + ts_xr = ts_ds[fit_var_name] + + # if all pixels have too few observations, skip the regression and return all NaNs + too_few_obs_short_circuit = ts_xr.count(dim="obs").max().values < min_nobs + ts_dtimes = ts_ds["time.dayofyear"].squeeze(drop=True).values + if too_few_obs_short_circuit: + ts_xr = ts_xr * np.nan + out_dataarray = reduce_to_harmonic_parameters( + ts_xr.squeeze(drop=True), dtimes=ts_dtimes, **kwargs + ) + out_dataset = xr.Dataset( + { + fit_var_name: out_dataarray.expand_dims(dim=extra_dims).transpose( + *ts_xr.rename({"obs": "param"}).dims + ) + }, + coords={ + dim: ts_ds[dim] + for dim in ts_ds.dims + if (dim in extra_dims or dim in out_dataarray.dims) and dim in ts_ds.coords + }, + ) + return out_dataset + + def reduce_to_harmonic_parameters( ts_xr: xr.DataArray, - dtimes: np.ndarray, x_var_name: str = "x", y_var_name: str = "y", **kwargs, # noqa: ANN003 -) -> xr.DataArray: - """Reduce a time series to harmonic parameters.""" - params_arr = harmonic_regression(ts_xr.data, dtimes=dtimes, **kwargs) +): + params_arr = harmonic_regression(ts_xr.values, **kwargs) k: int = kwargs.get("k", 3) out_dims: list[str] = ["param", y_var_name, x_var_name] + coords_dict = {"param": model_coords(k)} + if x_var_name in ts_xr.coords: + coords_dict[x_var_name] = ts_xr[x_var_name] + if y_var_name in ts_xr.coords: + coords_dict[y_var_name] = ts_xr[y_var_name] return xr.DataArray( data=params_arr, - coords={ - "param": model_coords(k), - x_var_name: ts_xr[x_var_name], - y_var_name: ts_xr[y_var_name], - }, + coords=coords_dict, dims=out_dims, ) @@ -100,6 +160,11 @@ def harmonic_regression( # should be in dayofyear format t = dtimes + # drop t and arr where t is nan for efficiency in regression + valid_time = ~np.isnan(t) + t = t[valid_time] + arr = arr[valid_time, ...] # type: ignore + # prepare A-matrix num_dims: int = 3 if len(arr.shape) != num_dims: @@ -115,13 +180,11 @@ def harmonic_regression( # run regression param = np.full((nx + 2, rows, cols), np.nan, dtype=np.float32) - _fast_harmonic_regression( - arr=arr, - a_matrix=a, - k=k, - red=redundancy, - param=param, - ) + arr = arr.astype(np.float32) + if np.all(np.isnan(arr)): + # All NaN array, return NaN params + return param + _fast_harmonic_regression(arr=arr, a_matrix=a, k=k, red=redundancy, param=param) return param diff --git a/src/dask_flood_mapper/vzarr/__init__.py b/src/dask_flood_mapper/vzarr/__init__.py new file mode 100644 index 0000000..064ebb6 --- /dev/null +++ b/src/dask_flood_mapper/vzarr/__init__.py @@ -0,0 +1 @@ +from .read import open_s1_datacube # noqa diff --git a/src/dask_flood_mapper/vzarr/construct.py b/src/dask_flood_mapper/vzarr/construct.py new file mode 100644 index 0000000..61c3599 --- /dev/null +++ b/src/dask_flood_mapper/vzarr/construct.py @@ -0,0 +1,319 @@ +import base64 +import io +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Tuple + +import fsspec +import numpy as np +import pandas as pd +import tifffile +import ujson +import xarray as xr +import zarr +from fsspec.implementations.reference import LazyReferenceMapper +from tqdm import tqdm + + +def generate_equi7_vzarr_from_dataframe( + paths_df: pd.DataFrame, + outfile: Path, + replace_in_paths: Tuple[str, str] | None = None, + metadata: dict = {}, +): + paths_df = paths_df.sort_index() + + time_coords, paths_df["time_idx"] = np.unique(paths_df.index, return_inverse=True) + paths_df[["lat_coord", "lon_coord"]] = paths_df.apply( + (lambda x: (int(x["tile_name"][1:4]), int(x["tile_name"][5:8]))), + axis=1, + result_type="expand", + ) + _, paths_df["y_idx"] = np.unique(paths_df["lat_coord"], return_inverse=True) + _, paths_df["x_idx"] = np.unique(paths_df["lon_coord"], return_inverse=True) + polarization_coords, paths_df["polarization_idx"] = np.unique( + paths_df["band"], return_inverse=True + ) + orbit_coords, paths_df["orbit_idx"] = np.unique( + paths_df["extra_field"], return_inverse=True + ) + tile_coords, paths_df["tile_idx"] = np.unique( + paths_df["tile_name"], return_inverse=True + ) + X_coords, Y_coords = coordinates_from_e7_string(tile_coords, n_pixels=15000) + + # Pre-assign sequential order to all files + file_order_map = {} + file_count = 0 + + # It's useful to group by tile first, assuming it's more likely that we'll be + # reading many orbits from a single tile than many tiles from a single orbit. + grouper = paths_df.groupby( + [ + "tile_idx", + "orbit_idx", + "polarization_idx", + ] + ) + + for i, (group, group_df) in enumerate(grouper): + files = group_df.sort_index()["filepath"].to_list() + files = [str(path) for path in files] + for j, filename in enumerate(files): + file_order_map[filename] = file_count + file_count += 1 + + fs, _ = fsspec.core.url_to_fs(outfile, **({})) + out_refs = LazyReferenceMapper.create( + record_size=300000, + root=outfile, + fs=fs, + categorical_threshold=10, + ) + + coordinates = {} + zarrgroup = zarr.open_group(coordinates) + + zarrgroup.array( + "tile", + data=tile_coords, + dtype="= minx) & (X <= maxx)).any("X") + useful_Y_tiles = ((Y >= miny) & (Y <= maxy)).any("Y") + needed_tiles = useful_X_tiles & useful_Y_tiles + X = ds.X.isel(tile=needed_tiles) + Y = ds.Y.isel(tile=needed_tiles) + needed_X = ((X >= minx) & (X <= maxx)).any("tile") + needed_Y = ((Y >= miny) & (Y <= maxy)).any("tile") + needed_Y.data = extend_true_to_chunk_edges(needed_Y.data, y_chunk_size, axis=0) + + tiles_ds = ds.isel(tile=needed_tiles.load()) + out_ds = tiles_ds.isel(Y=needed_Y).isel(X=needed_X) + valid_images = out_ds.time.notnull() + out_ds = out_ds.isel( + orbit=valid_images.any(["obs", "tile"]).values, + obs=valid_images.any(["orbit", "tile"]).values, + ) + return out_ds + + +def extend_true_to_chunk_edges(arr: np.ndarray[bool], chunk_size: int, axis: int = -1): + """Extend True values in a boolean array to the edges of chunks of given size.""" + + arr = np.asarray(arr) + out = np.zeros_like(arr, dtype=bool) + + arr_moved = np.moveaxis(arr, axis, -1) + out_moved = np.moveaxis(out, axis, -1) + + shape = arr_moved.shape + n = shape[-1] + n_chunks = (n + chunk_size - 1) // chunk_size + + it = np.nditer(arr_moved[..., 0], flags=["multi_index"]) + for _ in it: + idx = it.multi_index + line = arr_moved[idx] # 1D boolean array + + line_out = np.zeros_like(line, dtype=bool) + for i in range(n_chunks): + start = i * chunk_size + end = min((i + 1) * chunk_size, n) + if np.any(line[start:end]): + line_out[start:end] = True + out_moved[idx] = line_out + + return np.moveaxis(out_moved, -1, axis) diff --git a/tests/test_dask_flood_mapper.py b/tests/test_dask_flood_mapper.py index 918b72b..08c67c8 100644 --- a/tests/test_dask_flood_mapper.py +++ b/tests/test_dask_flood_mapper.py @@ -153,12 +153,8 @@ def mock_data_cubes(): @pytest.fixture def mock_data(): """Creates a mock dataset similar to real sig0_dc, using Dask arrays.""" - times = np.array( - ["2022-10-11", "2022-10-11", "2022-10-12"], dtype="datetime64" - ) - data_values = np.array( - [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] - ) + times = np.array(["2022-10-11", "2022-10-11", "2022-10-12"], dtype="datetime64") + data_values = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]) dask_data = da.from_array(data_values, chunks=(1, 2, 2)) @@ -261,9 +257,7 @@ def test_bayesian_flood_decision(self, mock_hpar_dataset): evidence = (nf_prob * 0.5) + (f_prob * 0.5) nf_post_prob_expected = (nf_prob * 0.5) / evidence f_post_prob_expected = (f_prob * 0.5) / evidence - decision_expected = np.greater( - f_post_prob_expected, nf_post_prob_expected - ) + decision_expected = np.greater(f_post_prob_expected, nf_post_prob_expected) assert (decision.values == decision_expected).all() @@ -287,13 +281,8 @@ def test_post_processing(self, mock_hpar_dataset): decision = post_processing(mock_hpar_dataset) expected_decision = ( - np.logical_and( - mock_hpar_dataset.MPLIA >= 27, mock_hpar_dataset.MPLIA <= 48 - ) - * ( - mock_hpar_dataset.hbsc - > (mock_hpar_dataset.wbsc + 0.5 * 2.754041) - ) + np.logical_and(mock_hpar_dataset.MPLIA >= 27, mock_hpar_dataset.MPLIA <= 48) + * (mock_hpar_dataset.hbsc > (mock_hpar_dataset.wbsc + 0.5 * 2.754041)) * ( ( mock_hpar_dataset.sig0 @@ -304,10 +293,7 @@ def test_post_processing(self, mock_hpar_dataset): < (mock_hpar_dataset.hbsc + 3 * mock_hpar_dataset.STD) ) ) - * ( - mock_hpar_dataset.sig0 - < (mock_hpar_dataset.wbsc + 3 * 2.754041) - ) + * (mock_hpar_dataset.sig0 < (mock_hpar_dataset.wbsc + 3 * 2.754041)) * (mock_hpar_dataset.f_post_prob > 0.8) ) @@ -382,9 +368,7 @@ def test_process_datacube( orbit_sig0 in result.orbit.values ), f"Dataset should contain orbit '{orbit_sig0}' only" - mock_post_process.assert_called_once_with( - mock_data, mock_items_orbits, bands - ) + mock_post_process.assert_called_once_with(mock_data, mock_items_orbits, bands) mock_extract_orbit_names.assert_called_once_with(mock_items_orbits) @@ -416,12 +400,8 @@ def test_remove_speckles(mock_data_cubes): sig0_dc, _, _ = mock_data_cubes # Use only sig0_dc for filtering result = remove_speckles(sig0_dc, window_size=3) - assert ( - result.sizes["y"] == sig0_dc.sizes["y"] - ), "y size should remain the same" - assert ( - result.sizes["x"] == sig0_dc.sizes["x"] - ), "x size should remain the same" + assert result.sizes["y"] == sig0_dc.sizes["y"], "y size should remain the same" + assert result.sizes["x"] == sig0_dc.sizes["x"], "x size should remain the same" assert result.chunks, "Dataset should be persisted (chunked with Dask)" assert any( isinstance(v.data, da.Array) for v in result.data_vars.values() diff --git a/tests/test_harmonic_params.py b/tests/test_harmonic_params.py index 7c467d5..705a6ba 100644 --- a/tests/test_harmonic_params.py +++ b/tests/test_harmonic_params.py @@ -6,6 +6,7 @@ harmonic_regression, model_coords, process_harmonic_parameters_datacube, + reduce_ds_to_harmonic_parameters, reduce_to_harmonic_parameters, ) @@ -15,12 +16,8 @@ def generate_harmonic_timeseries(times, mean, sin_amplitudes, cos_amplitudes): w = 2 * np.pi / 365 result = mean * np.ones_like(times) - for k, (sin_amp, cos_amp) in enumerate( - zip(sin_amplitudes, cos_amplitudes), 1 - ): - result += sin_amp * np.sin(k * w * times) + cos_amp * np.cos( - k * w * times - ) + for k, (sin_amp, cos_amp) in enumerate(zip(sin_amplitudes, cos_amplitudes), 1): + result += sin_amp * np.sin(k * w * times) + cos_amp * np.cos(k * w * times) return result @@ -113,9 +110,7 @@ def synthetic_data(request): orbit = np.array([["A1", "B1"][int(time % 2)] for time in times]) # Generate perfect harmonic signal - ts = generate_harmonic_timeseries( - times, mean, sin_amplitudes, cos_amplitudes - ) + ts = generate_harmonic_timeseries(times, mean, sin_amplitudes, cos_amplitudes) ts_data = ts.reshape(-1, 1, 1).astype(np.float32) ts_data = np.broadcast_to(ts_data, (len(times), rows, cols)).copy() @@ -193,9 +188,7 @@ def test_harmonic_regression_handles_insufficient_data(synthetic_data): assert np.isnan( params[:-1, 0, 0] ).all(), "Parameters should be NaN with insufficient data" - assert ( - params[-1, 0, 0] == 2 * synthetic_k - ), f"NOBS should be {2 * synthetic_k}" + assert params[-1, 0, 0] == 2 * synthetic_k, f"NOBS should be {2 * synthetic_k}" assert not np.isnan( params[:, 0, 1] ).any(), "Other pixels should have valid parameters" @@ -205,9 +198,9 @@ def test_harmonic_regression_respects_redundancy(synthetic_data): # Make some data NaN but keep enough for default redundancy data_with_nans = synthetic_data["data"].copy() k = synthetic_data["k"] - data_with_nans[: data_with_nans.shape[0] - (2 * k + 2), 0, 0] = ( - np.nan - ) # Leave 6 observations + data_with_nans[ + : data_with_nans.shape[0] - (2 * k + 2), 0, 0 + ] = np.nan # Leave 6 observations # Should work with redundancy=1 params_red1 = harmonic_regression( @@ -216,9 +209,7 @@ def test_harmonic_regression_respects_redundancy(synthetic_data): k=synthetic_data["k"], redundancy=1, ) - assert not np.isnan( - params_red1[:-1, 0, 0] - ).any(), "Should work with redundancy=1" + assert not np.isnan(params_red1[:-1, 0, 0]).any(), "Should work with redundancy=1" # Should fail with redundancy=2 params_red2 = harmonic_regression( @@ -227,9 +218,7 @@ def test_harmonic_regression_respects_redundancy(synthetic_data): k=synthetic_data["k"], redundancy=2, ) - assert np.isnan( - params_red2[:-1, 0, 0] - ).all(), "Should fail with redundancy=2" + assert np.isnan(params_red2[:-1, 0, 0]).all(), "Should fail with redundancy=2" def test_harmonic_regression_handles_no_data(synthetic_data): @@ -272,9 +261,7 @@ def synthetic_xarray_data(synthetic_data): @pytest.fixture def synthetic_xarray_dataset(synthetic_data): # Create synthetic xarray DataArray - times = pd.to_timedelta(synthetic_data["times"], "D") + np.datetime64( - "2018-12-31" - ) + times = pd.to_timedelta(synthetic_data["times"], "D") + np.datetime64("2018-12-31") data = synthetic_data["data"] orbit = synthetic_data["orbit"] @@ -315,9 +302,7 @@ def test_synthetic_array_dataset_contains_original_synthetic_data( ), "Synthetic dataset does not contain the original time data" -def test_reduce_to_harmonic_parameters_basic( - synthetic_xarray_data, synthetic_data -): +def test_reduce_to_harmonic_parameters_basic(synthetic_xarray_data, synthetic_data): # Run reduction result = reduce_to_harmonic_parameters( synthetic_xarray_data, @@ -348,12 +333,102 @@ def test_reduce_to_harmonic_parameters_coordinates(synthetic_xarray_data): # Check coordinates are properly set expected_params = model_coords(k) assert list(result.param.values) == expected_params - np.testing.assert_array_equal( - result.x.values, synthetic_xarray_data.x.values + np.testing.assert_array_equal(result.x.values, synthetic_xarray_data.x.values) + np.testing.assert_array_equal(result.y.values, synthetic_xarray_data.y.values) + + +def test_reducing_via_map_blocks(synthetic_xarray_data): + k = 2 + chunked = synthetic_xarray_data.chunk({"x": 1, "y": 1, "time": -1}) + param_names = model_coords(k) + template = chunked.isel(time=slice(len(param_names))).rename({"time": "param"}) + template["param"] = param_names + reduced = chunked.map_blocks( + reduce_to_harmonic_parameters, + template=template, + kwargs={"k": k, "dtimes": synthetic_xarray_data.time.values}, ) - np.testing.assert_array_equal( - result.y.values, synthetic_xarray_data.y.values + reduced.load() + + +@pytest.fixture +def synthetic_s1_dataset(): + # Reduced dimensions for testing + orbits = ["A015", "A029"] + polarizations = ["VH", "VV"] + tiles = ["E042N012T3", "E042N015T3"] + obs = 1000 + Y = 5 + X = 100 + + # Generate random data + sig0_data = np.random.rand( + obs, len(polarizations), len(orbits), len(tiles), Y, X + ).astype(np.float32) + + # Generate time data + start_date = pd.Timestamp("2021-01-01") + time_data = np.array([start_date + pd.Timedelta(days=i) for i in range(obs)]) + time_data = np.broadcast_to( + time_data[:, np.newaxis, np.newaxis], (obs, len(orbits), len(tiles)) + ) + + # Create the dataset + ds = xr.Dataset( + data_vars={ + "sig0": (("obs", "polarization", "orbit", "tile", "Y", "X"), sig0_data) + }, + coords={ + "orbit": orbits, + "polarization": polarizations, + "tile": tiles, + "time": (("obs", "orbit", "tile"), time_data), + }, + ) + + return ds + + +# Test to ensure the fixture is working correctly +def test_synthetic_s1_dataset(synthetic_s1_dataset): + assert isinstance(synthetic_s1_dataset, xr.Dataset) + assert set(synthetic_s1_dataset.dims) == { + "obs", + "polarization", + "orbit", + "tile", + "Y", + "X", + } + assert set(synthetic_s1_dataset.data_vars) == {"sig0"} + assert set(synthetic_s1_dataset.coords) == {"orbit", "polarization", "tile", "time"} + + assert synthetic_s1_dataset.sig0.shape == (1000, 2, 2, 2, 5, 100) + assert synthetic_s1_dataset.time.shape == (1000, 2, 2) + + assert synthetic_s1_dataset.orbit.values.tolist() == ["A015", "A029"] + assert synthetic_s1_dataset.polarization.values.tolist() == ["VH", "VV"] + assert synthetic_s1_dataset.tile.values.tolist() == ["E042N012T3", "E042N015T3"] + + +def test_reducing_via_map_blocks_with_nd_time(synthetic_s1_dataset): + k = 2 + chunked = synthetic_s1_dataset.chunk( + {"tile": 1, "orbit": 1, "polarization": 1, "obs": -1} + ) + param_names = model_coords(k) + template = ( + chunked.isel(obs=slice(len(param_names))) + .rename({"obs": "param"}) + .drop_vars("time") + ) + template["param"] = param_names + reduced = chunked.map_blocks( + reduce_ds_to_harmonic_parameters, + template=template, + kwargs={"fit_var_name": "sig0", "k": k, "x_var_name": "X", "y_var_name": "Y"}, ) + reduced.load() def test_reduce_to_harmonic_parameters_with_nans(synthetic_xarray_data): @@ -367,10 +442,7 @@ def test_reduce_to_harmonic_parameters_with_nans(synthetic_xarray_data): # Check that parameters are computed correctly despite NaNs assert not np.isnan(result.sel(x=0, y=0)).all() - assert ( - result.sel(param="NOBS", x=0, y=0) - == len(synthetic_xarray_data.time) - 2 - ) + assert result.sel(param="NOBS", x=0, y=0) == len(synthetic_xarray_data.time) - 2 @pytest.fixture @@ -386,9 +458,7 @@ def make_pars_list(synthetic_xarray_dataset, synthetic_data): def assert_both_orbits_have_approx_the_same_parameters(hpar_dc): - assert np.all( - np.abs(hpar_dc.diff(dim="orbit")) < 1e-6 - ), "Orbits differ too much" + assert np.all(np.abs(hpar_dc.diff(dim="orbit")) < 1e-6), "Orbits differ too much" def assert_retrieved_harmpars_are_approx_the_same_as_synthetic_data( @@ -422,9 +492,7 @@ def assert_we_have_hpars_for_all_sig0_orbits(sig0_dc, hpar_dc, orbit_sig0): ), "Not all sig0 orbits have harmonic parameters" -def test_make_process( - make_pars_list, synthetic_data, synthetic_xarray_dataset -): +def test_make_process(make_pars_list, synthetic_data, synthetic_xarray_dataset): pars_list = make_pars_list.copy() sig0_dc = synthetic_xarray_dataset.copy().sortby("time") time_range = (sig0_dc.time[-4].data, sig0_dc.time[-1].data)