Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- dandi
- mlcroissant
- torch<2.4.0
- axondeepseg==5.0.4
- axondeepseg==5.3.0
- monai
- opencv-python==4.8.1.78
- stardist
7 changes: 0 additions & 7 deletions scripts/compute_dataset_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@
from AxonDeepSeg.ads_utils import imread, get_imshape
from get_data import load_datasets

# REMOVE THIS COMMENT BLOCK
# we want to compute some statistics on each dataset, such as
# - [x] number of images,
# - [x] number of labelled images,
# - [x] average image size,
# - [x] average nb of axons per image and
# - [x] average foreground-background ratio.

def compute_dataset_statistics(datapath: Path):
# the total nb of images corresponds to the length of the samples.tsv file
Expand Down
178 changes: 126 additions & 52 deletions scripts/evaluate_models.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,66 @@
from pathlib import Path
from monai.metrics import DiceMetric, MeanIoU, compute_panoptic_quality
import matplotlib.pyplot as plt
import torch
import cv2
import numpy as np
import pandas as pd
import warnings
import argparse
from skimage import measure
from AxonDeepSeg.morphometrics.compute_morphometrics import get_watershed_segmentation
from AxonDeepSeg.ads_utils import imread
from stardist.matching import matching

from get_data import load_datasets


def aggregate_det_metrics(detection_df: pd.DataFrame):
agg_dict = {
'TP': 'sum',
'FP': 'sum',
'FN': 'sum',
'precision': 'mean',
'recall': 'mean',
'accuracy': 'mean',
'f1': 'mean'
'TP': 'sum',
'FP': 'sum',
'FN': 'sum',
'precision': 'mean',
'recall': 'mean',
'accuracy': 'mean',
'f1': 'mean'
}
detection_df.drop('image', axis=1)
return detection_df.groupby('dataset').agg(agg_dict).reset_index()

