diff --git a/mmv_im2im/bin/run_im2im.py b/mmv_im2im/bin/run_im2im.py index 42593b9..374330c 100644 --- a/mmv_im2im/bin/run_im2im.py +++ b/mmv_im2im/bin/run_im2im.py @@ -13,7 +13,9 @@ # import torch + from mmv_im2im import ProjectTester, ProjectTrainer +from mmv_im2im.map_extractor import MapExtractor from mmv_im2im.configs.config_base import ( ProgramConfig, parse_adaptor, @@ -32,6 +34,7 @@ ############################################################################### TRAIN_MODE = "train" INFER_MODE = "inference" +MAP_MODE = "uncertainty_map" ############################################################################### @@ -70,6 +73,9 @@ def main(): elif cfg.mode.lower() == INFER_MODE: exe = ProjectTester(cfg) exe.run_inference() + elif cfg.mode.lower() == MAP_MODE: + exe = MapExtractor(cfg) + exe.run_inference() else: log.error(f"Mode {cfg.mode} is not supported yet") sys.exit(1) diff --git a/mmv_im2im/map_extractor.py b/mmv_im2im/map_extractor.py new file mode 100644 index 0000000..b84acb0 --- /dev/null +++ b/mmv_im2im/map_extractor.py @@ -0,0 +1,944 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import logging +from typing import Union +from dask.array.core import Array as DaskArray +from numpy import ndarray as NumpyArray +from importlib import import_module +from pathlib import Path +import numpy as np +from bioio import BioImage +from bioio.writers import OmeTiffWriter +import torch +from mmv_im2im.utils.misc import parse_config +from mmv_im2im.utils.for_transform import parse_monai_ops_vanilla +from skimage.io import imsave as save_rgb +import bioio_tifffile +from mmv_im2im.utils.urcentainity_extractor import ( + Hole_Correction, + Thickness_Corretion, + Remove_objects, + Extract_Uncertainty_Maps, + perturb_image, + Perycites_correction, +) +from monai.inferers import sliding_window_inference +import itertools +from mmv_im2im.utils.multi_pred import ( + variance_prediction, + mean_prediction, + max_prediction, + add_prediction, +) +from bioio_base.types import PhysicalPixelSizes + +# https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html#predicting +############################################################################### + +log = logging.getLogger(__name__) + +############################################################################### + + +class MapExtractor(object): + """ + entry for training models + + Parameters + ---------- + cfg: configuration + """ + + def __init__(self, cfg): + # extract the three major chuck of the config + self.model_cfg = cfg.model + self.data_cfg = cfg.data + + # define variables + self.model = None + self.data = None + self.pre_process = None + self.cpu = False + self.spatial_dims = -1 + + def setup_model(self): + model_category = self.model_cfg.framework + model_module = import_module(f"mmv_im2im.models.pl_{model_category}") + my_model_func = getattr(model_module, "Model") + self.model = my_model_func(self.model_cfg, train=False) + + if ( + self.model_cfg.model_extra is not None + and "cpu_only" in self.model_cfg.model_extra + and self.model_cfg.model_extra["cpu_only"] + ): + self.cpu = True + checkpoint = torch.load( + self.model_cfg.checkpoint, + map_location=torch.device("cpu"), + weights_only=False, + ) + else: + checkpoint = torch.load(self.model_cfg.checkpoint, weights_only=False) + + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + + pre_train = checkpoint + pre_train["state_dict"].pop("criterion.xym", None) + pre_train["state_dict"].pop("criterion.xyzm", None) + self.model.load_state_dict(pre_train["state_dict"], strict=False) + else: + + state_dict = checkpoint + state_dict.pop("criterion.xym", None) + state_dict.pop("criterion.xyzm", None) + self.model.load_state_dict(state_dict, strict=False) + + if not self.cpu: + self.model.cuda() + + self.model.eval() + + def setup_data_processing(self): + # determine spatial dimension from reader parameters + if "Z" in self.data_cfg.inference_input.reader_params["dimension_order_out"]: + self.spatial_dims = 3 + else: + self.spatial_dims = 2 + + # prepare data preprocessing if needed + if self.data_cfg.preprocess is not None: + # load preprocessing transformation + self.pre_process = parse_monai_ops_vanilla(self.data_cfg.preprocess) + + def process_one_image( + self, img: Union[DaskArray, NumpyArray], out_fn: Union[str, Path] = None + ): + + if isinstance(img, DaskArray): + # Perform the prediction + x = img.compute() + + elif isinstance(img, NumpyArray): + x = img + else: + raise ValueError("invalid image") + + # check if need to add channel dimension + if len(x.shape) == self.spatial_dims: + x = np.expand_dims(x, axis=0) + + # convert the numpy array to float tensor + x = torch.tensor(x.astype(np.float32)) + + # run pre-processing on tensor if needed + + if self.pre_process is not None: + x = self.pre_process(x) + + # choose different inference function for different types of models + # the input here is assumed to be a tensor + with torch.no_grad(): + # add batch dimension and move to GPU + + if self.cpu: + x = torch.unsqueeze(x, dim=0) + else: + x = torch.unsqueeze(x, dim=0).cuda() + + # TODO: add convert to tensor with proper type, similar to torchio check + + if ( + self.model_cfg.model_extra is not None + and "sliding_window_params" in self.model_cfg.model_extra + ): + y_hat = sliding_window_inference( + inputs=x, + predictor=self.model, + device=torch.device("cpu"), + **self.model_cfg.model_extra["sliding_window_params"], + ) + + # currently, we keep sliding window stiching step on CPU, but assume + # the output is on GPU (see note below). So, we manually move the data + # back to GPU + if not self.cpu: + y_hat = y_hat.cuda() + else: + y_hat = self.model(x) + + ############################################################################### + # + # Note: currently, we assume y_hat is still on gpu, because embedseg clustering + # step is still only running on GPU (possible on CPU, need to some update on + # grid loading). All the post-procesisng functions we tested so far can accept + # tensor on GPU. If it is from mmv_im2im.post_processing, it will automatically + # convert the tensor to a numpy array and return the result as numpy array; if + # it is from monai.transforms, it is tensor in and tensor out. We have two items + # as #TODO: (1) we will extend post-processing functions in mmv_im2im to work + # similarly to monai transforms, ie. ndarray in ndarray out or tensor in tensor + # out. (2) allow yaml config to control if we want to run post-processing on + # GPU tensors or ndarrays + # + ############################################################################## + + # do post-processing on the prediction + if self.data_cfg.postprocess is not None: + pp_data = y_hat + for pp_info in self.data_cfg.postprocess: + pp = parse_config(pp_info) + pp_data = pp(pp_data) + if torch.is_tensor(pp_data): + pred = pp_data.cpu().numpy() + else: + pred = pp_data + else: + pred = y_hat.cpu().numpy() + + if out_fn is None: + return pred + + # determine output dimension orders + if out_fn.suffix == ".npy": + np.save(out_fn, pred) + else: + if len(pred.shape) == 2: + OmeTiffWriter.save(pred, out_fn, dim_order="YX") + elif len(pred.shape) == 3: + # 3D output, for 2D data + if self.spatial_dims == 2: + # save as RGB or multi-channel 2D + if pred.shape[0] == 3: + if out_fn.suffix != ".png": + out_fn = out_fn.with_suffix(".png") + save_rgb(out_fn, np.moveaxis(pred, 0, -1)) + else: + OmeTiffWriter.save(pred, out_fn, dim_order="CYX") + elif self.spatial_dims == 3: + OmeTiffWriter.save(pred, out_fn, dim_order="ZYX") + else: + raise ValueError("Invalid spatial dimension of pred") + elif len(pred.shape) == 4: + if self.spatial_dims == 3: + OmeTiffWriter.save(pred, out_fn, dim_order="CZYX") + elif self.spatial_dims == 2: + if pred.shape[0] == 1: + if pred.shape[1] == 1: + OmeTiffWriter.save(pred[0, 0], out_fn, dim_order="YX") + elif pred.shape[1] == 3: + if out_fn.suffix != ".png": + out_fn = out_fn.with_suffix(".png") + save_rgb( + out_fn, + np.moveaxis( + pred[0,], + 0, + -1, + ), + ) + else: + OmeTiffWriter.save( + pred[0,], + out_fn, + dim_order="CYX", + ) + else: + raise ValueError("invalid 4D output for 2d data") + elif len(pred.shape) == 5: + assert pred.shape[0] == 1, "error, found non-trivial batch dimension" + OmeTiffWriter.save( + pred[0,], + out_fn, + dim_order="CZYX", + ) + else: + raise ValueError("error in prediction output shape") + + def run_inference(self): + + self.setup_model() + if "pred_slice2vol" in self.model_cfg.net: + + if self.model_cfg.net["pred_slice2vol"] is not None: + # handle multiple kind of elements tiff/tif + if "," in self.data_cfg.inference_input.data_type: + types = self.data_cfg.inference_input.data_type.split(",") + extensions = [f"*{tipe}" for tipe in types] + filenames = sorted( + list( + itertools.chain.from_iterable( + self.data_cfg.inference_input.dir.glob(extension) + for extension in extensions + ) + ) + ) + else: + filenames = sorted( + self.data_cfg.inference_input.dir.glob( + "*" + self.data_cfg.inference_input.data_type + ) + ) + + vs_flag = PhysicalPixelSizes(1, 1, 1) + if "pixel_dim" in self.model_cfg.net["pred_slice2vol"]: + if self.model_cfg.net["pred_slice2vol"]["pixel_dim"] is not None: + if isinstance( + self.model_cfg.net["pred_slice2vol"]["pixel_dim"], str + ): + vs_flag = "auto" + elif isinstance( + self.model_cfg.net["pred_slice2vol"]["pixel_dim"], tuple + ): + if ( + len(self.model_cfg.net["pred_slice2vol"]["pixel_dim"]) + == 3 + ): + z, y, x = self.model_cfg.net["pred_slice2vol"][ + "pixel_dim" + ] + vs_flag = PhysicalPixelSizes(z, y, x) + elif isinstance( + self.model_cfg.net["pred_slice2vol"]["pixel_dim"], list + ): + if ( + len(self.model_cfg.net["pred_slice2vol"]["pixel_dim"]) + == 3 + ): + z, y, x = self.model_cfg.net["pred_slice2vol"][ + "pixel_dim" + ] + vs_flag = PhysicalPixelSizes(z, y, x) + + max_proj = False + if "max_proj" in self.model_cfg.net["pred_slice2vol"]: + if self.model_cfg.net["pred_slice2vol"]["max_proj"] is not None: + if isinstance( + self.model_cfg.net["pred_slice2vol"]["max_proj"], bool + ): + max_proj = self.model_cfg.net["pred_slice2vol"]["max_proj"] + + perycites_correction = False + if "perycites_correction" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["perycites_correction"] + is not None + ): + perycites_correction = self.model_cfg.net["pred_slice2vol"][ + "perycites_correction" + ] + + print( + f"########################################### {len(filenames)} Volume(s) found for prediction ###########################################" + ) + + if ( + self.model_cfg.net["pred_slice2vol"]["uncertainity_map"] + and "prob" not in self.model_cfg.net["func_name"].lower() + ): + print( + "##################################################################### Warning #####################################################################" + ) + print(f"Your selected Model is {self.model_cfg.net['func_name']}") + print( + "If the model is NOT probabilistic the uncertainity map and multiple prediction mode won't have sense" + ) + # save post process indicated by the user + original_postprocess = self.data_cfg.postprocess + # we set non postprocess we need logits for the model + self.data_cfg.postprocess = None + + n_trunc = -1 + threshold_um = -1 + border_corr = False + pert_opt = False + # handle uncertainity maps generation + if "uncertainity_map" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["uncertainity_map"] + is not None + ): + if self.model_cfg.net["pred_slice2vol"]["uncertainity_map"]: + + if "trunc" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["trunc"] + is not None + ): + if isinstance( + self.model_cfg.net["pred_slice2vol"]["trunc"], + bool, + ): + if not self.model_cfg.net["pred_slice2vol"][ + "trunc" + ]: + n_trunc = -1 + else: + n_trunc = 4 + else: + if ( + type( + self.model_cfg.net["pred_slice2vol"][ + "trunc" + ] + ) + is int + and self.model_cfg.net["pred_slice2vol"][ + "trunc" + ] + >= 0 + ): + n_trunc = self.model_cfg.net[ + "pred_slice2vol" + ]["trunc"] + else: + raise ValueError( + f"Unexpected Value for trunc: {self.model_cfg.net['pred_slice2vol']['trunc']}. It should be a positive integer" + ) + else: + n_trunc = 4 + else: + n_trunc = 4 + + if "threshold" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["threshold"] + is not None + ): + if isinstance( + self.model_cfg.net["pred_slice2vol"][ + "threshold" + ], + bool, + ): + if not self.model_cfg.net["pred_slice2vol"][ + "threshold" + ]: + threshold_um = -1 + else: + raise ValueError( + f"Unexpected Value for threshold: {self.model_cfg.net['pred_slice2vol']['threshold']}. It should be a positive float" + ) + else: + if ( + type( + self.model_cfg.net["pred_slice2vol"][ + "threshold" + ] + ) + is float + and self.model_cfg.net["pred_slice2vol"][ + "threshold" + ] + > 0 + ): + threshold_um = self.model_cfg.net[ + "pred_slice2vol" + ]["threshold"] + else: + raise ValueError( + f"Unexpected Value for threshold: {self.model_cfg.net['pred_slice2vol']['threshold']}. It should be a positive float" + ) + else: + threshold_um = -1 + else: + threshold_um = -1 + + if ( + "border_correction" + in self.model_cfg.net["pred_slice2vol"] + ): + if ( + self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ] + is not None + ): + if isinstance( + self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ], + bool, + ): + if not self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ]: + border_corr = False + else: + raise ValueError( + f"Unexpected Value for border_correction: {self.model_cfg.net['pred_slice2vol']['border_correction']}." + ) + else: + if ( + type( + self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ] + ) + is int + and self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ] + >= 0 + ): + if ( + self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ] + == 0 + ): + border_corr = False + else: + border_corr = [ + self.model_cfg.net[ + "pred_slice2vol" + ]["border_correction"] + ] * 2 + elif ( + type( + self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ] + ) + is list + and len( + self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ] + ) + <= 2 + ): + if self.model_cfg.net["pred_slice2vol"][ + "border_correction" + ] != [0, 0] and self.model_cfg.net[ + "pred_slice2vol" + ][ + "border_correction" + ] != [ + 0 + ]: + if ( + len( + self.model_cfg.net[ + "pred_slice2vol" + ]["border_correction"] + ) + == 1 + ): + border_corr = ( + self.model_cfg.net[ + "pred_slice2vol" + ]["border_correction"] + * 2 + ) + else: + border_corr = self.model_cfg.net[ + "pred_slice2vol" + ]["border_correction"] + else: + border_corr = False + else: + raise ValueError( + f"Unexpected Value for border_correction: {self.model_cfg.net['pred_slice2vol']['border_correction']}." + ) + else: + border_corr = False + else: + border_corr = False + + if self.model_cfg.net["pred_slice2vol"]["n_samples"] <= 1: + print( + "Number of samples are less or equal to 1 more are required for uncertainty generation" + ) + print( + "Automatically 5 samples will use to uncetainity calculation" + ) + self.model_cfg.net["pred_slice2vol"]["n_samples"] = 5 + + if "var_reductor" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["var_reductor"] + is None + ): + self.model_cfg.net["pred_slice2vol"][ + "var_reductor" + ] = True + else: + self.model_cfg.net["pred_slice2vol"][ + "var_reductor" + ] = True + + if "relative_MI" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["relative_MI"] + is None + ): + self.model_cfg.net["pred_slice2vol"][ + "relative_MI" + ] = True + else: + self.model_cfg.net["pred_slice2vol"][ + "relative_MI" + ] = True + + if "compute_mode" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["compute_mode"] + is None + ): + self.model_cfg.net["pred_slice2vol"][ + "compute_mode" + ] = "mutual_inf" + else: + self.model_cfg.net["pred_slice2vol"][ + "compute_mode" + ] = "mutual_inf" + + if "estabilizer" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["estabilizer"] + is None + ): + self.model_cfg.net["pred_slice2vol"][ + "estabilizer" + ] = False + else: + self.model_cfg.net["pred_slice2vol"][ + "estabilizer" + ] = False + + if "pertubations" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["pertubations"] + is not None + ): + if isinstance( + self.model_cfg.net["pred_slice2vol"][ + "pertubations" + ], + str, + ): + pert_opt = True + elif isinstance( + self.model_cfg.net["pred_slice2vol"][ + "pertubations" + ], + bool, + ): + pert_opt = True + elif isinstance( + self.model_cfg.net["pred_slice2vol"][ + "pertubations" + ], + list, + ): + if ( + len( + self.model_cfg.net["pred_slice2vol"][ + "pertubations" + ] + ) + != 0 + ): + pert_opt = True + else: + self.model_cfg.net["pred_slice2vol"]["uncertainity_map"] = False + else: + self.model_cfg.net["pred_slice2vol"]["uncertainity_map"] = False + + # handle multi pred generation + if "multi_pred_mode" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["multi_pred_mode"] + is not None + ): + if ( + self.model_cfg.net["pred_slice2vol"][ + "multi_pred_mode" + ].lower() + != "single" + ): + if self.model_cfg.net["pred_slice2vol"]["n_samples"] <= 1: + print( + "Number of samples are less or equal to 1 more are required for multi prediction usage" + ) + print("Automatically 5 samples will use to prediction") + self.model_cfg.net["pred_slice2vol"]["n_samples"] = 5 + else: + self.model_cfg.net["pred_slice2vol"][ + "multi_pred_mode" + ] = "single" + + else: + self.model_cfg.net["pred_slice2vol"]["multi_pred_mode"] = "single" + + if "jupyter" in self.model_cfg.net["pred_slice2vol"]: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm + + for fn in tqdm(filenames, desc="Prediction for the Volume"): + try: + img = BioImage(fn, reader=bioio_tifffile.Reader).get_image_data( + self.data_cfg.inference_input.reader_params[ + "dimension_order_out" + ], + T=self.data_cfg.inference_input.reader_params["T"], + ) + except Exception: + try: + img = BioImage(fn).get_image_data( + self.data_cfg.inference_input.reader_params[ + "dimension_order_out" + ], + T=self.data_cfg.inference_input.reader_params["T"], + ) + except Exception as e: + print(f"Error: {e}") + print( + f"Image {fn} failed at read process check the format." + ) + + voxel_sizes = PhysicalPixelSizes(1, 1, 1) + if vs_flag == "auto": + pps = getattr(BioImage(fn), "physical_pixel_sizes", None) + if pps is None: + voxel_sizes = PhysicalPixelSizes(None, None, None) + elif isinstance(pps, tuple): + voxel_sizes = pps # tuple like (Z, Y, X) + else: + voxel_sizes = ( + getattr(pps, "Z", None), + getattr(pps, "Y", None), + getattr(pps, "X", None), + ) + voxel_sizes = PhysicalPixelSizes( + voxel_sizes[0], voxel_sizes[1], voxel_sizes[2] + ) + + voxel_sizes = [ + 1.0 if v is None else float(v) for v in voxel_sizes + ] + voxel_sizes = PhysicalPixelSizes( + voxel_sizes[0], voxel_sizes[1], voxel_sizes[2] + ) + else: + voxel_sizes = vs_flag + + if max_proj: + img = np.max(img, axis=1) + img = np.expand_dims(img, axis=1) + + out_list = [] + uncertainity_map = [] + n = fn.name + if len(img.shape) == 3: + # chanel dummy + img = img[None, ...] + + for zz in tqdm( + range(img.shape[1]), desc="infering slice", leave=False + ): + samplesz = [] + im_input = img[:, zz, :, :] + for i in range( + self.model_cfg.net["pred_slice2vol"]["n_samples"] + ): + + if pert_opt and i != 0: + inp = perturb_image( + im_input, + self.model_cfg.net["pred_slice2vol"][ + "pertubations" + ], + ) + else: + inp = im_input + + logits = self.process_one_image(inp) + samplesz.append(np.squeeze(logits)) + + if ( + self.model_cfg.net["pred_slice2vol"][ + "multi_pred_mode" + ].lower() + == "single" + ): + seg = samplesz[0] + seg = seg[None, ...] + elif ( + self.model_cfg.net["pred_slice2vol"][ + "multi_pred_mode" + ].lower() + == "max" + ): + seg = max_prediction(samplesz) + seg = seg[None, ...] + elif ( + self.model_cfg.net["pred_slice2vol"][ + "multi_pred_mode" + ].lower() + == "mean" + ): + seg = mean_prediction(samplesz) + seg = seg[None, ...] + elif ( + self.model_cfg.net["pred_slice2vol"][ + "multi_pred_mode" + ].lower() + == "variance" + ): + seg = variance_prediction(samplesz) + seg = seg[None, ...] + elif ( + self.model_cfg.net["pred_slice2vol"][ + "multi_pred_mode" + ].lower() + == "sum" + ): + seg = add_prediction(samplesz) + seg = seg[None, ...] + else: + raise ValueError( + f"{self.model_cfg.net['pred_slice2vol']['multi_pred_mode']} is not a valid method for multi prediction use." + ) + + if original_postprocess is not None: + pp_data = seg + for pp_info in original_postprocess: + pp = parse_config(pp_info) + pp_data = pp(pp_data) + if torch.is_tensor(pp_data): + seg = pp_data.cpu().numpy() + else: + seg = pp_data + out_list.append(np.squeeze(seg)) + if self.model_cfg.net["pred_slice2vol"]["uncertainity_map"]: + Uc_zmap = Extract_Uncertainty_Maps( + logits_samples=samplesz, + compute_mode=self.model_cfg.net["pred_slice2vol"][ + "compute_mode" + ], + relative_MI=self.model_cfg.net["pred_slice2vol"][ + "relative_MI" + ], + var_reductor=self.model_cfg.net["pred_slice2vol"][ + "var_reductor" + ], + estabilizer=self.model_cfg.net["pred_slice2vol"][ + "estabilizer" + ], + ) + uncertainity_map.append(Uc_zmap) + + # elimina _IM del name para vessqc + if "_IM" in n: + n = n.replace("_IM", "") + + seg_full = np.stack(out_list, axis=0) + if self.model_cfg.net["pred_slice2vol"]["uncertainity_map"]: + UM_full = np.stack(uncertainity_map, axis=0) + + if "remove_object_size" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["remove_object_size"] + is not None + ): + seg_full = Remove_objects( + seg_full=seg_full, + n_classes=self.model_cfg.net["pred_slice2vol"][ + "n_class_correction" + ], + remove_object_size=self.model_cfg.net["pred_slice2vol"][ + "remove_object_size" + ], + voxel_sizes=tuple(voxel_sizes), + ) + + if "hole_size_threshold" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["hole_size_threshold"] + is not None + ): + seg_full = Hole_Correction( + seg_full=seg_full, + n_classes=self.model_cfg.net["pred_slice2vol"][ + "n_class_correction" + ], + hole_size_threshold=self.model_cfg.net[ + "pred_slice2vol" + ]["hole_size_threshold"], + voxel_sizes=tuple(voxel_sizes), + ) + + if "min_thickness_list" in self.model_cfg.net["pred_slice2vol"]: + if ( + self.model_cfg.net["pred_slice2vol"]["min_thickness_list"] + is not None + ): + seg_full = Thickness_Corretion( + seg_full=seg_full, + n_classes=self.model_cfg.net["pred_slice2vol"][ + "n_class_correction" + ], + min_thickness_physical=self.model_cfg.net[ + "pred_slice2vol" + ]["min_thickness_list"], + voxel_sizes=tuple(voxel_sizes), + ) + + if perycites_correction: + seg_full = Perycites_correction(seg_full=seg_full) + + if ".tiff" in n: + + out_fn = self.data_cfg.inference_output.path / n.replace( + ".tiff", "_segPred.tiff" + ) + UM_out_fn = self.data_cfg.inference_output.path / n.replace( + ".tiff", "_uncertainty.tiff" + ) + elif ".tif" in n: + out_fn = self.data_cfg.inference_output.path / n.replace( + ".tif", "_segPred.tif" + ) + UM_out_fn = self.data_cfg.inference_output.path / n.replace( + ".tif", "_uncertainty.tif" + ) + + OmeTiffWriter.save( + data=seg_full, + uri=out_fn, + dim_order="ZYX", + physical_pixel_sizes=voxel_sizes, + physical_pixel_units="micron", + ) + + if self.model_cfg.net["pred_slice2vol"]["uncertainity_map"]: + + if n_trunc >= 0: + UM_full = np.trunc(UM_full * (10**n_trunc)) / (10**n_trunc) + + if threshold_um >= 0: + UM_full[UM_full < threshold_um] = 0 + + if border_corr: + nx, ny = border_corr + nx = abs(nx) + ny = abs(ny) + Z, X, Y = UM_full.shape + UM_full[:, : nx + 1, :] = 0 + UM_full[:, X - (nx + 1) :, :] = 0 + UM_full[:, :, : (ny + 1)] = 0 + UM_full[:, :, Y - (ny + 1) :] = 0 + + if self.model_cfg.net["pred_slice2vol"]["var_reductor"]: + OmeTiffWriter.save(UM_full, UM_out_fn, dim_order="ZYX") + + else: + UM_full_CZYX = np.moveaxis(UM_full, 1, 0) + OmeTiffWriter.save( + UM_full_CZYX, UM_out_fn, dim_order="CZYX" + ) + + else: + raise ValueError("Please provide params for the volumetric prediction.") diff --git a/mmv_im2im/models/nets/ProbUnet.py b/mmv_im2im/models/nets/ProbUnet.py index ed64e67..312cb1a 100644 --- a/mmv_im2im/models/nets/ProbUnet.py +++ b/mmv_im2im/models/nets/ProbUnet.py @@ -1,4 +1,3 @@ -# Save this as ProbUnet.py (or mmv_im2im/models/ProbUnet.py if that's its actual path) import torch import torch.nn as nn import torch.nn.functional as F @@ -13,17 +12,17 @@ def get_valid_num_groups(channels): class ConvBlock(nn.Module): - """Standard 2D Convolutional Block.""" + """Standard 2D/3D Convolutional Block.""" - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, Conv, GroupNorm): super().__init__() gn_groups1 = get_valid_num_groups(out_channels) gn_groups2 = get_valid_num_groups(out_channels) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) - self.gn1 = nn.GroupNorm(gn_groups1, out_channels) + self.conv1 = Conv(in_channels, out_channels, kernel_size=3, padding=1) + self.gn1 = GroupNorm(gn_groups1, out_channels) self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) - self.gn2 = nn.GroupNorm(gn_groups2, out_channels) + self.conv2 = Conv(out_channels, out_channels, kernel_size=3, padding=1) + self.gn2 = GroupNorm(gn_groups2, out_channels) self.relu2 = nn.ReLU(inplace=True) def forward(self, x): @@ -35,10 +34,10 @@ def forward(self, x): class Down(nn.Module): """Downsampling block (MaxPool + ConvBlock).""" - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, MaxPool, ConvBlock, Conv, GroupNorm): super().__init__() - self.pool = nn.MaxPool2d(2) - self.conv_block = ConvBlock(in_channels, out_channels) + self.pool = MaxPool(2) + self.conv_block = ConvBlock(in_channels, out_channels, Conv, GroupNorm) def forward(self, x): x = self.pool(x) @@ -51,7 +50,7 @@ class Up(nn.Module): Args: in_channels_x1_before_upsample (int): Number of channels of the feature map (x1) - before being upsampled by ConvTranspose2d. + before being upsampled by ConvTranspose. in_channels_x2_skip_connection (int): Number of channels of the skip connection (x2). out_channels (int): Number of output channels for the final ConvBlock in this Up stage. """ @@ -61,10 +60,15 @@ def __init__( in_channels_x1_before_upsample, in_channels_x2_skip_connection, out_channels, + ConvTranspose, + ConvBlock, + Conv, + GroupNorm, + interpolation_mode, ): super().__init__() - - self.up = nn.ConvTranspose2d( + self.interpolation_mode = interpolation_mode + self.up = ConvTranspose( in_channels_x1_before_upsample, in_channels_x1_before_upsample // 2, kernel_size=2, @@ -74,14 +78,24 @@ def __init__( channels_for_conv_block = ( in_channels_x1_before_upsample // 2 ) + in_channels_x2_skip_connection - self.conv_block = ConvBlock(channels_for_conv_block, out_channels) + self.conv_block = ConvBlock( + channels_for_conv_block, out_channels, Conv, GroupNorm + ) def forward(self, x1, x2): x1 = self.up(x1) - # Adjust dimensions if there's a mismatch due to padding or odd sizes - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + + # Get the spatial size of the skip connection tensor + spatial_size_x2 = x2.size()[2:] + + if x1.size()[2:] != spatial_size_x2: + x1 = F.interpolate( + x1, + size=spatial_size_x2, + mode=self.interpolation_mode, + align_corners=False, + ) + x = torch.cat([x2, x1], dim=1) return self.conv_block(x) @@ -89,96 +103,140 @@ def forward(self, x1, x2): class PriorNet(nn.Module): """Network to predict prior distribution (mu, logvar).""" - def __init__(self, in_channels, latent_dim): + def __init__(self, in_channels, latent_dim, Conv): super().__init__() - self.conv = nn.Conv2d(in_channels, 2 * latent_dim, kernel_size=1) + self.conv = Conv(in_channels, 2 * latent_dim, kernel_size=1) def forward(self, x): mu_logvar = self.conv(x) - mu = mu_logvar[:, : self.conv.out_channels // 2, :, :] - logvar = mu_logvar[:, self.conv.out_channels // 2 :, :, :] + mu = mu_logvar[:, : self.conv.out_channels // 2, ...] + logvar = mu_logvar[:, self.conv.out_channels // 2 :, ...] return mu, logvar class PosteriorNet(nn.Module): """Network to predict posterior distribution (mu, logvar).""" - def __init__(self, in_channels, latent_dim): + def __init__(self, in_channels, latent_dim, Conv): super().__init__() - self.conv = nn.Conv2d(in_channels, 2 * latent_dim, kernel_size=1) + self.conv = Conv(in_channels, 2 * latent_dim, kernel_size=1) def forward(self, x): mu_logvar = self.conv(x) - mu = mu_logvar[:, : self.conv.out_channels // 2, :, :] - logvar = mu_logvar[:, self.conv.out_channels // 2 :, :, :] + mu = mu_logvar[:, : self.conv.out_channels // 2, ...] + logvar = mu_logvar[:, self.conv.out_channels // 2 :, ...] return mu, logvar class ProbabilisticUNet(nn.Module): - """Probabilistic UNet model.""" + """Probabilistic UNet model. + + This model can operate in 2D or 3D based on the 'model_type' parameter. + """ def __init__( - self, in_channels, n_classes, latent_dim=6, **kwargs - ): # Added **kwargs to capture extra params + self, + in_channels, + n_classes, + latent_dim=6, + task="segment", + model_type="2D", + **kwargs, + ): super().__init__() self.in_channels = in_channels self.n_classes = n_classes self.latent_dim = latent_dim - # self.beta is no longer needed here as it's handled by the loss function + self.task = task + self.model_type = model_type + + # Select the appropriate layers based on model_type + if model_type == "2D": + self.Conv = nn.Conv2d + self.MaxPool = nn.MaxPool2d + self.ConvTranspose = nn.ConvTranspose2d + self.GroupNorm = nn.GroupNorm + self.interpolation_mode = "bilinear" + elif model_type == "3D": + self.Conv = nn.Conv3d + self.MaxPool = nn.MaxPool3d + self.ConvTranspose = nn.ConvTranspose3d + self.GroupNorm = nn.GroupNorm + self.interpolation_mode = "trilinear" + else: + raise ValueError(f"Unknown model_type: {model_type}") # Encoder path (U-Net) - self.inc = ConvBlock(in_channels, 32) - self.down1 = Down(32, 64) - self.down2 = Down(64, 128) - self.down3 = Down(128, 256) - self.down4 = Down(256, 512) # Bottleneck features + self.inc = ConvBlock(in_channels, 32, self.Conv, self.GroupNorm) + self.down1 = Down(32, 64, self.MaxPool, ConvBlock, self.Conv, self.GroupNorm) + self.down2 = Down(64, 128, self.MaxPool, ConvBlock, self.Conv, self.GroupNorm) + self.down3 = Down(128, 256, self.MaxPool, ConvBlock, self.Conv, self.GroupNorm) + self.down4 = Down(256, 512, self.MaxPool, ConvBlock, self.Conv, self.GroupNorm) # Prior and Posterior Networks - self.prior_net = PriorNet(512, latent_dim) - # PosteriorNet input channels: 512 (features) + n_classes (one-hot y) - self.posterior_net = PosteriorNet(512 + n_classes, latent_dim) + self.prior_net = PriorNet(512, latent_dim, self.Conv) + # PosteriorNet input channels: 512 (features) + n_classes (one-hot y or float y) + self.posterior_net = PosteriorNet(512 + n_classes, latent_dim, self.Conv) # Decoder Path (U-Net upsampling path) - # Input channels for Up blocks adjusted to include latent_dim self.up1 = Up( in_channels_x1_before_upsample=512 + latent_dim, in_channels_x2_skip_connection=256, out_channels=256, + ConvTranspose=self.ConvTranspose, + ConvBlock=ConvBlock, + Conv=self.Conv, + GroupNorm=self.GroupNorm, + interpolation_mode=self.interpolation_mode, ) self.up2 = Up( in_channels_x1_before_upsample=256, in_channels_x2_skip_connection=128, out_channels=128, + ConvTranspose=self.ConvTranspose, + ConvBlock=ConvBlock, + Conv=self.Conv, + GroupNorm=self.GroupNorm, + interpolation_mode=self.interpolation_mode, ) self.up3 = Up( in_channels_x1_before_upsample=128, in_channels_x2_skip_connection=64, out_channels=64, + ConvTranspose=self.ConvTranspose, + ConvBlock=ConvBlock, + Conv=self.Conv, + GroupNorm=self.GroupNorm, + interpolation_mode=self.interpolation_mode, ) self.up4 = Up( in_channels_x1_before_upsample=64, in_channels_x2_skip_connection=32, out_channels=32, + ConvTranspose=self.ConvTranspose, + ConvBlock=ConvBlock, + Conv=self.Conv, + GroupNorm=self.GroupNorm, + interpolation_mode=self.interpolation_mode, ) - self.outc = nn.Conv2d(32, n_classes, kernel_size=1) + self.outc = self.Conv(32, n_classes, kernel_size=1) def forward(self, x, y=None): """ Forward pass of the Probabilistic UNet. Args: - x (torch.Tensor): Input image tensor (B, C, H, W). - y (torch.Tensor, optional): Ground truth segmentation mask (B, 1, H, W or B, H, W) - used for training to calculate posterior. - Defaults to None (for inference). + x (torch.Tensor): Input image tensor (B, C, H, W) for 2D or (B, C, D, H, W) for 3D. + y (torch.Tensor, optional): Ground truth segmentation mask used for training to + calculate posterior. Defaults to None (for inference). Returns: tuple: A tuple containing: - - logits (torch.Tensor): Output logits of the UNet (B, n_classes, H, W). + - logits (torch.Tensor): Output logits of the UNet. - prior_mu (torch.Tensor): Mean of the prior distribution. - prior_logvar (torch.Tensor): Log-variance of the prior distribution. - post_mu (torch.Tensor or None): Mean of the posterior distribution (None if y is None). @@ -197,21 +255,28 @@ def forward(self, x, y=None): # Posterior calculation and latent variable sampling post_mu, post_logvar = None, None if y is not None: - # Ensure y is one-hot encoded and downsampled to match features spatial dimensions. - # y typically comes as [B, 1, H, W] with integer class labels. - # Convert to [B, n_classes, H, W] for one-hot, then permute for channel dim. - y_one_hot = ( - F.one_hot(y.long().squeeze(1), num_classes=self.n_classes) - .permute(0, 3, 1, 2) - .float() - ) - - # Downsample y_one_hot to match features' spatial dimensions - y_downsampled = F.interpolate( - y_one_hot, size=features.shape[2:], mode="nearest" - ) + if self.task == "segment": + # Ensure y is one-hot encoded and downsampled + y_one_hot = F.one_hot(y.long().squeeze(1), num_classes=self.n_classes) + y_one_hot = y_one_hot.permute( + 0, -1, *range(1, y_one_hot.dim() - 1) + ).float() + + y_downsampled = F.interpolate( + y_one_hot, size=features.shape[2:], mode="nearest" + ) + elif self.task == "regression": + # For regression, y is already a float tensor, just downsample + y_downsampled = F.interpolate( + y, + size=features.shape[2:], + mode=self.interpolation_mode, + align_corners=False, + ) + else: + raise ValueError(f"Unknown task type: {self.task}") - # Concatenate features and downsampled one-hot y for posterior network + # Concatenate features and downsampled y for posterior network post_mu, post_logvar = self.posterior_net( torch.cat([features, y_downsampled], dim=1) ) @@ -227,17 +292,16 @@ def forward(self, x, y=None): z = prior_mu + eps * std_prior # Expand 'z' to spatial dimensions for concatenation + spatial_dims = features.size()[2:] if z.dim() == 2: # [B, latent_dim] - z_expanded = ( - z.unsqueeze(-1) - .unsqueeze(-1) - .repeat(1, 1, features.size(2), features.size(3)) - ) - elif z.dim() == 4: # [B, latent_dim, H, W] - if z.size(2) != features.size(2) or z.size(3) != features.size(3): - z_expanded = F.interpolate( - z, size=(features.size(2), features.size(3)), mode="nearest" - ) + # Expands latent vector to spatial dimensions + z_expanded = z.unsqueeze(-1) + for _ in range(len(spatial_dims) - 1): + z_expanded = z_expanded.unsqueeze(-1) + z_expanded = z_expanded.repeat(1, 1, *spatial_dims) + elif z.dim() == 2 + len(spatial_dims): + if z.size()[2:] != spatial_dims: + z_expanded = F.interpolate(z, size=spatial_dims, mode="nearest") else: z_expanded = z else: diff --git a/mmv_im2im/models/pl_ProbUnet.py b/mmv_im2im/models/pl_ProbUnet.py index f25aa89..8798eac 100644 --- a/mmv_im2im/models/pl_ProbUnet.py +++ b/mmv_im2im/models/pl_ProbUnet.py @@ -136,23 +136,33 @@ def validation_step(self, batch, batch_idx): def log_images(self, batch, y_hat, stage): src = batch["IM"] tar = batch["GT"] + task = self.model_info.net["params"].get("task", "segment") save_path = Path(self.trainer.log_dir) save_path.mkdir(parents=True, exist_ok=True) - act = torch.nn.Softmax(dim=1) - yhat_act = act(y_hat) - - src_out = np.squeeze(src[0].detach().cpu().numpy()).astype(float) - tar_out = np.squeeze(tar[0].detach().cpu().numpy()).astype(float) - prd_out = np.squeeze(yhat_act[0].detach().cpu().numpy()).astype(float) + if task == "segment": + act = torch.nn.Softmax(dim=1) + yhat_act = act(y_hat) + prd_out = np.squeeze( + yhat_act[0].detach().cpu().numpy().argmax(axis=0) + ).astype(float) + tar_out = np.squeeze(tar[0].detach().cpu().numpy()).astype(float) + + elif task == "regression": + prd_out = np.squeeze(y_hat[0].detach().cpu().numpy()).astype(float) + tar_out = np.squeeze(tar[0].detach().cpu().numpy()).astype(float) + else: + raise ValueError(f"Unknown task type for logging: {task}") def get_dim_order(arr): dims = len(arr.shape) - return {2: "YX", 3: "ZYX", 4: "CZYX"}.get(dims, "YX") + return {2: "YX", 3: "CYX"}.get(dims, "YX") rand_tag = randint(1, 1000) + src_out = np.squeeze(src[0].detach().cpu().numpy()).astype(float) + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_src_{rand_tag}.tiff" OmeTiffWriter.save(src_out, out_fn, dim_order=get_dim_order(src_out)) diff --git a/mmv_im2im/proj_tester.py b/mmv_im2im/proj_tester.py index baea7c6..1cbf373 100644 --- a/mmv_im2im/proj_tester.py +++ b/mmv_im2im/proj_tester.py @@ -15,11 +15,11 @@ from mmv_im2im.utils.misc import generate_test_dataset_dict, parse_config from mmv_im2im.utils.for_transform import parse_monai_ops_vanilla from skimage.io import imsave as save_rgb - - -# from mmv_im2im.utils.piecewise_inference import predict_piecewise +import bioio_tifffile +from tqdm import tqdm from monai.inferers import sliding_window_inference + # https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html#predicting ############################################################################### @@ -106,9 +106,9 @@ def process_one_image( if isinstance(img, DaskArray): # Perform the prediction x = img.compute() + elif isinstance(img, NumpyArray): x = img - else: raise ValueError("invalid image") @@ -243,19 +243,16 @@ def process_one_image( raise ValueError("error in prediction output shape") def run_inference(self): + self.setup_model() self.setup_data_processing() # set up data filenames dataset_list = generate_test_dataset_dict( self.data_cfg.inference_input.dir, self.data_cfg.inference_input.data_type ) - dataset_length = len(dataset_list) # loop through all images and apply the model - for i, ds in enumerate(dataset_list): - - # Read the image - print(f"Reading the image {i}/{dataset_length}") + for ds in tqdm(dataset_list, desc="Predicting images"): # output file name info fn_core = Path(ds).stem @@ -275,15 +272,35 @@ def run_inference(self): print(f"making a temp folder at {tmppath}") # get the number of time points - reader = BioImage(ds) + try: + reader = BioImage(ds, reader=bioio_tifffile.Reader) + except Exception: + try: + reader = BioImage(ds) + except Exception as e: + print(f"Error: {e}") + print(f"Image {ds} failed at read process check the format.") + timelapse_data = reader.dims.T tmpfile_list = [] - for t_idx in range(timelapse_data): - img = BioImage(ds).get_image_data( - T=[t_idx], **self.data_cfg.inference_input.reader_params - ) - print(f"Predicting the image timepoint {t_idx}") + for t_idx in tqdm( + range(timelapse_data), desc="Predicting image timepoint" + ): + try: + img = BioImage(ds, reader=bioio_tifffile.Reader).get_image_data( + T=[t_idx], **self.data_cfg.inference_input.reader_params + ) + except Exception: + try: + img = BioImage(ds).get_image_data( + T=[t_idx], **self.data_cfg.inference_input.reader_params + ) + except Exception as e: + print(f"Error: {e}") + print( + f"Image {ds} failed at read process check the format." + ) # prepare output filename out_fn = Path(tmppath) / f"{fn_core}_{t_idx}.npy" @@ -330,9 +347,18 @@ def run_inference(self): # clean up temporary dir shutil.rmtree(tmppath) else: - img = BioImage(ds).get_image_data( - **self.data_cfg.inference_input.reader_params - ) + try: + img = BioImage(ds, reader=bioio_tifffile.Reader).get_image_data( + **self.data_cfg.inference_input.reader_params + ) + except Exception: + try: + img = BioImage(ds).get_image_data( + **self.data_cfg.inference_input.reader_params + ) + except Exception as e: + print(f"Error: {e}") + print(f"Image {ds} failed at read process check the format.") # prepare output filename if "." in suffix: @@ -356,5 +382,4 @@ def run_inference(self): / f"{fn_core}{suffix}.tiff" ) - print("Predicting the image") self.process_one_image(img, out_fn) diff --git a/mmv_im2im/proj_trainer.py b/mmv_im2im/proj_trainer.py index 150286f..e693172 100644 --- a/mmv_im2im/proj_trainer.py +++ b/mmv_im2im/proj_trainer.py @@ -11,6 +11,7 @@ import warnings + warnings.simplefilter(action="ignore", category=FutureWarning) ############################################################################### diff --git a/mmv_im2im/utils/elbo_loss.py b/mmv_im2im/utils/elbo_loss.py index 8720e7e..4e26357 100644 --- a/mmv_im2im/utils/elbo_loss.py +++ b/mmv_im2im/utils/elbo_loss.py @@ -5,6 +5,7 @@ from mmv_im2im.utils.topological_loss import TI_Loss from mmv_im2im.utils.connectivity_loss import ConnectivityCoherenceLoss from monai.losses import GeneralizedDiceFocalLoss +from monai.metrics import HausdorffDistanceMetric class KLDivergence(nn.Module): @@ -73,6 +74,10 @@ class ELBOLoss(nn.Module): use_gdl_focal_regularization (bool): If True, Includes Generalized Dice Focal (GDF) regularization. gdl_focal_weight (float): Weighting factor for GDF. gdl_class_weights (list): Weights for each class. + use_hausdorff_regularization (bool): Enable or disable Hausdorff regularization + hausdorff_weight (float): Weighting factor for the Hausdorff loss term. + hausdorff_ignore_background (bool): If True, ignore background for Hausdorff loss. + task (str): The task type, either "segment" for segmentation or "regression" for regression. """ def __init__( @@ -100,11 +105,16 @@ def __init__( gdl_focal_weight: float = 1.0, elbo_class_weights: list = None, gdl_class_weights: list = None, + use_hausdorff_regularization: bool = False, + hausdorff_weight: float = 0.1, + hausdorff_ignore_background: bool = True, + task: str = "segment", ): super().__init__() self.beta = beta self.n_classes = n_classes self.kl_clamp = kl_clamp + self.task = task self.kl_divergence_calculator = KLDivergence() self.use_fractal_regularization = use_fractal_regularization @@ -135,7 +145,6 @@ def __init__( else: self.topological_weight = 0.0 - # New Connectivity Regularization self.use_connectivity_regularization = use_connectivity_regularization if self.use_connectivity_regularization: self.connectivity_weight = connectivity_weight @@ -161,7 +170,16 @@ def __init__( else: self.gdl_focal_weight = 0.0 - # Convert class_weights list to a torch.Tensor + self.use_hausdorff_regularization = use_hausdorff_regularization + if self.use_hausdorff_regularization: + self.hausdorff_weight = hausdorff_weight + self.hausdorff_distance_calculator = HausdorffDistanceMetric( + include_background=not hausdorff_ignore_background, + reduction="mean", + ) + else: + self.hausdorff_weight = 0.0 + if elbo_class_weights is not None: self.elbo_class_weights = torch.tensor( elbo_class_weights, dtype=torch.float32 @@ -187,27 +205,47 @@ def forward(self, logits, y_true, prior_mu, prior_logvar, post_mu, post_logvar): Returns: torch.Tensor: The calculated ELBO loss. """ - # Ensure y_true has correct dimensions (e.g., [B, H, W]) for cross_entropy - if y_true.ndim == 4 and y_true.shape[1] == 1: - y_true_squeezed = y_true.squeeze(1) # Squeeze channel dim to [B, H, W] - else: - y_true_squeezed = y_true - - # Negative Cross-Entropy (Log-Likelihood) - if ( - self.elbo_class_weights is not None - and self.elbo_class_weights.device != logits.device - ): - elbo_class_weights_on_device = self.elbo_class_weights.to(logits.device) - else: - elbo_class_weights_on_device = self.elbo_class_weights - log_likelihood = -F.cross_entropy( - logits, - y_true_squeezed.long(), - reduction="mean", - weight=elbo_class_weights_on_device, - ) + if self.task == "segment": + + if y_true.shape[1] == 1: + y_true_squeezed = y_true.squeeze(1) + else: + y_true_squeezed = y_true + + if ( + self.elbo_class_weights is not None + and self.elbo_class_weights.device != logits.device + ): + elbo_class_weights_on_device = self.elbo_class_weights.to(logits.device) + else: + elbo_class_weights_on_device = self.elbo_class_weights + + log_likelihood = -F.cross_entropy( + logits, + y_true_squeezed.long(), + reduction="mean", + weight=elbo_class_weights_on_device, + ) + + elif self.task == "regression": + # log_likelihood = -F.mse_loss( + # logits, + # y_true, + # reduction="mean", + # ) + # log_likelihood = -F.mae_loss( + # logits, + # y_true, + # reduction="mean", + # ) + log_likelihood = -F.huber_loss( + logits, + y_true, + reduction="mean", + ) + else: + raise ValueError(f"Unknown task type: {self.task}") # KL-Divergence kl_div = self.kl_divergence_calculator( @@ -218,61 +256,93 @@ def forward(self, logits, y_true, prior_mu, prior_logvar, post_mu, post_logvar): total_loss = elbo_loss - if self.use_fractal_regularization: - y_pred_mask = F.softmax(logits, dim=1).argmax(dim=1, keepdim=True).float() + if self.task == "segment": + if self.use_fractal_regularization: + y_pred_mask = ( + F.softmax(logits, dim=1).argmax(dim=1, keepdim=True).float() + ) - if y_true_squeezed.ndim == 3: - y_true_for_fractal = y_true_squeezed.unsqueeze(1).float() - else: - y_true_for_fractal = y_true.float() + if y_true.ndim == 4 and y_true.shape[1] == 1: + y_true_for_fractal = y_true.float() + else: + y_true_for_fractal = y_true.unsqueeze(1).float() - fd_true = self.fractal_dimension_calculator(y_true_for_fractal) - fd_pred = self.fractal_dimension_calculator(y_pred_mask) + fd_true = self.fractal_dimension_calculator(y_true_for_fractal) + fd_pred = self.fractal_dimension_calculator(y_pred_mask) - fractal_loss = torch.mean(torch.abs(fd_true - fd_pred)) - total_loss += self.fractal_weight * fractal_loss + fractal_loss = torch.mean(torch.abs(fd_true - fd_pred)) + total_loss += self.fractal_weight * fractal_loss - if self.use_topological_regularization: - # y_true needs to be B, C, H, W or B, C, H, W, D for TI_Loss, where C=1 - # If y_true is B, H, W, unsqueeze to B, 1, H, W - if y_true_squeezed.ndim == 3: - y_true_for_topological = y_true_squeezed.unsqueeze(1).float() - else: - y_true_for_topological = ( - y_true.float() - ) # This should already be B, 1, H, W + if self.use_topological_regularization: + # y_true needs to be B, C, H, W or B, C, H, W, D for TI_Loss, where C=1 + # If y_true is B, H, W, unsqueeze to B, 1, H, W + if y_true.ndim == 3: + y_true_for_topological = y_true.unsqueeze(1).float() + else: + y_true_for_topological = y_true.float() - # logits are B, C, H, W (or B, C, H, W, D), which is what TI_Loss expects for x - topological_loss = self.topological_loss_calculator( - logits, y_true_for_topological - ) - total_loss += self.topological_weight * topological_loss + # logits are B, C, H, W (or B, C, H, W, D), which is what TI_Loss expects for x + topological_loss = self.topological_loss_calculator( + logits, y_true_for_topological + ) + total_loss += self.topological_weight * topological_loss - if self.use_connectivity_regularization: - # y_pred_softmax: (B, C, H, W) - y_pred_softmax = F.softmax(logits, dim=1) - - # y_true_one_hot: Need to convert y_true_squeezed (B, H, W) to one-hot (B, C, H, W) - # Ensure the number of classes matches n_classes used in ELBOLoss - y_true_one_hot = ( - F.one_hot(y_true_squeezed.long(), num_classes=self.n_classes) - .permute(0, 3, 1, 2) - .float() - ) + if self.use_connectivity_regularization: + # y_pred_softmax: (B, C, H, W) + y_pred_softmax = F.softmax(logits, dim=1) - connectivity_loss = self.connectivity_coherence_calculator( - y_pred_softmax, y_true_one_hot - ) - total_loss += self.connectivity_weight * connectivity_loss + # y_true_one_hot: Need to convert y_true_squeezed (B, H, W) to one-hot (B, C, H, W) + y_true_one_hot = ( + F.one_hot(y_true_squeezed.long(), num_classes=self.n_classes) + .permute(0, 3, 1, 2) + .float() + ) - if self.use_gdl_focal_regularization: - # logits: (B, C, H, W) - # y_true: (B, H, W) o (B, 1, H, W) - # GeneralizedDiceFocalLoss de MONAI puede manejar esto directamente - y_true_for_gdl_focal = y_true_squeezed.unsqueeze(1).long() - gdl_focal_loss = self.gdl_focal_loss_calculator( - logits, y_true_for_gdl_focal - ) - total_loss += self.gdl_focal_weight * gdl_focal_loss + connectivity_loss = self.connectivity_coherence_calculator( + y_pred_softmax, y_true_one_hot + ) + total_loss += self.connectivity_weight * connectivity_loss + + if self.use_gdl_focal_regularization: + # logits: (B, C, H, W) + # y_true: (B, H, W) o (B, 1, H, W) + y_true_for_gdl_focal = y_true_squeezed.unsqueeze(1).long() + gdl_focal_loss = self.gdl_focal_loss_calculator( + logits, y_true_for_gdl_focal + ) + total_loss += self.gdl_focal_weight * gdl_focal_loss + + if self.use_hausdorff_regularization: + # Get the one-hot encoded ground truth + # Squeeze y_true to (B, H, W) if it's (B, 1, H, W) + if y_true.ndim == 4 and y_true.shape[1] == 1: + y_true_squeezed = y_true.squeeze(1) + else: + y_true_squeezed = y_true + + # Calculate the Hausdorff distance + # The metric returns a tensor of shape (B, C), so we take the mean + try: + # Convert ground truth to one-hot format (B, C, H, W) + y_true_one_hot = F.one_hot( + y_true_squeezed.long(), num_classes=self.n_classes + ).permute(0, 3, 1, 2) + + # Get the one-hot encoded prediction from logits + y_pred_one_hot = F.one_hot( + logits.argmax(dim=1), num_classes=self.n_classes + ).permute(0, 3, 1, 2) + + # Calculate the Hausdorff distance + # The metric returns a tensor of shape (B, C), so we take the mean + hausdorff_loss = self.hausdorff_distance_calculator( + y_pred=y_pred_one_hot, y_true=y_true_one_hot + ).mean() + + # Add the Hausdorff loss to the total loss + total_loss += self.hausdorff_weight * hausdorff_loss + except Exception: + # Avoid troubles with some tensor that become None douring the hausdorff computation + pass return total_loss diff --git a/mmv_im2im/utils/fourier.py b/mmv_im2im/utils/fourier.py new file mode 100644 index 0000000..fa25297 --- /dev/null +++ b/mmv_im2im/utils/fourier.py @@ -0,0 +1,413 @@ +import numpy as np +from bioio import BioImage +from bioio.writers import OmeTiffWriter +from tqdm import tqdm +import os +from typing import Callable + + +def Im2Fourier(image: np.ndarray, mode: str = "complex") -> np.ndarray: + """ + Computes the 2D Fourier Transform for a multi-channel image and returns the result + in different formats based on the specified mode. + + Args: + image (np.ndarray): The input image as a NumPy array of shape (C, Y, X). + mode (str): The desired output format. Must be one of "complex", "real", or "freq". + - "complex": Returns a matrix of shape (2C, Y, X) with the real and + imaginary parts concatenated along the channel axis. + - "real": Returns a matrix of shape (C, Y, X) with only the real + part of the Fourier Transform. + - "freq": Returns a matrix of shape (2C, Y, X) with the magnitude + (abs) and phase (angle) for each channel. + + Returns: + np.ndarray: The transformed image matrix according to the specified mode. + Returns None if the mode is invalid. + """ + + allowed_modes = ["complex", "real", "freq"] + if mode not in allowed_modes: + raise ValueError(f"Invalid mode: '{mode}'. Must be one of {allowed_modes}.") + + if image.ndim != 3: + raise ValueError("Input image must be a 3D array of shape (C, Y, X).") + + C, Y, X = image.shape + + transformed_channels = [] + + for c in range(C): + + channel = image[c, :, :] + + f_transform = np.fft.fft2(channel) + + f_transform_shifted = np.fft.fftshift(f_transform) + + if mode == "complex": + real_part = np.real(f_transform_shifted) + imag_part = np.imag(f_transform_shifted) + transformed_channels.append(real_part) + transformed_channels.append(imag_part) + + elif mode == "real": + real_part = np.real(f_transform_shifted) + transformed_channels.append(real_part) + + elif mode == "freq": + magnitude = np.abs(f_transform_shifted) + phase = np.angle(f_transform_shifted) + transformed_channels.append(magnitude) + transformed_channels.append(phase) + + return np.stack(transformed_channels, axis=0) + + +def Fourier2Im(fourier_image: np.ndarray, mode: str = "complex") -> np.ndarray: + """ + Computes the 2D Inverse Fourier Transform for a multi-channel image and + returns the result. It automatically deduces the number of original channels + based on the input shape and mode. + + Args: + fourier_image (np.ndarray): The input Fourier-transformed image. Its shape + depends on the specified mode. + mode (str): The format of the input. Must be one of "complex", "real", or "freq". + + Returns: + np.ndarray: The reconstructed image matrix of shape (C, Y, X). + """ + allowed_modes = ["complex", "real", "freq"] + if mode not in allowed_modes: + raise ValueError(f"Invalid mode: '{mode}'. Must be one of {allowed_modes}.") + + if fourier_image.ndim != 3: + raise ValueError("Input image must be a 3D array of shape (N, Y, X).") + + N, Y, X = fourier_image.shape + + # Deduce the number of original channels based on the mode + if mode == "complex": + if N % 2 != 0: + raise ValueError( + f"For mode 'complex', the number of channels ({N}) must be even." + ) + C = N // 2 + elif mode == "freq": + if N % 2 != 0: + raise ValueError( + f"For mode 'freq', the number of channels ({N}) must be even." + ) + C = N // 2 + else: # mode == "real" + C = N + + reconstructed_channels = [] + + if mode == "complex": + for i in range(C): + real_part = fourier_image[2 * i, :, :] + imag_part = fourier_image[2 * i + 1, :, :] + f_transform_shifted = real_part + 1j * imag_part + f_transform = np.fft.ifftshift(f_transform_shifted) + reconstructed_channel = np.fft.ifft2(f_transform) + reconstructed_channels.append(np.real(reconstructed_channel)) + + elif mode == "real": + for i in range(C): + f_transform_shifted = fourier_image[i, :, :] + f_transform = np.fft.ifftshift(f_transform_shifted) + reconstructed_channel = np.fft.ifft2(f_transform) + reconstructed_channels.append(np.real(reconstructed_channel)) + + elif mode == "freq": + for i in range(C): + magnitude = fourier_image[2 * i, :, :] + phase = fourier_image[2 * i + 1, :, :] + f_transform_shifted = magnitude * np.exp(1j * phase) + f_transform = np.fft.ifftshift(f_transform_shifted) + reconstructed_channel = np.fft.ifft2(f_transform) + reconstructed_channels.append(np.real(reconstructed_channel)) + + return np.stack(reconstructed_channels, axis=0) + + +def discretizer2n(matrix: np.ndarray, n: int = 3) -> np.ndarray: + """ + Quantizes the values of a float matrix to a specified number of integer ranges. + + The thresholds are calculated by dividing the range of matrix values + into n equal parts. + + Args: + matrix (np.ndarray): The input matrix with float values. + n (int): The number of integer levels to quantize to (e.g., 3 for 0, 1, 2). + + Returns: + np.ndarray: A new matrix with the quantized integer values (from 0 to n-1). + """ + if not isinstance(matrix, np.ndarray): + raise TypeError("Input must be a NumPy array.") + if not isinstance(n, int) or n < 2: + raise ValueError("n_levels must be an integer greater than or equal to 2.") + + if matrix.size == 0: + return np.array([], dtype=int) + + min_val = np.min(matrix) + max_val = np.max(matrix) + + if min_val == max_val: + return np.full_like(matrix, int(n / 2), dtype=int) + + data_range = max_val - min_val + step = data_range / n + + thresholds = [min_val + i * step for i in range(1, n)] + + quantized_matrix = np.zeros_like(matrix, dtype=int) + + for i in range(n - 1): + quantized_matrix[matrix >= thresholds[i]] = i + 1 + + return quantized_matrix.astype(np.uint8) + + +def Vol2Fourier(volume: np.ndarray, mode: str = "complex") -> np.ndarray: + """ + Applies the 2D Fourier Transform to each image (slice) of a volume and + returns a new volume of transformations. + + Args: + volume (np.ndarray): The input volume as a 4D NumPy array + with shape (Z, C, Y, X). + mode (str): The desired output format. Must be one of "complex", + "real", or "freq". + + Returns: + np.ndarray: The transformed volume with shape (Z, C', Y, X), where C' + is the number of channels of the transformation. + """ + if volume.ndim != 4: + raise ValueError("Input volume must be a 4D array with shape (Z, C, Y, X).") + + Z, C, Y, X = volume.shape + transformed_slices = [] + + for z in range(Z): + image_slice = volume[z, :, :, :] + transformed_slice = Im2Fourier(image_slice, mode=mode) + transformed_slices.append(transformed_slice) + + return np.stack(transformed_slices, axis=0) + + +def Fourier2Vol(fourier_volume: np.ndarray, mode: str = "complex") -> np.ndarray: + """ + Applies the 2D Inverse Fourier Transform to each slice of a transformed volume, + reconstructing the original volume. + + Args: + fourier_volume (np.ndarray): The input volume with the Fourier transform applied, + with shape (Z, N, Y, X). + mode (str): The format of the input. Must be one of "complex", + "real", or "freq". + + Returns: + np.ndarray: The reconstructed volume with shape (Z, C, Y, X). + """ + if fourier_volume.ndim != 4: + raise ValueError("Input volume must be a 4D array with shape (Z, N, Y, X).") + + Z, _, Y, X = fourier_volume.shape + reconstructed_slices = [] + + for z in range(Z): + fourier_slice = fourier_volume[z, :, :, :] + reconstructed_slice = Fourier2Im(fourier_slice, mode=mode) + reconstructed_slices.append(reconstructed_slice) + + return np.stack(reconstructed_slices, axis=0) + + +def VolDiscretizer(volume: np.ndarray, n: int = 3) -> np.ndarray: + """ + Applies discretization to each image (slice) of a volume, + quantizing float values to n integer ranges. + + Args: + volume (np.ndarray): The input volume as a 4D array with + shape (Z, C, Y, X). + n (int): The number of integer levels to quantize to + (e.g., 3 for 0, 1, 2). + + Returns: + np.ndarray: The new volume with discretized values as a 4D integer array. + """ + if volume.ndim != 4: + raise ValueError("Input volume must be a 4D array with shape (Z, C, Y, X).") + + Z, C, Y, X = volume.shape + discretized_slices = [] + + for z in range(Z): + image_slice = volume[z, :, :, :] + discretized_channels = [] + for c in range(C): + channel_slice = image_slice[c, :, :] + discretized_channel = discretizer2n(channel_slice, n=n) + discretized_channels.append(discretized_channel) + + discretized_slices.append(np.stack(discretized_channels, axis=0)) + + return np.stack(discretized_slices, axis=0) + + +def FreqNorm(fourier_image: np.ndarray) -> np.ndarray: + """ + Normalizes the magnitude and phase from the output of Im2Fourier with mode 'freq'. + + The magnitude is normalized using a logarithmic scale followed by a Min-Max + normalization to the range [0, 1]. The phase is scaled to the range [0, 1]. + + Args: + fourier_image (np.ndarray): The input Fourier-transformed matrix with shape + (2C, Y, X), where even channels are magnitude + and odd channels are phase. + + Returns: + np.ndarray: The output matrix with the same shape, but with normalized + magnitude and phase channels. + + Raises: + ValueError: If the number of channels is not even. + """ + if fourier_image.ndim != 3: + raise ValueError("Input image must be a 3D array with shape (N, Y, X).") + + N, Y, X = fourier_image.shape + if N % 2 != 0: + raise ValueError(f"The number of channels ({N}) must be even for mode 'freq'.") + + C = N // 2 + normalized_channels = [] + + for i in range(C): + # Separate magnitude and phase for each channel + magnitude = fourier_image[2 * i, :, :] + phase = fourier_image[2 * i + 1, :, :] + + # Magnitude normalization: logarithm + Min-Max + log_magnitude = np.log1p(magnitude) + min_log_mag = np.min(log_magnitude) + max_log_mag = np.max(log_magnitude) + + if max_log_mag == min_log_mag: + norm_magnitude = np.zeros_like(log_magnitude) + else: + norm_magnitude = (log_magnitude - min_log_mag) / (max_log_mag - min_log_mag) + + # Phase normalization: Scale to the range [0, 1] + # np.angle returns phase in the range [-pi, pi] + norm_phase = (phase + np.pi) / (2 * np.pi) + + normalized_channels.append(norm_magnitude) + normalized_channels.append(norm_phase) + + return np.stack(normalized_channels, axis=0) + + +def VolFreqNorm(fourier_volume: np.ndarray) -> np.ndarray: + """ + Applies frequency-based normalization to each slice of a Fourier-transformed volume. + + This function normalizes the magnitude and phase for each 2D image (slice) + within a 4D volume, using the same logic as FreqNorm. + + Args: + fourier_volume (np.ndarray): The input volume with the Fourier transform applied, + with shape (Z, 2C, Y, X). + + Returns: + np.ndarray: The normalized volume with the same shape. + + Raises: + ValueError: If the input volume is not a 4D array or if the number of + channels is not even. + """ + if fourier_volume.ndim != 4: + raise ValueError("Input volume must be a 4D array with shape (Z, N, Y, X).") + + Z, N, Y, X = fourier_volume.shape + if N % 2 != 0: + raise ValueError(f"The number of channels ({N}) must be even for mode 'freq'.") + + normalized_slices = [] + + for z in range(Z): + fourier_slice = fourier_volume[z, :, :, :] + normalized_slice = FreqNorm(fourier_slice) + normalized_slices.append(normalized_slice) + + return np.stack(normalized_slices, axis=0) + + +def transform_files( + path_data: str, + path_out: str, + function: Callable[[np.ndarray], np.ndarray], + mode: str = "complex", + shape: str = "CYX", +) -> None: + """ + Applies a function to all image files in a directory and saves the modified images. + + This function iterates through all files in the specified input directory, + loads each image, applies a user-defined function to its data, and then + saves the modified image to a new output directory. + + Args: + path_data (str): The path to the directory containing the input image files. + path_out (str): The path to the directory where the modified images will be saved. + function (Callable[[np.ndarray], np.ndarray]): A function that takes a NumPy array + (the image data) and returns a + NumPy array (the transformed data). + shape (str, optional): The dimension order of the image to be loaded. + Defaults to 'CYX'. + + Raises: + TypeError: If path_data, path_out, or shape are not strings. + TypeError: If the 'function' argument is not a callable object. + ValueError: If a provided path does not exist. + """ + + if ( + not isinstance(path_data, str) + or not isinstance(path_out, str) + or not isinstance(shape, str) + ): + raise TypeError("path_data, path_out, and shape must be strings.") + + if not isinstance(function, Callable): + raise TypeError("'function' must be a callable object.") + + if not os.path.isdir(path_data): + raise ValueError(f"The input directory '{path_data}' does not exist.") + + if not os.path.exists(path_out): + print(f"Creating output directory: {path_out}") + os.makedirs(path_out) + + files = os.listdir(path_data) + if not files: + print(f"Warning: No files found in the directory: {path_data}") + return + + for file in tqdm(files): + try: + im = BioImage(os.path.join(path_data, file)).get_image_data(shape, T=0) + c_im = function(im, mode) + OmeTiffWriter.save(c_im, os.path.join(path_out, file), dim_order=shape) + except Exception as e: + print(f"Error processing file '{file}': {e}") diff --git a/mmv_im2im/utils/gdl_reg.py b/mmv_im2im/utils/gdl_reg.py new file mode 100644 index 0000000..5b9a6de --- /dev/null +++ b/mmv_im2im/utils/gdl_reg.py @@ -0,0 +1,200 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmv_im2im.utils.fractal_layers import FractalDimension +from mmv_im2im.utils.topological_loss import TI_Loss +from mmv_im2im.utils.connectivity_loss import ConnectivityCoherenceLoss +from monai.losses import GeneralizedDiceFocalLoss +from monai.metrics import HausdorffDistanceMetric + + +class SegmentationRegularizedLoss(nn.Module): + """ + Segmentation loss that combines a primary loss (Generalized Dice Focal Loss) + with structural regularizers (Fractal, Topological, Connectivity, Hausdorff). + Designed to replace ELBOLoss for DETERMINISTIC UNet models. + """ + + def __init__( + self, + n_classes: int = 3, + # --- Main Loss Parameters (GeneralizedDiceFocalLoss) --- + gdl_focal_weight: float = 1.0, # Overall weight for the GDL/Focal term + gdl_class_weights: list = None, # Class weights for GDL/Focal + # --- Fractal Regularization --- + use_fractal_regularization: bool = False, + fractal_weight: float = 0.1, + fractal_num_kernels: int = 5, + fractal_mode: str = "classic", + fractal_to_binary: bool = True, + # --- Topological Regularization (TI_Loss) --- + use_topological_regularization: bool = False, + topological_weight: float = 0.1, + topological_dim: int = 2, + topological_connectivity: int = 4, + topological_inclusion: list = None, + topological_exclusion: list = None, + topological_min_thick: int = 1, + # --- Connectivity Regularization --- + use_connectivity_regularization: bool = False, + connectivity_weight: float = 0.1, + connectivity_kernel_size: int = 3, + connectivity_ignore_background: bool = True, + # --- Hausdorff Regularization --- + use_hausdorff_regularization: bool = False, + hausdorff_weight: float = 0.1, + hausdorff_ignore_background: bool = True, + **kwargs, # Catch-all for extra params + ): + super().__init__() + self.n_classes = n_classes + + # 1. Main Segmentation Loss (GeneralizedDiceFocalLoss) + self.gdl_focal_weight = gdl_focal_weight + monai_focal_weights = None + if gdl_class_weights is not None: + # MONAI GeneralizedDiceFocalLoss expects a tensor for weights + monai_focal_weights = torch.tensor(gdl_class_weights, dtype=torch.float32) + + self.main_seg_loss_calculator = GeneralizedDiceFocalLoss( + softmax=True, to_onehot_y=True, weight=monai_focal_weights + ) + + # 2. Fractal Regularization Setup (from ELBOLoss.py) + self.use_fractal_regularization = use_fractal_regularization + self.fractal_weight = fractal_weight + if self.use_fractal_regularization: + self.fractal_dimension_calculator = FractalDimension( + num_kernels=fractal_num_kernels, + mode=fractal_mode, + to_binary=fractal_to_binary, + ) + + # 3. Topological Regularization Setup (from ELBOLoss.py) + self.use_topological_regularization = use_topological_regularization + self.topological_weight = topological_weight + if self.use_topological_regularization: + if topological_inclusion is None: + topological_inclusion = [] + if topological_exclusion is None: + topological_exclusion = [] + self.topological_loss_calculator = TI_Loss( + dim=topological_dim, + connectivity=topological_connectivity, + inclusion=topological_inclusion, + exclusion=topological_exclusion, + min_thick=topological_min_thick, + ) + + # 4. Connectivity Regularization Setup (from ELBOLoss.py) + self.use_connectivity_regularization = use_connectivity_regularization + self.connectivity_weight = connectivity_weight + if self.use_connectivity_regularization: + self.connectivity_coherence_calculator = ConnectivityCoherenceLoss( + kernel_size=connectivity_kernel_size, + ignore_background=connectivity_ignore_background, + num_classes=n_classes, + ) + + # 5. Hausdorff Regularization Setup (from ELBOLoss.py) + self.use_hausdorff_regularization = use_hausdorff_regularization + self.hausdorff_weight = hausdorff_weight + if self.use_hausdorff_regularization: + self.hausdorff_distance_calculator = HausdorffDistanceMetric( + include_background=not hausdorff_ignore_background, + reduction="mean", + ) + + # La UNet determinística solo devuelve logits, así que esta es la firma + def forward(self, logits, y_true): + """ + Computes the combined segmentation loss with structural regularizers. + + Args: + logits (torch.Tensor): Output logits from the UNet (B, C, H, W). + y_true (torch.Tensor): Ground truth segmentation mask (B, 1, H, W). + + Returns: + torch.Tensor: The calculated total loss. + """ + + # Squeeze y_true to (B, H, W) if it's (B, 1, H, W) + if y_true.shape[1] == 1: + y_true_squeezed = y_true.squeeze(1) + else: + y_true_squeezed = y_true + + # 1. Primary Segmentation Loss (GDL + Focal) + + # The loss calculator handles logits and converts y_true internally + y_true_for_gdl_focal = y_true_squeezed.unsqueeze(1).long() + primary_loss = self.main_seg_loss_calculator(logits, y_true_for_gdl_focal) + + # Apply overall weight for the main term + total_loss = self.gdl_focal_weight * primary_loss + + # Get softmax probabilities for regularizers that need them + y_pred_proba = F.softmax(logits, dim=1) + + # --- REGULARIZATION TERMS --- + + # 2. Fractal Regularization (Uses argmax mask for prediction) + if self.use_fractal_regularization and self.fractal_weight > 0.0: + # y_pred_mask: (B, 1, H, W) with class indices + y_pred_mask = y_pred_proba.argmax(dim=1, keepdim=True).float() + + # Prepare y_true for fractal (B, 1, H, W) + y_true_for_fractal = y_true_squeezed.unsqueeze(1).float() + + fd_true = self.fractal_dimension_calculator(y_true_for_fractal) + fd_pred = self.fractal_dimension_calculator(y_pred_mask) + + fractal_loss = torch.mean(torch.abs(fd_true - fd_pred)) + total_loss += self.fractal_weight * fractal_loss + + # 3. Topological Regularization (TI_Loss expects logits and y_true B, 1, H, W) + if self.use_topological_regularization and self.topological_weight > 0.0: + # TI_Loss expects logits for x and y_true (B, 1, H, W) for y + y_true_for_topological = y_true_squeezed.unsqueeze(1).float() + topological_loss = self.topological_loss_calculator( + logits, y_true_for_topological + ) + total_loss += self.topological_weight * topological_loss + + # 4. Connectivity Regularization (ConnectivityCoherenceLoss) + if self.use_connectivity_regularization and self.connectivity_weight > 0.0: + # y_pred_softmax: (B, C, H, W) + # y_true_one_hot: Convert (B, H, W) to one-hot (B, C, H, W) + y_true_one_hot = ( + F.one_hot(y_true_squeezed.long(), num_classes=self.n_classes) + .permute(0, 3, 1, 2) + .float() + ) + connectivity_loss = self.connectivity_coherence_calculator( + y_pred_proba, y_true_one_hot + ) + total_loss += self.connectivity_weight * connectivity_loss + + # 5. Hausdorff Regularization + if self.use_hausdorff_regularization and self.hausdorff_weight > 0.0: + try: + # Convert ground truth to one-hot format (B, C, H, W) + y_true_one_hot = F.one_hot( + y_true_squeezed.long(), num_classes=self.n_classes + ).permute(0, 3, 1, 2) + + # Get the one-hot encoded prediction from logits + y_pred_one_hot = F.one_hot( + logits.argmax(dim=1), num_classes=self.n_classes + ).permute(0, 3, 1, 2) + + # Calculate the Hausdorff distance + hausdorff_loss = self.hausdorff_distance_calculator( + y_pred=y_pred_one_hot, y_true=y_true_one_hot + ).mean() + + total_loss += self.hausdorff_weight * hausdorff_loss + except Exception: + pass + + return total_loss diff --git a/mmv_im2im/utils/misc.py b/mmv_im2im/utils/misc.py index 8a0f2a3..17e3f64 100644 --- a/mmv_im2im/utils/misc.py +++ b/mmv_im2im/utils/misc.py @@ -10,6 +10,7 @@ from monai.utils import ensure_tuple, require_pkg from monai.config import PathLike from monai.data.image_reader import _stack_images +import bioio_tifffile @require_pkg(pkg_name="bioio") @@ -22,7 +23,15 @@ def read(self, data: Union[Sequence[PathLike], PathLike]): filenames: Sequence[PathLike] = ensure_tuple(data) img_ = [] for name in filenames: - img_.append(BioImage(f"{name}")) + try: + img_.append(BioImage(f"{name}", reader=bioio_tifffile.Reader)) + + except Exception: + try: + img_.append(BioImage(f"{name}")) + except Exception as e: + print(f"Error: {e}") + print(f"Image {name} failed at read process check the format.") return img_ if len(filenames) > 1 else img_[0] diff --git a/mmv_im2im/utils/multi_pred.py b/mmv_im2im/utils/multi_pred.py new file mode 100644 index 0000000..6bef433 --- /dev/null +++ b/mmv_im2im/utils/multi_pred.py @@ -0,0 +1,27 @@ +import numpy as np + + +def mean_prediction(samples): + stacked_predictions = np.array(samples) + mean_volume = np.mean(stacked_predictions, axis=0) + + return mean_volume + + +def max_prediction(samples): + stacked_predictions = np.array(samples) + max_volume = np.max(stacked_predictions, axis=0) + + return max_volume + + +def variance_prediction(samples): + stacked_predictions = np.array(samples) + variance_volume = np.var(stacked_predictions, axis=0) + return variance_volume + + +def add_prediction(samples): + stacked_predictions = np.array(samples) + sum_volume = np.sum(stacked_predictions, axis=0) + return sum_volume diff --git a/mmv_im2im/utils/topological_loss.py b/mmv_im2im/utils/topological_loss.py index ad973cd..46ccac5 100644 --- a/mmv_im2im/utils/topological_loss.py +++ b/mmv_im2im/utils/topological_loss.py @@ -88,7 +88,7 @@ def set_kernel(self): ) elif self.connectivity == 26: np_kernel = np.ones((k, k, k)) - + self.kernel = torch.from_numpy( np.expand_dims(np.expand_dims(np_kernel, axis=0), axis=0) ) diff --git a/mmv_im2im/utils/urcentainity_extractor.py b/mmv_im2im/utils/urcentainity_extractor.py new file mode 100644 index 0000000..b3e94cb --- /dev/null +++ b/mmv_im2im/utils/urcentainity_extractor.py @@ -0,0 +1,570 @@ +import numpy as np +from skimage.morphology import remove_small_objects, remove_small_holes +from skimage.measure import label, regionprops +from mmv_im2im.utils.utils import topology_preserving_thinning +import torch +from scipy.ndimage import shift, rotate +import random + + +def perturb_image( + im_input, + opts, + gaussian_std=0.01, + sp_prob=0.01, + speckle_std=0.1, + color_jitter_factor=0.1, + max_shift=2, + max_angle=2, + scale_range=(0.98, 1.02), + dropout_rate=0.02, +): + """ + Applies a random combination of small perturbations to an image (C, X, Y) in NumPy. + + This function is intended for Data Augmentation *before* the model inference + to generate slightly varied inputs for Monte Carlo Dropout or similar methods. + + Args: + im_input (np.ndarray): Input image with shape (channels, height, width). + Assumed to be a float array (e.g., normalized 0.0 to 1.0). + opts: List with the option tranformations or string indicatting 'all' ranodm text-> all + gaussian_std (float): Standard deviation for Gaussian Noise. + sp_prob (float): Probability for Salt and Pepper Noise (controls the density). + speckle_std (float): Standard deviation for Speckle Noise (multiplicative). + color_jitter_factor (float): Max factor for color perturbation (brightness/contrast). + max_shift (int): Maximum displacement in pixels (for both X and Y axes). + max_angle (float): Maximum rotation angle in degrees (e.g., ±2 degrees). + scale_range (tuple): Range (min, max) of the scaling factor (e.g., 0.98 to 1.02). + dropout_rate (float): Probability that an individual pixel will be dropped (set to 0). + + Returns: + np.ndarray: The perturbed image. + """ + + # Create a copy to avoid modifying the original array + im_out = im_input.copy() + C, X, Y = im_out.shape + + if isinstance(opts, str): + opts = [ + "gauss_noise", + "impulse_noise", + "speckle_noise", + "color_jitter", + "shift", + "rotation", + "pixel_dropout", + ] + + # Gaussian Noise (Additive) + def add_gaussian_noise(img): + # Generate Gaussian noise array + noise = np.random.normal(0, gaussian_std, img.shape) + return img + noise + + # Salt and Pepper Noise (Impulse Noise) + def add_salt_and_pepper_noise(img): + out = img.copy() + # Calculate number of salt (max value) and pepper (min value) points + num_sp_points = int(sp_prob * X * Y * C / 2) # Divide by 2 for Salt and Pepper + + # Salt Noise: max values (assuming normalized data 0.0 to 1.0) + coords_salt = [np.random.randint(0, s, num_sp_points) for s in img.shape] + out[tuple(coords_salt)] = 1.0 + + # Pepper Noise: min values + coords_pepper = [np.random.randint(0, s, num_sp_points) for s in img.shape] + out[tuple(coords_pepper)] = 0.0 + + return out + + # Speckle Noise (Multiplicative Noise) + def add_speckle_noise(img): + # Generate multiplicative noise component + noise = np.random.normal(0, speckle_std, img.shape) + return img * (1 + noise) + + # Color Jitter (Small random change in brightness/contrast) + def apply_color_jitter(img): + # Choose a random small scale factor (contrast) + scale = 1.0 + np.random.uniform(-color_jitter_factor, color_jitter_factor) + # Choose a random small offset (brightness) + offset = np.random.uniform(-color_jitter_factor / 5, color_jitter_factor / 5) + + # Apply transformation: img * scale + offset + return (img * scale) + offset + + # Shift (Translation) + def apply_shift(img): + # Random shift by 1 or 2 pixels in X and Y (axes 1 and 2) + shift_x = np.random.randint(-max_shift, max_shift + 1) + shift_y = np.random.randint(-max_shift, max_shift + 1) + + # Shift is applied to axes X and Y (1 and 2). Channel (axis 0) shift is 0. + return shift(img, (0, shift_x, shift_y), mode="nearest") + + # Rotation + def apply_rotation(img): + # Random very small angle (e.g., ±1° or ±2°) + angle = np.random.uniform(-max_angle, max_angle) + + # Rotate in the X-Y plane (axes 1 and 2). + # reshape=False ensures the output shape is the same as input. + return rotate(img, angle, axes=(1, 2), reshape=False, mode="nearest") + + # Pixel Dropout (Setting random pixels to 0) + def apply_pixel_dropout(img): + # Create a binary mask where True (1) pixels are kept and False (0) are dropped + mask = np.random.binomial(1, 1 - dropout_rate, size=img.shape) + return img * mask + + # List of all transformation functions + # Each function takes im_out and returns the transformed array + transformations = [] + if "gauss_noise" in opts: + transformations.append(add_gaussian_noise) + if "impulse_noise" in opts: + transformations.append(add_salt_and_pepper_noise) + if "speckle_noise" in opts: + transformations.append(add_speckle_noise) + if "color_jitter" in opts: + transformations.append(apply_color_jitter) + if "shift" in opts: + transformations.append(apply_shift) + if "rotation" in opts: + transformations.append(apply_rotation) + if "pixel_dropout" in opts: + transformations.append(apply_pixel_dropout) + + if len(transformations) == 0: + raise ValueError("Invalid transformations") + + # --- Random Application of Transformations --- + + # Choose how many transformations to apply (1 to N, where N is the total number of defined transforms) + num_transforms_to_apply = random.randint(1, len(transformations)) + + # Randomly select the transformations (without replacement) + selected_transforms = random.sample(transformations, num_transforms_to_apply) + + # Apply the selected transformations in random order + for transform_func in selected_transforms: + im_out = transform_func(im_out) + + # Crucial: Ensure data remains within a valid range (e.g., 0.0 to 1.0) + # The clip operation prevents extreme values generated by noise/jitter from breaking the model. + # Note: If your input images are not normalized to [0, 1], adjust this clipping range accordingly. + im_out = np.clip(im_out, 0.0, 1.0) + + return im_out + + +def Perycites_correction(seg_full): + seg_2 = remove_small_objects(seg_full == 2, min_size=30) + seg_2_mid = np.logical_xor(seg_2, remove_small_objects(seg_2, min_size=300)) + + for zz in range(seg_2_mid.shape[0]): + seg_label, num_obj = label(seg_2_mid[zz, :, :], return_num=True) + if num_obj > 0: + stats = regionprops(seg_label) + for ii in range(num_obj): + if ( + stats[ii].eccentricity < 0.88 + and stats[ii].solidity > 0.85 + and stats[ii].area < 150 + ): + seg_z = seg_2[zz, :, :] + seg_z[seg_label == (ii + 1)] = 0 + seg_2[zz, :, :] = seg_z + + seg_full[seg_full == 2] = 1 + seg_full[seg_2 > 0] = 2 + return seg_full + + +def Remove_objects(seg_full, n_classes, remove_object_size, voxel_sizes=(1, 1, 1)): + """ + Applies removal of small objects to all object classes (1 to n_classes-1) + in a 3D segmentation volume, allowing for class-specific size thresholds. + + Args: + seg_full (np.ndarray): The 3D segmentation volume with integer class values. + n_classes (int): The total number of classes in the segmentation (including background). + remove_object_size (list or int): A single minimum size (int) or a list of + minimum sizes (list) for objects to be kept. + If a list, its length must be 1 or equal + to the number of object classes (n_classes - 1). + + Returns: + np.ndarray: The segmentation volume with small objects removed for each class. + + Raises: + ValueError: If the length of remove_object_size list is invalid. + """ + pz, py, px = voxel_sizes + voxel_volume = pz * py * px + + classes_to_process = range(1, n_classes) + num_target_classes = len(classes_to_process) + + thresholds = [] + + if not isinstance(remove_object_size, list): + physical_thresholds = [remove_object_size] * num_target_classes + else: + list_len = len(remove_object_size) + + if list_len == 1: + physical_thresholds = [remove_object_size[0]] * num_target_classes + + elif list_len == num_target_classes: + physical_thresholds = remove_object_size + + else: + raise ValueError( + f"The list 'remove_object_size' has {list_len} elements, " + f"but {num_target_classes} (or 1) were expected for the {num_target_classes} classes to process " + f"(Class 1 to {n_classes - 1}). The background (Class 0) is ignored." + ) + + for physical_size in physical_thresholds: + min_voxel_count = int(np.ceil(physical_size / voxel_volume)) + thresholds.append(min_voxel_count) + + seg_cleaned = np.zeros_like(seg_full) + + for i, class_id in enumerate(classes_to_process): + min_size_threshold = thresholds[i] + + seg_class_mask = seg_full == class_id + + if seg_class_mask.any(): + seg_class_clean = remove_small_objects( + seg_class_mask, min_size=min_size_threshold + ) + seg_cleaned[seg_class_clean] = class_id + + return seg_cleaned + + +def Hole_Correction(seg_full, n_classes, hole_size_threshold, voxel_sizes=(1, 1, 1)): + """ + Applies hole correction to multiple classes in a segmentation volume. + + The correction is applied to object classes (typically 1 up to n_classes-1). + Each class can have a different hole size threshold. It also includes + an initial removal of small objects for all classes. + + Args: + seg_full (np.ndarray): The 3D segmentation volume with integer class values. + n_classes (int): The total number of classes in the segmentation (including background). + hole_size_threshold (list or int): A single threshold (int) or a list of + thresholds (list) for hole correction. + If a list, its length must be 1 or + equal to the number of object classes (n_classes - 1). + + Returns: + np.ndarray: The corrected segmentation volume. + + Raises: + ValueError: If the length of hole_size_threshold is not 1 and is less than n_classes - 1. + """ + pz, py, px = voxel_sizes + pixel_area = py * px + + classes_to_correct = range(1, n_classes) + num_target_classes = len(classes_to_correct) + + thresholds = [] + + if not isinstance(hole_size_threshold, list): + physical_thresholds = [hole_size_threshold] * num_target_classes + else: + list_len = len(hole_size_threshold) + + if list_len == 1: + physical_thresholds = [hole_size_threshold[0]] * num_target_classes + + elif list_len == num_target_classes: + physical_thresholds = hole_size_threshold + + else: + raise ValueError( + f"The list 'hole_size_threshold' has {list_len} elements, " + f"but {num_target_classes} (or 1) were expected for the {num_target_classes} classes to correct " + f"(Class 1 to {n_classes - 1}). The background (Class 0) does not need a threshold." + ) + + for physical_area in physical_thresholds: + area_threshold = int(np.ceil(physical_area / pixel_area)) + thresholds.append(area_threshold) + + seg_corrected = seg_full.copy() + + for i, class_id in enumerate(classes_to_correct): + threshold = thresholds[i] + seg_obj_mask = seg_corrected == class_id + + if seg_obj_mask.any(): + seg_obj_slice_corrected = seg_obj_mask.copy() + + for zz in range(seg_full.shape[0]): + s_v = remove_small_holes( + seg_obj_slice_corrected[zz, :, :], area_threshold=threshold + ) + seg_obj_slice_corrected[zz, :, :] = s_v[:, :] + seg_corrected[seg_corrected == class_id] = 0 + seg_corrected[seg_obj_slice_corrected] = class_id + + return seg_corrected + + +def Thickness_Corretion( + seg_full, n_classes, min_thickness_physical, voxel_sizes=(1, 1, 1) +): + """ + Applies topology-preserving thinning (thickness correction) to all object + classes (1 to n_classes-1) in a 3D segmentation volume, using a specific + minimum thickness for each class. Class 0 (background) is automatically ignored. + + Args: + seg_full (np.ndarray): The 3D segmentation volume with integer class values. + n_classes (int): The total number of classes in the segmentation (including Class 0). + min_thickness_physical (list or np.ndarray): A list or array of minimum + thickness values. The index 'i' + corresponds to the minimum thickness for Class i+1. + (e.g., index 0 is for Class 1). + + Returns: + np.ndarray: The segmentation volume where each object class has been thinned, + preserving its original class label. + """ + pz, py, px = voxel_sizes + distance_unit = (py + px) / 2 + + classes_to_process = range(1, n_classes) + num_object_classes = len(classes_to_process) + + # 1. Validate the length of the minimum thickness list + if len(min_thickness_physical) != num_object_classes: + raise ValueError( + f"The length of 'min_thickness_list' ({len(min_thickness_physical)}) does not match " + f"the number of object classes to process ({num_object_classes}). " + "Class 0 (background) is ignored, so {num_object_classes} values are expected " + " (one for each class from 1 to {n_classes-1})." + ) + + min_thickness_list = [] + for physical_distance in min_thickness_physical: + min_thickness_pixel = int(np.ceil(physical_distance / distance_unit)) + min_thickness_list.append(min_thickness_pixel) + + seg_corrected = np.zeros_like(seg_full) + + for i, class_id in enumerate(classes_to_process): + + current_min_thickness = min_thickness_list[i] + seg_class_mask = seg_full == class_id + + seg_thinned = topology_preserving_thinning( + seg_class_mask, min_thickness=current_min_thickness, thin=1 + ) + + seg_corrected[seg_thinned > 0] = class_id + + return seg_corrected + + +def adjust_volume(volume: np.ndarray) -> np.ndarray: + """ + Applies the square root (sqrt) to a NumPy volume (2D or 3D) based on + the percentage of values within specific orders of magnitude. + + Args: + volume (np.ndarray): The input volume with shape (c, y, x) or (y, x). + + Returns: + np.ndarray: The transformed volume. + """ + + # 1. Condition: 80% or more of the values are of the order e^-6 or smaller + # Range: [0, 1e-6] + + threshold_e_neg_6 = 1e-6 + # Create a mask for values less than or equal to 1e-6 (including 0) + mask_e_neg_6_or_less = np.less_equal(volume, threshold_e_neg_6) + + # Count how many elements meet the condition + count_e_neg_6_or_less = np.count_nonzero(mask_e_neg_6_or_less) + + # Calculate the percentage + total_elements = volume.size + percentage_e_neg_6_or_less = (count_e_neg_6_or_less / total_elements) * 100 + + # 2. Condition: 80% or more of the values are of the order 0.000n or 0.00n + # Range: (1e-6, 0.01] (Strictly greater than 1e-6 and less than or equal to 0.01) + + lower_threshold_00n = 1e-6 + upper_threshold_00n = 1e-2 # 0.01 + + mask_00n = np.logical_and( + np.greater(volume, lower_threshold_00n), + np.less_equal(volume, upper_threshold_00n), + ) + + # Count how many elements meet the condition + count_00n = np.count_nonzero(mask_00n) + + # Calculate the percentage + percentage_00n = (count_00n / total_elements) * 100 + + # Apply double sqrt (e^{-6} or less) + if percentage_e_neg_6_or_less >= 80: + volume = np.sqrt(np.sqrt(volume)) + + # Apply single sqrt (0.000n or 0.00n) + elif percentage_00n >= 80: + # sqrt(x) + volume = np.sqrt(volume) + + else: + return volume + + return volume + + +def Extract_Uncertainty_Maps( + logits_samples, compute_mode, relative_MI=True, var_reductor=True, estabilizer=False +): + """ + Generates an uncertainty map based on the compute_mode. + + Args: + logits_samples (list[np.ndarray]): List of N logit samples, where each + sample has shape (C, Y, X). + compute_mode (str): Uncertainty calculation mode: + 'variance' (Returns variance) or + 'mutual_inf' (Returns Mutual Information). + 'entropy' (Returns Total uncertainity) + 'prob_in' (Returns 1-Prob) + relative_MI (bool): If True and compute_mode='mutual_inf', MI is normalized + by ln(C) to the range [0, 1]. Ignored for 'variance'. + var_reductor (bool): ONLY APPLIES TO 'variance' mode. + If True, returns the minimum variance across classes + (shape: (Y, X)). + If False, returns the variance for all classes + (shape: (C, Y, X)). + + Returns: + np.ndarray: The resulting uncertainty map. Shape is (Y, X) for 'mutual_inf' + and reduced 'variance', or (C, Y, X) for unreduced 'variance'. + + Raises: + ValueError: If the list of logit samples is empty or the computation mode is invalid. + """ + + if not logits_samples: + raise ValueError("The list of logit samples cannot be empty.") + + # Convert Logits to Probabilities --- + + # Stack samples along a new axis (axis 0). Shape: (N_samples, C, Y, X) + stacked_logits = np.stack(logits_samples, axis=0) + N, C, Y, X = stacked_logits.shape + + # Apply Softmax along the class axis (axis 1) to get probabilities P. + logits_tensor = torch.from_numpy(stacked_logits).float() + + # stacked_probs.shape: (N_samples, C, Y, X) + stacked_probs = torch.nn.functional.softmax(logits_tensor, dim=1).numpy() + + # CALCULATE AND MERGE UNCERTAINTY --- + + if compute_mode == "variance": + # Calculate Variance of probabilities (class-wise uncertainty) + # uncertainty_map_split shape: (C, Y, X) + uncertainty_map_split = np.var(stacked_probs, axis=0) + + if var_reductor: + # Merge: Take the MINIMUM uncertainty across classes (axis 0) + # merged_uncertainty shape: (Y, X) + merged_uncertainty = np.min(uncertainty_map_split, axis=0) + + if estabilizer: + merged_uncertainty = adjust_volume(merged_uncertainty) + return merged_uncertainty + else: + # Return per-class uncertainty map + # shape: (C, Y, X) + if estabilizer: + uncertainty_map_split = adjust_volume(uncertainty_map_split) + return uncertainty_map_split + elif compute_mode == "prob_inv": + # Probability map + # stacked_max_probs.shape: (N_samples, Y, X) + prob_comp = 1 - np.max(stacked_probs, axis=1) + # averaged_max_probs.shape: (Y, X) + prob_comp = np.mean(prob_comp, axis=0) + + if estabilizer: + prob_comp = adjust_volume(prob_comp) + + return prob_comp + elif compute_mode == "mutual_inf" or compute_mode == "entropy": + # Calculate Mutual Information (MI) + + # helps to avoid log(0) + epsilon = 1e-12 + + # a) Average Predictive Probability (E[P(y|x)]) + avg_probs = np.mean(stacked_probs, axis=0) + # Cliping para la entropía total (H[E[P(y|x)]]) + avg_probs_clipped = np.clip(avg_probs, a_min=epsilon, a_max=None) + + # b) Total Predictive Entropy (H[E[P(y|x)]]) + entropy_total = -np.sum(avg_probs_clipped * np.log(avg_probs_clipped), axis=0) + + # c) Average Conditional Entropy (E[H[P(y|x, w)]]) + + # Cliping para la entropía condicional (P(y|x,w) * log(P(y|x,w))) + stacked_probs_clipped = np.clip(stacked_probs, a_min=epsilon, a_max=None) + + # Cálculo de la entropía por muestra: -sum(P log P) + per_sample_entropy = -np.sum( + stacked_probs_clipped * np.log(stacked_probs_clipped), axis=1 + ) + + # Promedio de la entropía por muestra (E[H[P(y|x, w)]]) + entropy_avg_conditional = np.mean(per_sample_entropy, axis=0) + + if compute_mode == "entropy": + # Entropy + mutual_information_map = entropy_total + else: + # d) Mutual Information (MI) + mutual_information_map = entropy_total - entropy_avg_conditional + + # Apply Relative Normalization if requested + if relative_MI: + # Normalization factor: ln(C), the max possible entropy (MI max theoretical bound) + max_mi = np.log(C) + + if max_mi == 0: # Handle C=1 case (though unlikely for segmentation) + return np.zeros_like(mutual_information_map) + + # Clip result to ensure strict [0, 1] range due to floating point arithmetic + normalized_mi = np.clip(mutual_information_map / max_mi, 0.0, 1.0) + + if estabilizer: + normalized_mi = adjust_volume(normalized_mi) + + return normalized_mi + else: + # Return original MI (in nats), range [0, ln(C)] + if estabilizer: + mutual_information_map = adjust_volume(mutual_information_map) + return mutual_information_map + + else: + raise ValueError("Invalid computation mode.") diff --git a/setup.cfg b/setup.cfg index d17d716..4010c51 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,4 +14,4 @@ replace = version = "{new_version}" [flake8] exclude = docs/, .git/, __pycache__/, build/, dist/, .venv/, .tox/, *.egg-info/ ignore = E203, E402, W291, W503, W293, W292, E501 -max-line-length = 88 +max-line-length = 88 \ No newline at end of file