diff --git a/environment.yaml b/environment.yaml index 97ae7b6..7b06a62 100644 --- a/environment.yaml +++ b/environment.yaml @@ -6,7 +6,7 @@ dependencies: - dandi - mlcroissant - torch<2.4.0 - - axondeepseg==5.0.4 + - axondeepseg==5.3.0 - monai - opencv-python==4.8.1.78 - stardist diff --git a/scripts/compute_dataset_statistics.py b/scripts/compute_dataset_statistics.py index 427967a..8dc0f4a 100644 --- a/scripts/compute_dataset_statistics.py +++ b/scripts/compute_dataset_statistics.py @@ -2,13 +2,6 @@ from AxonDeepSeg.ads_utils import imread, get_imshape from get_data import load_datasets -# REMOVE THIS COMMENT BLOCK -# we want to compute some statistics on each dataset, such as -# - [x] number of images, -# - [x] number of labelled images, -# - [x] average image size, -# - [x] average nb of axons per image and -# - [x] average foreground-background ratio. def compute_dataset_statistics(datapath: Path): # the total nb of images corresponds to the length of the samples.tsv file diff --git a/scripts/evaluate_models.py b/scripts/evaluate_models.py index 53b53a5..eb6e6dd 100644 --- a/scripts/evaluate_models.py +++ b/scripts/evaluate_models.py @@ -1,12 +1,15 @@ from pathlib import Path from monai.metrics import DiceMetric, MeanIoU, compute_panoptic_quality +import matplotlib.pyplot as plt import torch import cv2 import numpy as np import pandas as pd import warnings +import argparse from skimage import measure from AxonDeepSeg.morphometrics.compute_morphometrics import get_watershed_segmentation +from AxonDeepSeg.ads_utils import imread from stardist.matching import matching from get_data import load_datasets @@ -14,35 +17,35 @@ def aggregate_det_metrics(detection_df: pd.DataFrame): agg_dict = { - 'TP': 'sum', - 'FP': 'sum', - 'FN': 'sum', - 'precision': 'mean', - 'recall': 'mean', - 'accuracy': 'mean', - 'f1': 'mean' + 'TP': 'sum', + 'FP': 'sum', + 'FN': 'sum', + 'precision': 'mean', + 'recall': 'mean', + 'accuracy': 'mean', + 'f1': 'mean' } detection_df.drop('image', axis=1) return detection_df.groupby('dataset').agg(agg_dict).reset_index() def compute_metrics(pred, gt, metric): - """ - Compute the given metric for a single image + ''' + Computes the given metric for a single image Args: - pred: the prediction image - gt: the ground truth image + pred: the prediction image + gt: the ground truth image metric: the metric to compute Returns: the computed metric - """ + ''' value = metric(pred, gt) return value.item() def make_panoptic_input(instance_map): - """ + ''' Converts a (H, W) instance map into the (B, 2, H, W) format required by MONAI PanopticQualityMetric. - """ + ''' semantic_map = torch.tensor((instance_map > 0)).unsqueeze(dim=0).int() instance_map = torch.tensor(instance_map).unsqueeze(dim=0).int() panoptic_map = torch.cat([semantic_map, instance_map], dim=1) @@ -50,14 +53,14 @@ def make_panoptic_input(instance_map): return panoptic_map def compute_confusion(inst_pred, inst_gt): - """ + ''' Computes the confusion matrix (TP, FP, FN, sum of IoU) < IMPORTANT! > Do NOT use this function! This is an extremly slow - implementation, relying on MONAI's panoptic quality - metric. Also, this algorithm diverges with images with - 1000+ axons. Instead, `stardist` has a much faster - matching function based on sparse graph optimization - instead of dense matrix operations. + implementation, relying on MONAI's panoptic quality metric. Also, this + algorithm diverges on images with 1000+ axons. Instead, `stardist` has + a much faster matching function based on sparse graph optimization + instead of dense matrix operations. + Args: inst_pred: instance segmentation prediction inst_gt: instance segmentation ground-truth @@ -66,7 +69,7 @@ def compute_confusion(inst_pred, inst_gt): FP: false positive count, FN: false negative count, sumIoU: sum of IoU used to compute Panoptic Quality - """ + ''' y_pred_2ch = make_panoptic_input(inst_pred) y_true_2ch = make_panoptic_input(inst_gt) confusion = compute_panoptic_quality( @@ -88,12 +91,11 @@ def apply_watershed(ax_mask, my_mask): return get_watershed_segmentation(ax_mask, my_mask, centroids) - def extract_binary_masks(mask): ''' - This function will take as input an 8-bit image containing both the axon - class (value should be ~255) and the myelin class (value should be ~127). - This function will also convert the numpy arrays read by opencv to Tensors. + This function's input is an 8-bit image containing both the axon class + (value should be ~255) and the myelin class (value should be ~127). It also + converts the numpy arrays read by opencv to Tensors. ''' # axonmyelin masks should always have 3 unique values if len(np.unique(mask)) > 3: @@ -104,8 +106,27 @@ class (value should be ~255) and the myelin class (value should be ~127). axon_mask = torch.from_numpy(axon_mask).float() return axon_mask, myelin_mask -def main(): - metrics = [DiceMetric(), MeanIoU()] #, PanopticQualityMetric(num_classes=1)] +def threshold_sensitivity_analysis(inst_gt, inst_pred, thresholds: list) -> pd.DataFrame: + ''' + Perform a threshold sensitivity analysis by computing detection metrics + over a range of matching thresholds. + ''' + columns = ['threshold', 'precision', 'recall', 'f1'] + sensitivity_df = pd.DataFrame(columns=columns) + for thresh in thresholds: + print('Computing detection metrics for threshold=', thresh) + stats = matching(inst_gt, inst_pred, thresh=thresh) + row = { + 'threshold': thresh, + 'precision': stats.precision, + 'recall': stats.recall, + 'f1': stats.f1 + } + sensitivity_df = pd.concat([sensitivity_df, pd.DataFrame([row])], ignore_index=True) + return sensitivity_df + +def main(eval_cellpose: bool = False, matching_thresh: float = 0.3, sensitivity_analysis: bool = False): + metrics = [DiceMetric(), MeanIoU()] metric_names = [metric.__class__.__name__ for metric in metrics] columns = ['dataset', 'image', 'class'] + metric_names pixelwise_df = pd.DataFrame(columns=columns) @@ -113,6 +134,8 @@ def main(): detection_df = pd.DataFrame(columns=columns_detection) data_splits_path = Path("data/splits") + if eval_cellpose: + cellpose_path = Path("data/cellpose_pipeline") assert data_splits_path.exists(), "Data splits directory does not exist. Please run get_data.py with --make-splits arg first." datasets = load_datasets() @@ -127,6 +150,9 @@ def main(): potential_grayscale_img_fname = img_fname.replace('.png', '_grayscale.png') ax_pred_fname = gt.name.replace("_seg-axonmyelin-manual.png", "_seg-axon.png") my_pred_fname = gt.name.replace("_seg-axonmyelin-manual.png", "_seg-myelin.png") + + if sensitivity_analysis and 'sub-uoftRat02_sample-uoftRat02' not in img_fname: + continue # check if image was converted to grayscale prior to inference if (testset_path / potential_grayscale_img_fname).exists(): @@ -140,7 +166,7 @@ def main(): gt_im = np.floor(gt_im / np.max(gt_im) * 255).astype(np.uint8) gt_ax, gt_my = extract_binary_masks(gt_im) - # load predictions + # load axon and myelin predictions ax_pred = cv2.imread(str(testset_path / ax_pred_fname), cv2.IMREAD_GRAYSCALE)[None] ax_pred = np.floor(ax_pred / np.max(ax_pred) * 255).astype(np.uint8) ax_pred, _ = extract_binary_masks(ax_pred) @@ -148,6 +174,47 @@ def main(): my_pred = np.floor(my_pred / np.max(my_pred) * 255).astype(np.uint8) my_pred, _ = extract_binary_masks(my_pred) + # INSTANCE-WISE EVALUATION + if eval_cellpose: + cp_pred_fname = gt.name.replace("_seg-axonmyelin-manual.png", "_cp_masks.png") + cp_pred_path = cellpose_path / f'cellpose_preprocessed_{dset.name}' / 'test' / cp_pred_fname + assert cp_pred_path.exists(), f"Cellpose predictions not found for {img_fname}" + inst_pred = imread(cp_pred_path, use_16bit=True) + + inst_gt_fname = cp_pred_fname.replace('_cp_masks.png', '_seg-cellpose.png') + inst_gt_path = cellpose_path / f'cellpose_preprocessed_{dset.name}' / 'test' / inst_gt_fname + inst_gt = imread(inst_gt_path, use_16bit=True) + else: + inst_pred = apply_watershed(ax_pred, my_pred) + inst_gt = apply_watershed(gt_ax, gt_my) + + if sensitivity_analysis and 'sub-uoftRat02_sample-uoftRat02' in img_fname: + thresholds = np.arange(0.3, 0.99, 0.032).tolist() + sensitivity_df = threshold_sensitivity_analysis(inst_gt, inst_pred, thresholds) + sensitivity_df.to_csv('sensitivity_analysis_sub-uoftRat02_sample-uoftRat02.csv', index=False) + print(sensitivity_df) + continue + + stats = matching(inst_gt, inst_pred, thresh=matching_thresh) + detection_row = { + 'dataset': dset.name, + 'image': img_fname, + 'TP': stats.tp, + 'FP': stats.fp, + 'FN': stats.fn, + 'precision': stats.precision, + 'recall': stats.recall, + 'accuracy': stats.accuracy, + 'f1': stats.f1 + } + print(f'detection metrics: {detection_row}') + detection_df = pd.concat([detection_df, pd.DataFrame([detection_row])], ignore_index=True) + + if eval_cellpose: + # for Cellpose evaluation, we only compute detection metrics + continue + + # PIXEL-WISE EVALUATION classwise_pairs = [ ('axon', ax_pred, gt_ax), ('myelin', my_pred, gt_my) @@ -165,31 +232,38 @@ def main(): row[metric.__class__.__name__] = value pixelwise_df = pd.concat([pixelwise_df, pd.DataFrame([row])], ignore_index=True) - # compute detection metrics - inst_gt = apply_watershed(gt_ax, gt_my) - inst_pred = apply_watershed(ax_pred, my_pred) - stats = matching(inst_gt, inst_pred, thresh=0.3) - detection_row = { - 'dataset': dset.name, - 'image': img_fname, - 'TP': stats.tp, - 'FP': stats.fp, - 'FN': stats.fn, - 'precision': stats.precision, - 'recall': stats.recall, - 'accuracy': stats.accuracy, - 'f1': stats.f1 - } - print(f'detection metrics: {detection_row}') - detection_df = pd.concat([detection_df, pd.DataFrame([detection_row])], ignore_index=True) - # Export the dataframe to a CSV file - pixelwise_df.to_csv('metrics.csv', index=False) - print("Metrics computed and saved to metrics.csv") - detection_df.to_csv('det_metrics.csv', index=False) - print("Detection metrics computed and saved to det_metrics.csv") - print(aggregate_det_metrics(detection_df)) + if eval_cellpose: + detection_df.to_csv('cellpose_det_metrics.csv', index=False) + print("Cellpose detection metrics computed and saved to cellpose_det_metrics.csv") + print(aggregate_det_metrics(detection_df)) + else: + pixelwise_df.to_csv('metrics.csv', index=False) + print("Metrics computed and saved to metrics.csv") + detection_df.to_csv('det_metrics.csv', index=False) + print("Detection metrics computed and saved to det_metrics.csv") + print(aggregate_det_metrics(detection_df)) if __name__ == "__main__": - main() + ap = argparse.ArgumentParser(description="Evaluate segmentation models") + ap.add_argument( + '-c', '--eval_cellpose', + action='store_true', + default=False, + help="Evaluate Cellpose model predictions. Will only run the object detection evaluation. Expects predictions with suffix '_cp_masks'." + ) + ap.add_argument( + '-t', '--matching_thresh', + type=float, + default=0.5, + help="Matching IoU threshold for object detection evaluation (default: 0.5)" + ) + ap.add_argument( + '-s', '--sensitivity_analysis', + action='store_true', + default=False, + help="Perform threshold sensitivity analysis for detection metrics." + ) + args = ap.parse_args() + main(args.eval_cellpose, args.matching_thresh, args.sensitivity_analysis) \ No newline at end of file diff --git a/scripts/get_data.py b/scripts/get_data.py index 6d12e06..1876cc1 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -5,6 +5,8 @@ import requests, zipfile, io import json +from preprocess_for_cellpose import preprocess_dataset + ASTIH_ASCII = ''' █████ ███ █████ @@ -121,7 +123,7 @@ def copy_files_associated(img_path, gt_paths, dest_dir): -def main(make_splits: bool): +def main(make_splits: bool, preprocess_cellpose: bool = False): print(ASTIH_ASCII) # Create a directory to store the downloaded data @@ -149,6 +151,12 @@ def main(make_splits: bool): print(f"Splitting {dataset.name} dataset...") split_dataset(dataset, dataset_path, dataset_split_dir) + if preprocess_cellpose: + print(f"Preprocessing {dataset.name} dataset for Cellpose...") + output_cellpose_dir = data_dir / "cellpose_pipeline" / f'cellpose_preprocessed_{dataset.name}' + output_cellpose_dir.mkdir(parents=True, exist_ok=True) + preprocess_dataset(dataset_split_dir, output_cellpose_dir) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download and split datasets.") @@ -157,6 +165,15 @@ def main(make_splits: bool): action="store_true", help="Make splits for the datasets.", ) + parser.add_argument( + "--preprocess-cellpose", + action="store_true", + default=False, + help="Preprocess the datasets for Cellpose training.", + ) args = parser.parse_args() - main(args.make_splits) + if args.preprocess_cellpose and not args.make_splits: + parser.error("--preprocess-cellpose requires --make-splits to be set.") + + main(args.make_splits, args.preprocess_cellpose) diff --git a/scripts/plot_sensitivity_analysis.py b/scripts/plot_sensitivity_analysis.py new file mode 100644 index 0000000..b51598c --- /dev/null +++ b/scripts/plot_sensitivity_analysis.py @@ -0,0 +1,80 @@ +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Plot sensitivity analysis results.") + parser.add_argument( + '-n', '--nnunet_csv', + type=str, + required=True, + help="Path to the CSV file containing sensitivity analysis results for the nnunet model.", + ) + parser.add_argument( + '-c', '--cellpose_csv', + type=str, + required=True, + help="Path to the CSV file containing sensitivity analysis results for the cellpose model.", + ) + args = parser.parse_args() + + nnunet_df = pd.read_csv(args.nnunet_csv) + cellpose_df = pd.read_csv(args.cellpose_csv) + + plt.figure(figsize=(10, 8)) + sns.set_context("notebook", font_scale=2) # Increase font size + sns.set_style('ticks') + sns.set_palette('mako_r', n_colors=2) + plt.tick_params(axis='both', which='major', labelsize=20) # Adjust tick label font size + sns.lineplot(data=nnunet_df, x='threshold', y='f1', label='nnU-Net') + sns.lineplot(data=cellpose_df, x='threshold', y='f1', label='Cellpose', linestyle='dashed') + plt.title('Sensitivity analysis of IoU threshold on F1 Score') + plt.xlabel('IoU Threshold', labelpad=5) + plt.ylabel('F1 Score') + plt.xlim((0.3, 0.99)) + plt.ylim((0, 1)) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig('f1_sensitivity_analysis_plot.png', dpi=300) + print("Sensitivity analysis plot saved as 'f1_sensitivity_analysis_plot.png'") + + # now, we do the same for precision and recall + plt.figure(figsize=(10, 8)) + sns.set_context("notebook", font_scale=2) # Increase font size + sns.set_style('ticks') + sns.set_palette('mako_r', n_colors=2) + plt.tick_params(axis='both', which='major', labelsize=20) # Adjust tick label font size + sns.lineplot(data=nnunet_df, x='threshold', y='precision', label='nnU-Net') + sns.lineplot(data=cellpose_df, x='threshold', y='precision', label='Cellpose', linestyle='dashed') + plt.title('Sensitivity analysis of IoU threshold on Precision') + plt.xlabel('IoU Threshold', labelpad=5) + plt.ylabel('Precision') + plt.xlim((0.3, 0.99)) + plt.ylim((0, 1)) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig('precision_sensitivity_analysis_plot.png', dpi=300) + print("Sensitivity analysis plot saved as 'precision_sensitivity_analysis_plot.png'") + + plt.figure(figsize=(10, 8)) + sns.set_context("notebook", font_scale=2) + sns.set_style('ticks') + sns.set_palette('mako_r', n_colors=2) + plt.tick_params(axis='both', which='major', labelsize=20) + sns.lineplot(data=nnunet_df, x='threshold', y='recall', label='nnU-Net') + sns.lineplot(data=cellpose_df, x='threshold', y='recall', label='Cellpose', linestyle='dashed') + plt.title('Sensitivity analysis of IoU threshold on Recall') + plt.xlabel('IoU Threshold', labelpad=5) + plt.ylabel('Recall') + plt.xlim((0.3, 0.99)) + plt.ylim((0, 1)) + plt.grid(True) + plt.legend() + plt.tight_layout() + plt.savefig('recall_sensitivity_analysis_plot.png', dpi=300) + print("Sensitivity analysis plot saved as 'recall_sensitivity_analysis_plot.png'") \ No newline at end of file diff --git a/scripts/preprocess_for_cellpose.py b/scripts/preprocess_for_cellpose.py new file mode 100644 index 0000000..15b63fa --- /dev/null +++ b/scripts/preprocess_for_cellpose.py @@ -0,0 +1,104 @@ +'''This file provides utilities to preprocess the dataset into a format suitable +for Cellpose training and inference. + +NOTE: unfortunately, the Cellpose dependency collides with other dependencies in +our project so we can't run the training in the same environment. +''' + +from pathlib import Path +import argparse +import shutil + +from AxonDeepSeg.ads_utils import imread, imwrite, get_imshape +from skimage import measure +import numpy as np + +CELLPOSE_MASK_SUFFIX = '_seg-cellpose.png' + + +def find_all_images_and_masks(dir_path: Path, mask_suffix: str = '_seg-axonmyelin-manual.png') -> list[tuple[Path, Path]]: + """ + Find all image and corresponding mask file paths in the given directory. + + Parameters: + - dir_path: Path to the directory to search. + - mask_suffix: Suffix used to identify mask files. + + Returns: + - List of tuples containing (image_path, mask_path). + """ + image_mask_pairs = [] + mask_files = list(dir_path.glob(f'*{mask_suffix}')) + for mask_file in mask_files: + image_file = mask_file.with_name(mask_file.name.replace(mask_suffix, '.png')) + if image_file.exists(): + image_mask_pairs.append((image_file, mask_file)) + else: + # the image might actually be in TIFF format + image_file_tiff = mask_file.with_name(mask_file.name.replace(mask_suffix, '.tif')) + if image_file_tiff.exists(): + image_mask_pairs.append((image_file_tiff, mask_file)) + else: + print(f'Warning: No corresponding image found for mask {mask_file}') + return image_mask_pairs + +def convert_axonmyelin_mask_to_cellpose(mask_path: Path, output_path: Path): + """ + Convert an axon-myelin segmentation mask to a Cellpose-compatible mask. + + In the axon-myelin mask: + - Background: 0 + - Myelin: 127 + - Axon: 255 + + In the Cellpose mask: + - Instance segmentation + - Background: 0 + - Cell 1: 1 (axon and myelin combined) + - ... and so on for each cell instance. + + Parameters: + - mask_path: Path to the input axon-myelin mask. + - output_path: Path to save the converted Cellpose mask. + """ + mask = imread(str(mask_path)) + cellpose_mask = (mask > 0) # Set axon and myelin to 1, background to 0 + cellpose_mask = measure.label(cellpose_mask, connectivity=1) # Label connected components + + # Ensure the mask is cast to a supported data type + cellpose_mask = cellpose_mask.astype(np.uint16) + + imwrite(str(output_path), cellpose_mask, use_16bit=True) + +def preprocess_dataset(data_dir: Path, output_dir: Path): + train_dir = data_dir / 'train' + test_dir = data_dir / 'test' + + if not train_dir.exists() or not test_dir.exists(): + raise ValueError("The provided data_dir must contain 'train/' and 'test/' subdirectories.") + + output_train_dir = output_dir / 'train' + output_test_dir = output_dir / 'test' + for out_dir in [output_train_dir, output_test_dir]: + out_dir.mkdir(parents=True, exist_ok=True) + + train_data = find_all_images_and_masks(train_dir) + test_data = find_all_images_and_masks(test_dir) + + for data, output_dir in zip([train_data, test_data], [output_train_dir, output_test_dir]): + for image_path, mask_path in data: + output_image_path = output_dir / image_path.name + output_mask_path = output_dir / (mask_path.stem.replace('_seg-axonmyelin-manual', CELLPOSE_MASK_SUFFIX)) + + shutil.copy(image_path, output_image_path) + convert_axonmyelin_mask_to_cellpose(mask_path, output_mask_path) + +if __name__ == '__main__': + ap = argparse.ArgumentParser(description="Preprocess dataset for Cellpose") + ap.add_argument('data_dir', type=str, help="Path to the dataset (split into train/ and test/ directories)") + ap.add_argument('--output_dir', type=str, default=None, help="Path to save the preprocessed data") + args = ap.parse_args() + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) if args.output_dir else Path('.') / 'cellpose_preprocessed' + + preprocess_dataset(data_dir, output_dir) \ No newline at end of file