Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update average_rasters, add ability to make weighted average #352

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions src/dolphin/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax.numpy as jnp
import numpy as np
from jax import Array, jit, vmap
from numpy.typing import ArrayLike
from numpy.typing import ArrayLike, NDArray
from opera_utils import get_dates
from scipy import ndimage

Expand Down Expand Up @@ -581,21 +581,25 @@ def read_and_fit(
create_overviews([output_file])


class AverageFunc(Protocol):
class WeightedAverager(Protocol):
"""Protocol for temporally averaging a block of data."""

def __call__(self, ArrayLike, axis: int) -> ArrayLike: ...
def __call__(
self, arr: ArrayLike, axis: int, weights: ArrayLike | None
) -> NDArray: ...


def create_temporal_average(
def create_average(
file_list: Sequence[PathOrStr],
output_file: PathOrStr,
block_shape: tuple[int, int] = (512, 512),
num_threads: int = 5,
average_func: Callable[[ArrayLike, int], np.ndarray] = np.nanmean,
average_func: WeightedAverager = np.average,
mask_average_func: Callable[[ArrayLike, int], np.ndarray] = np.any,
weights: ArrayLike | None = None,
read_masked: bool = False,
) -> None:
"""Average all images in `reader` to create a 2D image in `output_file`.
"""Average all images in `file_list` to create a 2D image in `output_file`.

Parameters
----------
Expand All @@ -611,18 +615,37 @@ def create_temporal_average(
Default is 5.
average_func : Callable[[ArrayLike, int], np.ndarray], optional
The function to use to average the images.
Default is `np.nanmean`, which calls `np.nanmean(arr, axis=0)` on each block.
Default calls `np.average(arr, axis=0, weights=weights)` on each block.
mask_average_func : Callable[[ArrayLike, int], np.ndarray], optional
If `read_masked` is true, the function to use to average the masks.
Default is `np.any`, which calls `np.any(masks, axis=0)` on each block
and masks *any* pixel that is masked in one of the images.
Use `np.all` to have more valid pixels (a smaller masked region).
weights: ArrayLike, optional
If provided, assigns a floating point weight to each file in `file_list`.
The output is a weighted average.
Default is `None`, equivalent to `np.ones(len(file_list))`.
read_masked : bool, optional
If True, reads the data as a masked array based on the rasters' nodata values.
Default is False.

"""
if weights is None:
weights = np.ones(len(file_list), dtype="float32")
if weights.shape != (len(file_list),):
msg = f"weights must be shape (len(file_list),) got {weights.shape}"
raise ValueError(msg)

def read_and_average(
readers: Sequence[io.StackReader], rows: slice, cols: slice
) -> tuple[slice, slice, np.ndarray]:
chunk = readers[0][:, rows, cols]
return average_func(chunk, 0), rows, cols

out_chunk = average_func(chunk, axis=0, weights=weights), rows, cols
if isinstance(chunk, np.ma.MaskedArray):
mask = mask_average_func(chunk.mask, 0)
out_chunk = np.ma.MaskedArray(data=out_chunk, mask=mask)
return out_chunk

writer = io.BackgroundRasterWriter(output_file, like_filename=file_list[0])
with NamedTemporaryFile(mode="w", suffix=".vrt") as f:
Expand All @@ -644,6 +667,17 @@ def read_and_average(
writer.notify_finished()


def create_temporal_average(*args, **kwargs):
import warnings

warnings.warn(
"'create_temporal_average' is deprecated. Use 'create_average' instead.",
DeprecationWarning,
stacklevel=2,
)
return create_average(*args, **kwargs)


def invert_unw_network(
unw_file_list: Sequence[PathOrStr],
reference: ReferencePoint,
Expand Down Expand Up @@ -880,8 +914,12 @@ def _get_largest_conncomp_mask(
block_shape: tuple[int, int] = (512, 512),
num_threads: int = 5,
) -> np.ndarray:
def intersect_conncomp(arr: np.ma.MaskedArray, axis: int) -> np.ndarray:
def intersect_conncomp(
arr: ArrayLike, axis: int, _weights: ArrayLike | None
) -> np.ndarray:
# Track where input is nodata
if not isinstance(arr, np.ma.MaskedArray):
arr = np.ma.MaskedArray(data=arr, mask=np.ma.nomask)
any_masked = np.any(arr.mask, axis=axis)
# Get the logical AND of all nonzero conncomp labels
fillval = arr.fill_value
Expand All @@ -892,14 +930,15 @@ def intersect_conncomp(arr: np.ma.MaskedArray, axis: int) -> np.ndarray:
return all_are_valid

conncomp_intersection_file = Path(output_dir) / "conncomp_intersection.tif"
average_func: WeightedAverager = intersect_conncomp # type: ignore[assignment]
if ccl_file_list and not conncomp_intersection_file.exists():
logger.info("Creating intersection of connected components")
create_temporal_average(
create_average(
file_list=ccl_file_list,
output_file=conncomp_intersection_file,
block_shape=block_shape,
num_threads=num_threads,
average_func=intersect_conncomp,
average_func=average_func,
read_masked=True,
)

Expand Down
32 changes: 4 additions & 28 deletions src/dolphin/workflows/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@

import logging
from itertools import chain
from os import fspath
from pathlib import Path
from typing import Optional

from osgeo_utils import gdal_calc

from dolphin import io
from dolphin._types import Filename
from dolphin.io import VRTStack
from dolphin.stack import MiniStackPlanner
from dolphin.timeseries import create_average

from .config import ShpMethod
from .single import run_wrapped_phase_single
Expand Down Expand Up @@ -128,12 +125,10 @@ def already_processed(d: Path, search_ext: str = ".tif") -> bool:
# Average the temporal coherence files in each ministack
full_span = ministack_planner.real_slc_date_range_str
output_temp_coh_file = output_folder / f"temporal_coherence_average_{full_span}.tif"
output_shp_count_file = output_folder / f"shp_counts_average_{full_span}.tif"
create_average(temp_coh_files, output_file=output_temp_coh_file)

# we can pass the list of files to gdal_calc, which interprets it
# as a multi-band file
_average_rasters(temp_coh_files, output_temp_coh_file, "Float32")
_average_rasters(shp_count_files, output_shp_count_file, "Int16")
output_shp_count_file = output_folder / f"shp_counts_average_{full_span}.tif"
create_average(shp_count_files, output_file=output_shp_count_file)

# Combine the separate SLC output lists into a single list
all_slc_files = list(chain.from_iterable(output_slc_files))
Expand Down Expand Up @@ -162,22 +157,3 @@ def _get_outputs_from_folder(
# Currently ignoring to not stitch:
# eigenvalues, estimator, avg_coh
return cur_output_files, cur_comp_slc_file, temp_coh_file, shp_count_file


def _average_rasters(file_list: list[Path], outfile: Path, output_type: str):
if len(file_list) == 1:
file_list[0].rename(outfile)
return

logger.info(f"Averaging {len(file_list)} files into {outfile}")
gdal_calc.Calc(
NoDataValue=0,
format="GTiff",
outfile=fspath(outfile),
type=output_type,
quiet=True,
overwrite=True,
creation_options=io.DEFAULT_TIFF_OPTIONS,
A=file_list,
calc="numpy.nanmean(A, axis=0)",
)
26 changes: 26 additions & 0 deletions tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,32 @@ def test_stack_unweighted(self, data, x_arr, expected_velo):
npt.assert_allclose(velocities, expected_velo, atol=1e-5)


class TestCreateAverage:
def test_basic(self, tmp_path, slc_file_list, slc_stack):
output_file = tmp_path / "average.tif"
timeseries.create_average(
file_list=slc_file_list, output_file=output_file, num_threads=1
)
computed = io.load_gdal(output_file)
expected = np.average(slc_stack, axis=0)
npt.assert_allclose(computed, expected)

def test_weighted(self, tmp_path, slc_file_list, slc_stack):
output_file = tmp_path / "average.tif"
weights = np.ones(len(slc_file_list))
weights[-1] = 0
timeseries.create_average(
file_list=slc_file_list,
output_file=output_file,
num_threads=1,
weights=weights,
)

computed = io.load_gdal(output_file)
expected = np.average(slc_stack[:-1], axis=0)
npt.assert_allclose(computed, expected, atol=1e-6)


if __name__ == "__main__":
sar_dates = make_sar_dates()
sar_phases = make_sar_phases(sar_dates)
Expand Down