Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor select_reference_point to pick the centroid of high-coherence candidate pixels #535

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
# Changelog

## [Unreleased](https://github.com/isce-framework/dolphin/compare/v0.30.0...main)
## [Unreleased](https://github.com/isce-framework/dolphin/compare/v0.35.1...main)

## [0.35.0](https://github.com/isce-framework/dolphin/compare/v0.34.0...v0.35.0) - 2025-12-09
### Changed

- `timeseries.py`: Auto reference point selection
- Picks the center of mass instead of arbitrary `argmax` result
- Rename `condition_file` to `quality_file`

### Removed

- Removed `condition` parameter in `timeseries` reference point functions
- Removed `CallFunc` enum

## [0.35.1](https://github.com/isce-framework/dolphin/compare/v0.35.0...v0.35.1) - 2025-01-15

### Fixed

- `filtering.py` Fix in_bounds_pixels masking, set default to 25 km
- Set `output_reference_idx` separately from `compressed_reference_idx` during phase linking setup


## [0.35.0](https://github.com/isce-framework/dolphin/compare/v0.34.0...v0.35.0) - 2025-01-09

### Added

Expand Down
12 changes: 1 addition & 11 deletions src/dolphin/_cli_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from dolphin._log import setup_logging
from dolphin.timeseries import InversionMethod
from dolphin.workflows import CallFunc

