Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tobac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
add_coordinates,
get_spacings,
)
from .utils.mask import convert_feature_mask_to_cells, convert_cell_mask_to_features
from .feature_detection import feature_detection_multithreshold
from .tracking import linking_trackpy
from .wrapper import maketrack
Expand Down
30 changes: 30 additions & 0 deletions tobac/segmentation/watershed_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from tobac.utils import get_statistics
from tobac.utils import decorators
from tobac.utils.generators import field_and_features_over_time
from tobac.utils.mask import convert_feature_mask_to_cells


def add_markers(
Expand Down Expand Up @@ -1135,6 +1136,8 @@ def segmentation(
segment_number_unassigned: int = 0,
statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
time_padding: Optional[datetime.timedelta] = datetime.timedelta(seconds=0.5),
return_cells: bool = False,
stubs: Optional[int] = None,
) -> tuple[xr.DataArray, pd.DataFrame]:
"""Use watershedding to determine region above a threshold
value around initial seeding position for all time steps of
Expand Down Expand Up @@ -1212,6 +1215,18 @@ def segmentation(
timestep that is time_padding off of the feature. Extremely useful when
converting between micro- and nanoseconds, as is common when using Pandas
dataframes.
return_cells: bool, optional (default: False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you added the capability to directly go from features to tracks as well, should that also be an option in the segmentation function?

If True, the segmentation mask returned will use the cell values of the
input dataframe, rather than the feature values. This requires the
features input to be the output from tobac.linking_trackpy
stubs: int, optional (default: None)
The stub values used for unlinked cells in tobac.linking_trackpy, used
when return_cells=True If None, the stub cells with be relabelled with
the stub cell value in the feature dataframe. If a value is provided,
the masked regions corresponding to stub cells with be removed from the
output. Warning: the presence of stub cells may make it impossible to
perfectly reconstruct the feature mask afterwards as any stub features
will be removed.

Returns
-------
Expand Down Expand Up @@ -1250,6 +1265,12 @@ def segmentation(
)
) from exc

# Check features has cell column if return_cells is True:
if return_cells and "cell" not in features.columns:
raise ValueError(
"`cell` column not found in features input, please perform tracking on this data before performing segmentation with `return_cells=True`"
)

# create our output dataarray
segmentation_out_data = xr.DataArray(
np.zeros(field.shape, dtype=int),
Expand Down Expand Up @@ -1300,6 +1321,15 @@ def segmentation(

# Merge output from individual timesteps:
features_out = pd.concat(features_out_list)

# Convert feature mask to cells if return_cells is True:
if return_cells:
segmentation_out_data = convert_feature_mask_to_cells(
features_out,
segmentation_out_data,
stubs=stubs,
)

logging.debug("Finished segmentation")
return segmentation_out_data, features_out

Expand Down
146 changes: 136 additions & 10 deletions tobac/tests/segmentation_tests/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import tobac.segmentation as seg
from datetime import datetime
import numpy as np
import pandas as pd
import xarray as xr
import pytest
from tobac import segmentation, feature_detection, testing
from tobac.utils import periodic_boundaries as pbc_utils

Expand Down Expand Up @@ -86,7 +88,7 @@ def test_segmentation_timestep_2D_feature_2D_seg():
)

for pbc_option in ["none", "hdim_1", "hdim_2", "both"]:
out_seg_mask, out_df = seg.segmentation_timestep(
out_seg_mask, out_df = segmentation.segmentation_timestep(
field_in=test_data_iris,
features_in=test_feature_ds,
dxy=test_dxy,
Expand Down Expand Up @@ -154,7 +156,7 @@ def test_segmentation_timestep_2D_feature_2D_seg():
)

for pbc_option in ["none", "hdim_1", "hdim_2", "both"]:
out_seg_mask, out_df = seg.segmentation_timestep(
out_seg_mask, out_df = segmentation.segmentation_timestep(
field_in=test_data_iris,
features_in=test_feature_ds,
dxy=test_dxy,
Expand Down Expand Up @@ -222,7 +224,7 @@ def test_segmentation_timestep_2D_feature_2D_seg():
)

for pbc_option in ["none", "hdim_1", "hdim_2", "both"]:
out_seg_mask, out_df = seg.segmentation_timestep(
out_seg_mask, out_df = segmentation.segmentation_timestep(
field_in=test_data_iris,
features_in=test_feature_ds,
dxy=test_dxy,
Expand Down Expand Up @@ -762,7 +764,7 @@ def test_segmentation_timestep_3d_buddy_box(
common_seg_opts["seed_3D_flag"] = "box"
common_seg_opts["seed_3D_size"] = seed_3D_size

out_seg_mask, out_df = seg.segmentation_timestep(
out_seg_mask, out_df = segmentation.segmentation_timestep(
field_in=test_data_iris, features_in=test_feature_ds, **common_seg_opts
)

Expand Down Expand Up @@ -790,7 +792,7 @@ def test_segmentation_timestep_3d_buddy_box(
PBC_flag="both",
)
test_feature_ds_shifted = pd.concat([test_feature_ds_1, test_feature_ds_2])
out_seg_mask_shifted, out_df = seg.segmentation_timestep(
out_seg_mask_shifted, out_df = segmentation.segmentation_timestep(
field_in=test_data_iris_shifted,
features_in=test_feature_ds_shifted,
**common_seg_opts,
Expand Down Expand Up @@ -895,7 +897,7 @@ def test_add_markers_pbcs(
common_marker_opts["seed_3D_flag"] = "box"
common_marker_opts["seed_3D_size"] = seed_3D_size

marker_arr = seg.add_markers(
marker_arr = segmentation.add_markers(
test_feature_ds, np.zeros(dset_size), **common_marker_opts
)

Expand Down Expand Up @@ -931,7 +933,7 @@ def test_add_markers_pbcs(

test_feature_ds_shifted = pd.concat([test_feature_ds_1, test_feature_ds_2])

marker_arr_shifted = seg.add_markers(
marker_arr_shifted = segmentation.add_markers(
test_feature_ds_shifted, np.zeros(dset_size), **common_marker_opts
)

Expand Down Expand Up @@ -989,7 +991,7 @@ def test_empty_segmentation(PBC_flag):
seg_arr, data_type="iris", z_dim_num=0, y_dim_num=1, x_dim_num=2
)

out_seg_mask, out_df = seg.segmentation_timestep(
out_seg_mask, out_df = segmentation.segmentation_timestep(
field_in=test_data_iris, features_in=test_feature, **seg_opts
)

Expand Down Expand Up @@ -1181,3 +1183,127 @@ def test_seg_alt_unseed_num(below_thresh, above_thresh, error):

seg_out_arr = seg_output.core_data()
assert np.all(correct_seg_arr == seg_out_arr)


def test_segmentation_return_cells():
"""Test segmentation with the return_cells option"""
test_data = np.zeros([3, 4, 5], dtype=int)
test_data[:, 1:3, 1:4] = 2

test_data = xr.DataArray(
test_data,
dims=("time", "y", "x"),
coords=dict(
time=pd.date_range(
datetime(2000, 1, 1, 0), datetime(2000, 1, 1, 2), periods=3
)
),
attrs=dict(units="feature"),
)

test_features = pd.DataFrame(
{
"feature": [1, 2, 3],
"frame": [0, 1, 2],
"time": pd.date_range(
datetime(2000, 1, 1, 0), datetime(2000, 1, 1, 2), periods=3
),
"hdim_1": [1.5, 1.5, 1.5],
"hdim_2": [2, 2, 2],
"cell": [1, 1, 1],
}
)

cell_mask, _ = segmentation.segmentation(
test_features, test_data, 1, threshold=1, return_cells=True
)

assert np.all(cell_mask.values[test_data.values == 2] == 1)
assert np.all(cell_mask.values[test_data.values == 0] == 0)


def test_segmentation_return_cells_stubs():
"""Test segmentation with the return_cells option and stubs option"""
test_data = np.zeros([3, 4, 5], dtype=int)
test_data[:, 1:3, 1:4] = 2

test_data = xr.DataArray(
test_data,
dims=("time", "y", "x"),
coords=dict(
time=pd.date_range(
datetime(2000, 1, 1, 0), datetime(2000, 1, 1, 2), periods=3
)
),
attrs=dict(units="feature"),
)

test_features = pd.DataFrame(
{
"feature": [1, 2, 3],
"frame": [0, 1, 2],
"time": pd.date_range(
datetime(2000, 1, 1, 0), datetime(2000, 1, 1, 2), periods=3
),
"hdim_1": [1.5, 1.5, 1.5],
"hdim_2": [2, 2, 2],
"cell": [1, 1, -1],
}
)

# Without stubs
cell_mask, _ = segmentation.segmentation(
test_features, test_data, 1, threshold=1, return_cells=True
)

assert np.all(cell_mask[:-1].values[test_data[:-1].values == 2] == 1)
assert np.all(cell_mask[-1].values[test_data[-1].values == 2] == -1)
assert np.all(cell_mask.values[test_data.values == 0] == 0)

# With stubs
cell_mask, _ = segmentation.segmentation(
test_features, test_data, 1, threshold=1, return_cells=True, stubs=-1
)

assert np.all(cell_mask[:-1].values[test_data[:-1].values == 2] == 1)
assert np.all(cell_mask[-1].values[test_data[-1].values == 2] == 0)
assert np.all(cell_mask.values[test_data.values == 0] == 0)


def test_segmentation_return_cells_no_cell_column():
"""Test segmentation with the return_cells raise the correct error if the
input features has no cell column
"""
test_data = np.zeros([3, 4, 5], dtype=int)
test_data[:, 1:3, 1:4] = 2

test_data = xr.DataArray(
test_data,
dims=("time", "y", "x"),
coords=dict(
time=pd.date_range(
datetime(2000, 1, 1, 0), datetime(2000, 1, 1, 2), periods=3
)
),
attrs=dict(units="feature"),
)

test_features = pd.DataFrame(
{
"feature": [1, 2, 3],
"frame": [0, 1, 2],
"time": pd.date_range(
datetime(2000, 1, 1, 0), datetime(2000, 1, 1, 2), periods=3
),
"hdim_1": [1.5, 1.5, 1.5],
"hdim_2": [2, 2, 2],
}
)

with pytest.raises(
ValueError,
match="`cell` column not found in features input, please perform tracking on this data before performing segmentation with *",
):
cell_mask, _ = segmentation.segmentation(
test_features, test_data, 1, threshold=1, return_cells=True
)
Loading
Loading