def compute_metrics(pred, gt, metric):
"""
Compute the given metric for a single image
'''
Computes the given metric for a single image
Args:
pred: the prediction image
gt: the ground truth image
pred: the prediction image
gt: the ground truth image
metric: the metric to compute
Returns:
the computed metric
"""
'''
value = metric(pred, gt)
return value.item()

def make_panoptic_input(instance_map):
"""
'''
Converts a (H, W) instance map into the (B, 2, H, W) format
required by MONAI PanopticQualityMetric.
"""
'''
semantic_map = torch.tensor((instance_map > 0)).unsqueeze(dim=0).int()
instance_map = torch.tensor(instance_map).unsqueeze(dim=0).int()
panoptic_map = torch.cat([semantic_map, instance_map], dim=1)

return panoptic_map

def compute_confusion(inst_pred, inst_gt):
"""
'''
Computes the confusion matrix (TP, FP, FN, sum of IoU)
< IMPORTANT! > Do NOT use this function! This is an extremly slow
implementation, relying on MONAI's panoptic quality
metric. Also, this algorithm diverges with images with
1000+ axons. Instead, `stardist` has a much faster
matching function based on sparse graph optimization
instead of dense matrix operations.
implementation, relying on MONAI's panoptic quality metric. Also, this
algorithm diverges on images with 1000+ axons. Instead, `stardist` has
a much faster matching function based on sparse graph optimization
instead of dense matrix operations.

Args:
inst_pred: instance segmentation prediction
inst_gt: instance segmentation ground-truth
Expand All @@ -66,7 +69,7 @@ def compute_confusion(inst_pred, inst_gt):
FP: false positive count,
FN: false negative count,
sumIoU: sum of IoU used to compute Panoptic Quality
"""
'''
y_pred_2ch = make_panoptic_input(inst_pred)
y_true_2ch = make_panoptic_input(inst_gt)
confusion = compute_panoptic_quality(
Expand All @@ -88,12 +91,11 @@ def apply_watershed(ax_mask, my_mask):

return get_watershed_segmentation(ax_mask, my_mask, centroids)


def extract_binary_masks(mask):
'''
This function will take as input an 8-bit image containing both the axon
class (value should be ~255) and the myelin class (value should be ~127).
This function will also convert the numpy arrays read by opencv to Tensors.
This function's input is an 8-bit image containing both the axon class
(value should be ~255) and the myelin class (value should be ~127). It also
converts the numpy arrays read by opencv to Tensors.
'''
# axonmyelin masks should always have 3 unique values
if len(np.unique(mask)) > 3:
Expand All @@ -104,15 +106,36 @@ class (value should be ~255) and the myelin class (value should be ~127).
axon_mask = torch.from_numpy(axon_mask).float()
return axon_mask, myelin_mask

def main():
metrics = [DiceMetric(), MeanIoU()] #, PanopticQualityMetric(num_classes=1)]
def threshold_sensitivity_analysis(inst_gt, inst_pred, thresholds: list) -> pd.DataFrame:
'''
Perform a threshold sensitivity analysis by computing detection metrics
over a range of matching thresholds.
'''
columns = ['threshold', 'precision', 'recall', 'f1']
sensitivity_df = pd.DataFrame(columns=columns)
for thresh in thresholds:
print('Computing detection metrics for threshold=', thresh)
stats = matching(inst_gt, inst_pred, thresh=thresh)
row = {
'threshold': thresh,
'precision': stats.precision,
'recall': stats.recall,
'f1': stats.f1
}
sensitivity_df = pd.concat([sensitivity_df, pd.DataFrame([row])], ignore_index=True)
return sensitivity_df

def main(eval_cellpose: bool = False, matching_thresh: float = 0.3, sensitivity_analysis: bool = False):
metrics = [DiceMetric(), MeanIoU()]
metric_names = [metric.__class__.__name__ for metric in metrics]
columns = ['dataset', 'image', 'class'] + metric_names
pixelwise_df = pd.DataFrame(columns=columns)
columns_detection = ['dataset', 'image', 'TP', 'FP', 'FN', 'precision', 'recall', 'accuracy', 'f1']
detection_df = pd.DataFrame(columns=columns_detection)

data_splits_path = Path("data/splits")
if eval_cellpose:
cellpose_path = Path("data/cellpose_pipeline")
assert data_splits_path.exists(), "Data splits directory does not exist. Please run get_data.py with --make-splits arg first."

datasets = load_datasets()
Expand All @@ -127,6 +150,9 @@ def main():
potential_grayscale_img_fname = img_fname.replace('.png', '_grayscale.png')
ax_pred_fname = gt.name.replace("_seg-axonmyelin-manual.png", "_seg-axon.png")
my_pred_fname = gt.name.replace("_seg-axonmyelin-manual.png", "_seg-myelin.png")

if sensitivity_analysis and 'sub-uoftRat02_sample-uoftRat02' not in img_fname:
continue

# check if image was converted to grayscale prior to inference
if (testset_path / potential_grayscale_img_fname).exists():
Expand All @@ -140,14 +166,55 @@ def main():
gt_im = np.floor(gt_im / np.max(gt_im) * 255).astype(np.uint8)
gt_ax, gt_my = extract_binary_masks(gt_im)

# load predictions
# load axon and myelin predictions
ax_pred = cv2.imread(str(testset_path / ax_pred_fname), cv2.IMREAD_GRAYSCALE)[None]
ax_pred = np.floor(ax_pred / np.max(ax_pred) * 255).astype(np.uint8)
ax_pred, _ = extract_binary_masks(ax_pred)
my_pred = cv2.imread(str(testset_path / my_pred_fname), cv2.IMREAD_GRAYSCALE)[None]
my_pred = np.floor(my_pred / np.max(my_pred) * 255).astype(np.uint8)
my_pred, _ = extract_binary_masks(my_pred)

# INSTANCE-WISE EVALUATION
if eval_cellpose:
cp_pred_fname = gt.name.replace("_seg-axonmyelin-manual.png", "_cp_masks.png")
cp_pred_path = cellpose_path / f'cellpose_preprocessed_{dset.name}' / 'test' / cp_pred_fname
assert cp_pred_path.exists(), f"Cellpose predictions not found for {img_fname}"
inst_pred = imread(cp_pred_path, use_16bit=True)

inst_gt_fname = cp_pred_fname.replace('_cp_masks.png', '_seg-cellpose.png')
inst_gt_path = cellpose_path / f'cellpose_preprocessed_{dset.name}' / 'test' / inst_gt_fname
inst_gt = imread(inst_gt_path, use_16bit=True)
else:
inst_pred = apply_watershed(ax_pred, my_pred)
inst_gt = apply_watershed(gt_ax, gt_my)

if sensitivity_analysis and 'sub-uoftRat02_sample-uoftRat02' in img_fname:
thresholds = np.arange(0.3, 0.99, 0.032).tolist()
sensitivity_df = threshold_sensitivity_analysis(inst_gt, inst_pred, thresholds)
sensitivity_df.to_csv('sensitivity_analysis_sub-uoftRat02_sample-uoftRat02.csv', index=False)
print(sensitivity_df)
continue

stats = matching(inst_gt, inst_pred, thresh=matching_thresh)
detection_row = {
'dataset': dset.name,
'image': img_fname,
'TP': stats.tp,
'FP': stats.fp,
'FN': stats.fn,
'precision': stats.precision,
'recall': stats.recall,
'accuracy': stats.accuracy,
'f1': stats.f1
}
print(f'detection metrics: {detection_row}')
detection_df = pd.concat([detection_df, pd.DataFrame([detection_row])], ignore_index=True)

if eval_cellpose:
# for Cellpose evaluation, we only compute detection metrics
continue

# PIXEL-WISE EVALUATION
classwise_pairs = [
('axon', ax_pred, gt_ax),
('myelin', my_pred, gt_my)
Expand All @@ -165,31 +232,38 @@ def main():
row[metric.__class__.__name__] = value
pixelwise_df = pd.concat([pixelwise_df, pd.DataFrame([row])], ignore_index=True)

# compute detection metrics
inst_gt = apply_watershed(gt_ax, gt_my)
inst_pred = apply_watershed(ax_pred, my_pred)
stats = matching(inst_gt, inst_pred, thresh=0.3)
detection_row = {
'dataset': dset.name,
'image': img_fname,
'TP': stats.tp,
'FP': stats.fp,
'FN': stats.fn,
'precision': stats.precision,
'recall': stats.recall,
'accuracy': stats.accuracy,
'f1': stats.f1
}
print(f'detection metrics: {detection_row}')
detection_df = pd.concat([detection_df, pd.DataFrame([detection_row])], ignore_index=True)

# Export the dataframe to a CSV file
pixelwise_df.to_csv('metrics.csv', index=False)
print("Metrics computed and saved to metrics.csv")
detection_df.to_csv('det_metrics.csv', index=False)
print("Detection metrics computed and saved to det_metrics.csv")
print(aggregate_det_metrics(detection_df))
if eval_cellpose:
detection_df.to_csv('cellpose_det_metrics.csv', index=False)
print("Cellpose detection metrics computed and saved to cellpose_det_metrics.csv")
print(aggregate_det_metrics(detection_df))
else:
pixelwise_df.to_csv('metrics.csv', index=False)
print("Metrics computed and saved to metrics.csv")
detection_df.to_csv('det_metrics.csv', index=False)
print("Detection metrics computed and saved to det_metrics.csv")
print(aggregate_det_metrics(detection_df))


if __name__ == "__main__":
main()
ap = argparse.ArgumentParser(description="Evaluate segmentation models")
ap.add_argument(
'-c', '--eval_cellpose',
action='store_true',
default=False,
help="Evaluate Cellpose model predictions. Will only run the object detection evaluation. Expects predictions with suffix '_cp_masks'."
)
ap.add_argument(
'-t', '--matching_thresh',
type=float,
default=0.5,
help="Matching IoU threshold for object detection evaluation (default: 0.5)"
)
ap.add_argument(
'-s', '--sensitivity_analysis',
action='store_true',
default=False,
help="Perform threshold sensitivity analysis for detection metrics."
)
args = ap.parse_args()
main(args.eval_cellpose, args.matching_thresh, args.sensitivity_analysis)
21 changes: 19 additions & 2 deletions scripts/get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import requests, zipfile, io
import json

from preprocess_for_cellpose import preprocess_dataset


ASTIH_ASCII = '''
█████ ███ █████
Expand Down Expand Up @@ -121,7 +123,7 @@ def copy_files_associated(img_path, gt_paths, dest_dir):



def main(make_splits: bool):
def main(make_splits: bool, preprocess_cellpose: bool = False):
print(ASTIH_ASCII)

# Create a directory to store the downloaded data
Expand Down Expand Up @@ -149,6 +151,12 @@ def main(make_splits: bool):
print(f"Splitting {dataset.name} dataset...")
split_dataset(dataset, dataset_path, dataset_split_dir)

if preprocess_cellpose:
print(f"Preprocessing {dataset.name} dataset for Cellpose...")
output_cellpose_dir = data_dir / "cellpose_pipeline" / f'cellpose_preprocessed_{dataset.name}'
output_cellpose_dir.mkdir(parents=True, exist_ok=True)
preprocess_dataset(dataset_split_dir, output_cellpose_dir)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and split datasets.")
Expand All @@ -157,6 +165,15 @@ def main(make_splits: bool):
action="store_true",
help="Make splits for the datasets.",
)
parser.add_argument(
"--preprocess-cellpose",
action="store_true",
default=False,
help="Preprocess the datasets for Cellpose training.",
)
args = parser.parse_args()

main(args.make_splits)
if args.preprocess_cellpose and not args.make_splits:
parser.error("--preprocess-cellpose requires --make-splits to be set.")

main(args.make_splits, args.preprocess_cellpose)
Loading