if TYPE_CHECKING:
_SubparserType = argparse._SubParsersAction[argparse.ArgumentParser]
Expand Down Expand Up @@ -57,21 +56,12 @@ def get_parser(subparser=None, subcommand_name="timeseries") -> argparse.Argumen
),
)
parser.add_argument(
"--condition-file",
"--quality-file",
help=(
"A file with the same size as each raster, like amplitude dispersion or "
"temporal coherence to find reference point"
),
)
parser.add_argument(
"--condition",
type=CallFunc,
default=CallFunc.MIN,
help=(
"A condition to apply to condition file to find the reference point. "
"Options are [min, max]. default=min"
),
)
parser.add_argument(
"--method",
type=InversionMethod,
Expand Down
97 changes: 63 additions & 34 deletions src/dolphin/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from dolphin._overviews import ImageType, create_overviews
from dolphin._types import PathOrStr, ReferencePoint
from dolphin.utils import flatten, format_dates, full_suffix, get_nearest_date_idx
from dolphin.workflows import CallFunc

T = TypeVar("T")
DateOrDatetime = datetime | date
Expand All @@ -42,10 +41,10 @@ class ReferencePointError(ValueError):
def run(
unwrapped_paths: Sequence[PathOrStr],
conncomp_paths: Sequence[PathOrStr] | None,
condition_file: PathOrStr,
condition: CallFunc,
quality_file: PathOrStr,
output_dir: PathOrStr,
method: InversionMethod = InversionMethod.L1,
reference_candidate_threshold: float = 0.95,
run_velocity: bool = False,
corr_paths: Sequence[PathOrStr] | None = None,
weight_velocity_by_corr: bool = False,
Expand All @@ -66,19 +65,21 @@ def run(
Sequence unwrapped interferograms to invert.
conncomp_paths : Sequence[Path]
Sequence connected component files, one per file in `unwrapped_paths`
condition_file: PathOrStr
quality_file: PathOrStr
A file with the same size as each raster, like amplitude dispersion or
temporal coherence
condition: CallFunc
The function to apply to the condition file,
for example numpy.argmin which finds the pixel with lowest value
the options are [min, max]
output_dir : Path
Path to the output directory.
method : str, choices = "L1", "L2"
Inversion method to use when solving Ax = b.
Default is L2, which uses least squares to solve Ax = b (faster).
"L1" minimizes |Ax - b|_1 at each pixel.
reference_candidate_threshold: float
The threshold for the quality metric to be considered a candidate
reference point pixel.
Only pixels with values in `quality_file` greater than
`reference_candidate_threshold` will be considered a candidate.
Default is 0.95.
run_velocity : bool
Whether to run velocity estimation on the inverted phase series
corr_paths : Sequence[Path], optional
Expand Down Expand Up @@ -136,13 +137,12 @@ def run(
unwrapped_paths = sorted(unwrapped_paths, key=str)
Path(output_dir).mkdir(exist_ok=True, parents=True)

condition_func = argmax_index if condition == CallFunc.MAX else argmin_index
if reference_point == (-1, -1):
logger.info("Selecting a reference point for unwrapped interferograms")
ref_point = select_reference_point(
condition_file=condition_file,
quality_file=quality_file,
output_dir=Path(output_dir),
condition_func=condition_func,
candidate_threshold=reference_candidate_threshold,
ccl_file_list=conncomp_paths,
)
else:
Expand Down Expand Up @@ -969,11 +969,10 @@ def invert_unw_network(
suffix = ".tif"
# Create the `n_sar_dates - 1` output files (skipping the 0 reference raster)
out_paths = [
Path(output_dir) / (f"{format_dates(ref_date, d)}{suffix}")
for d in sar_dates[1:]
Path(output_dir) / f"{format_dates(ref_date, d)}{suffix}" for d in sar_dates[1:]
]
out_residuals_paths = [
Path(output_dir) / (f"residuals_{format_dates(ref_date, d)}{suffix}")
Path(output_dir) / f"residuals_{format_dates(ref_date, d)}{suffix}"
for d in sar_dates[1:]
]
if all(p.exists() for p in out_paths):
Expand Down Expand Up @@ -1169,32 +1168,37 @@ def correlation_to_variance(correlation: ArrayLike, nlooks: int) -> Array:

def select_reference_point(
*,
condition_file: PathOrStr,
quality_file: PathOrStr,
output_dir: Path,
condition_func: Callable[[ArrayLike], tuple[int, ...]] = argmin_index,
candidate_threshold: float = 0.95,
ccl_file_list: Sequence[PathOrStr] | None = None,
block_shape: tuple[int, int] = (256, 256),
num_threads: int = 4,
) -> ReferencePoint:
"""Automatically select a reference point for a stack of unwrapped interferograms.

Uses the condition file and (optionally) connected component labels.
Uses the quality file and (optionally) connected component labels.
The point is selected which

1. has the condition applied to condition file. for example: has the lowest
amplitude dispersion
2. (optionally) is within intersection of all nonzero connected component labels
1. (optionally) is within intersection of all nonzero connected component labels
2. Has value in `quality_file` above the threshold `candidate_threshold`

Among all points which meet this, the centroid selected using the function
`scipy.ndimage.center_of_mass`.

Parameters
----------
condition_file: PathOrStr
A file with the same size as each raster, like amplitude dispersion or
temporal coherence in `ccl_file_list`
quality_file: PathOrStr
A file with the same size as each raster in `ccl_file_list` containing a quality
metric, such as temporal coherence.
output_dir: Path
Path to store the computed "conncomp_intersection.tif" raster
condition_func: Callable[[ArrayLike, ]]
The function to apply to the condition file,
for example numpy.argmin which finds the pixel with lowest value
candidate_threshold: float
The threshold for the quality metric function to be considered a candidate
reference point pixel.
Only pixels with values in `quality_file` greater than `candidate_threshold` are
considered a candidate.
Default = 0.95
ccl_file_list : Sequence[PathOrStr]
List of connected component label phase files.
block_shape : tuple[int, int]
Expand Down Expand Up @@ -1223,9 +1227,10 @@ def select_reference_point(
return ref_point

logger.info("Selecting reference point")
condition_file_values = io.load_gdal(condition_file, masked=True)
quality_file_values = io.load_gdal(quality_file, masked=True)

isin_largest_conncomp = np.ones(condition_file_values.shape, dtype=bool)
# Start with all points as valid candidates
isin_largest_conncomp = np.ones(quality_file_values.shape, dtype=bool)
if ccl_file_list:
try:
isin_largest_conncomp = _get_largest_conncomp_mask(
Expand All @@ -1235,15 +1240,39 @@ def select_reference_point(
num_threads=num_threads,
)
except ReferencePointError:
msg = "Unable to find find a connected component intersection."
msg += f"Proceeding using only {condition_file = }"
msg = "Unable to find a connected component intersection."
msg += f"Proceeding using only {quality_file = }"
logger.warning(msg, exc_info=True)

# Mask out where the conncomps aren't equal to the largest
condition_file_values.mask = condition_file_values.mask | (~isin_largest_conncomp)
# Find pixels meeting the threshold criteria
is_candidate = quality_file_values > candidate_threshold

# Restrict candidates to the largest connected component region
is_candidate &= isin_largest_conncomp

# Pick the (unmasked) point with the condition applied to condition file
ref_row, ref_col = condition_func(condition_file_values)
# Find connected regions within candidate pixels
labeled, n_objects = ndimage.label(is_candidate, structure=np.ones((3, 3)))

if n_objects == 0:
# If no candidates meet threshold, pick best available point
logger.warning(
f"No pixels above threshold={candidate_threshold}. Choosing best among"
" available."
)
ref_row, ref_col = argmax_index(quality_file_values)
else:
# Find the largest region of connected candidate pixels
label_counts = np.bincount(labeled.ravel())
label_counts[0] = 0 # ignore background
largest_label = label_counts.argmax()
largest_component = labeled == largest_label

# Select point closest to center of largest region
row_c, col_c = ndimage.center_of_mass(largest_component)
rows, cols = np.nonzero(largest_component)
dist_sq = (rows - row_c) ** 2 + (cols - col_c) ** 2
i_min = dist_sq.argmin()
ref_row, ref_col = rows[i_min], cols[i_min]

# Cast to `int` to avoid having `np.int64` types
ref_point = ReferencePoint(int(ref_row), int(ref_col))
Expand Down
8 changes: 0 additions & 8 deletions src/dolphin/workflows/config/_enums.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from enum import Enum

__all__ = [
"CallFunc",
"ShpMethod",
"UnwrapMethod",
]
Expand All @@ -25,10 +24,3 @@ class UnwrapMethod(str, Enum):
PHASS = "phass"
SPURT = "spurt"
WHIRLWIND = "whirlwind"


class CallFunc(str, Enum):
"""Call function for the timeseries method to find reference point."""

MIN = "min"
MAX = "max"
7 changes: 4 additions & 3 deletions src/dolphin/workflows/displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from dolphin import __version__, io, timeseries, utils
from dolphin._log import log_runtime, setup_logging
from dolphin.timeseries import ReferencePoint
from dolphin.workflows import CallFunc

from . import stitching_bursts, unwrapping, wrapped_phase
from ._utils import _create_burst_cfg, _remove_dir_if_empty, parse_ionosphere_files
Expand Down Expand Up @@ -260,8 +259,10 @@ def run(
unwrapped_paths=unwrapped_paths,
conncomp_paths=conncomp_paths,
corr_paths=stitched_paths.interferometric_corr_paths,
condition_file=stitched_paths.temp_coh_file,
condition=CallFunc.MAX,
# TODO: Right now we don't have the option to pick a different candidate
# or quality file. Figure out if this is worth exposing
quality_file=stitched_paths.temp_coh_file,
reference_candidate_threshold=0.95,
output_dir=ts_opts._directory,
method=timeseries.InversionMethod(ts_opts.method),
run_velocity=ts_opts.run_velocity,
Expand Down
76 changes: 76 additions & 0 deletions tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,82 @@ def test_stack_unweighted(self, data, x_arr, expected_velo):
npt.assert_allclose(velocities, expected_velo, atol=1e-5)


class TestReferencePoint:
def test_all_ones_center(self, tmp_path):
# All coherence=1 => pick the true center pixel.
shape = (31, 31)
arr = np.ones(shape, dtype="float32")

coh_file = tmp_path / "coh_ones.tif"
io.write_arr(arr=arr, output_name=coh_file)

ref_point = timeseries.select_reference_point(
quality_file=coh_file,
output_dir=tmp_path,
candidate_threshold=0.95, # everything is above 0.95
ccl_file_list=None,
)
npt.assert_equal((ref_point.row, ref_point.col), (15, 15))

def test_half_03_half_099(self, tmp_path):
# Left half=0.3, right half=0.99 => reference point should be
# near the center of the right side.
shape = (31, 31)
arr = np.full(shape, 0.3, dtype="float32")
# Make right half, 16 to 30, high coherence
arr[:, 16:] = 0.99

coh_file = tmp_path / "coh_half_03_099.tif"
io.write_arr(arr=arr, output_name=coh_file)

ref_point = timeseries.select_reference_point(
quality_file=coh_file,
output_dir=tmp_path,
candidate_threshold=0.95,
ccl_file_list=None,
)
# Expect the center row=15, and roughly col=23 for columns 16..30.
npt.assert_equal((ref_point.row, ref_point.col), (15, 23))

def test_with_conncomp(self, tmp_path):
"""Make a temporal coherence with left half=0.3, right half=1.0.

Make the connected-component labels have top half=0, bottom half=1.

Reference point is in the bottom-right quadrant where
coherence > threshold AND conncomp == 1
"""
shape = (31, 31)

# Coherence array: left half=0.3, right half=1.0
coh = np.full(shape, 0.3, dtype="float32")
coh[:, 16:] = 1.0

coh_file = tmp_path / "coh_left03_right1.tif"
io.write_arr(arr=coh, output_name=coh_file)

# ConnComp label: top half=0, bottom half=1
# So only rows >= 15 are labeled '1'.
ccl = np.zeros(shape, dtype="uint16")
ccl[16:, :] = 1

ccl_file1 = tmp_path / "conncomp_bottom_half.tif"
io.write_arr(arr=ccl, output_name=ccl_file1)
# Add another conncomp file with all good pixels
ccl_file2 = tmp_path / "conncomp_full.tif"
io.write_arr(arr=np.ones_like(ccl), output_name=ccl_file2)

ref_point = timeseries.select_reference_point(
ccl_file_list=[ccl_file1, ccl_file2],
quality_file=coh_file,
output_dir=tmp_path,
candidate_threshold=0.95,
)

# Bottom half => rows [16..30], right half => cols [16..30].
npt.assert_equal((ref_point.row, ref_point.col), (23, 23))


if __name__ == "__main__":
sar_dates = make_sar_dates()
sar_phases = make_sar_phases(sar_dates)
Expand Down