arXiv: link; ICCV Proceedings: link.
Original Github page: bw
Topology-preserving image segmentation loss function for curvilinear structure segmentation with spatially-aware persistent feature matching leveraging critical cell locations. Prior work (BMLoss) addressing similar topological feature matching accuracy issue is computationally expensive (O(n³) time complexity, n is number of pixels in an image) and is impractical for many applications. Our method introduces no additional complexity to computing the persistent homology (O(nlog(n))). While achieving comparable performance, it is much more efficient and allows wider applications.
A high-level overview:

Matching examples (0/1-persistent features):

For implementation of the Spatial-Aware Topological Loss (SATLoss), check PDMatchingLoss in ./utils/losses.py and SpatialAware_WassersteinDistance in ./utils/PDMatching.py.
Dependencies:
python==3.9.18 torch==2.1.0 numpy==1.26.0 scikit-image==0.24.0 opencv-python==4.8.1.78 gudhi==3.9.0 torch-topological==0.1.7 POT==0.8.2
To train using the topological loss:
python3 run.py --expmode train --dataset dataset_name --exp note_for_this_run
Hyperparamters varies from dataset to dataset, check paper and modify via parser in run.py.
To run inference:
python3 run.py --expmode test --dataset dataset_name --exp note_for_this_run
Checkpoints and results are saved to ./exp/note_for_this_run.
Put images and labels (support any format cv2.imread can read) of respective set (train/val/test) into ./data/dataset_name/set. An example is provided in ./data.
dataset_name needs to be the same with --dataset_name xxx when runing run.py.
Names of corresponding input images and labels need to be the same.
