diff --git a/docs/conf.py b/docs/conf.py index 0431409..9ad28f5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -60,18 +60,18 @@ # You can specify multiple suffix as a list of string: # source_suffix = { - ".rst": "restructuredtext", - ".txt": "markdown", - ".md": "markdown", + ".rst": "restructuredtext", + ".txt": "markdown", + ".md": "markdown", } # The master toctree document. master_doc = "index" # General information about the project. -project = u"MMV Im2Im Transformation" -copyright = u'2021, Jianxu Chen' -author = u"Jianxu Chen" +project = "MMV Im2Im Transformation" +copyright = "2021, Jianxu Chen" +author = "Jianxu Chen" # The version info for the project you"re documenting, acts as replacement # for |version| and |release|, also used in various other places throughout @@ -135,15 +135,12 @@ # The paper size ("letterpaper" or "a4paper"). # # "papersize": "letterpaper", - # The font size ("10pt", "11pt" or "12pt"). # # "pointsize": "10pt", - # Additional stuff for the LaTeX preamble. # # "preamble": "", - # Latex figure (float) alignment # # "figure_align": "htbp", @@ -153,9 +150,13 @@ # (source start file, target name, title, author, documentclass # [howto, manual, or own class]). latex_documents = [ - (master_doc, "mmv_im2im.tex", - u"MMV Im2Im Transformation Documentation", - u"Jianxu Chen", "manual"), + ( + master_doc, + "mmv_im2im.tex", + "MMV Im2Im Transformation Documentation", + "Jianxu Chen", + "manual", + ), ] @@ -164,9 +165,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, "mmv_im2im", - u"MMV Im2Im Transformation Documentation", - [author], 1) + (master_doc, "mmv_im2im", "MMV Im2Im Transformation Documentation", [author], 1) ] @@ -176,10 +175,13 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, "mmv_im2im", - u"MMV Im2Im Transformation Documentation", - author, - "mmv_im2im", - "One line description of project.", - "Miscellaneous"), + ( + master_doc, + "mmv_im2im", + "MMV Im2Im Transformation Documentation", + author, + "mmv_im2im", + "One line description of project.", + "Miscellaneous", + ), ] diff --git a/mmv_im2im/models/pl_FCN.py b/mmv_im2im/models/pl_FCN.py index e98e59a..4320479 100644 --- a/mmv_im2im/models/pl_FCN.py +++ b/mmv_im2im/models/pl_FCN.py @@ -156,4 +156,3 @@ def validation_step(self, batch, batch_idx): ) return loss - \ No newline at end of file diff --git a/mmv_im2im/proj_tester.py b/mmv_im2im/proj_tester.py index af0a8c7..baea7c6 100644 --- a/mmv_im2im/proj_tester.py +++ b/mmv_im2im/proj_tester.py @@ -16,6 +16,7 @@ from mmv_im2im.utils.for_transform import parse_monai_ops_vanilla from skimage.io import imsave as save_rgb + # from mmv_im2im.utils.piecewise_inference import predict_piecewise from monai.inferers import sliding_window_inference @@ -60,16 +61,26 @@ def setup_model(self): and self.model_cfg.model_extra["cpu_only"] ): self.cpu = True - pre_train = torch.load( - self.model_cfg.checkpoint, map_location=torch.device("cpu") + checkpoint = torch.load( + self.model_cfg.checkpoint, + map_location=torch.device("cpu"), + weights_only=False, ) else: - pre_train = torch.load(self.model_cfg.checkpoint) + checkpoint = torch.load(self.model_cfg.checkpoint, weights_only=False) + + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - # TODO: hacky solution to remove a wrongly registered key - pre_train["state_dict"].pop("criterion.xym", None) - pre_train["state_dict"].pop("criterion.xyzm", None) - self.model.load_state_dict(pre_train["state_dict"]) + pre_train = checkpoint + pre_train["state_dict"].pop("criterion.xym", None) + pre_train["state_dict"].pop("criterion.xyzm", None) + self.model.load_state_dict(pre_train["state_dict"], strict=False) + else: + + state_dict = checkpoint + state_dict.pop("criterion.xym", None) + state_dict.pop("criterion.xyzm", None) + self.model.load_state_dict(state_dict, strict=False) if not self.cpu: self.model.cuda() @@ -322,7 +333,7 @@ def run_inference(self): img = BioImage(ds).get_image_data( **self.data_cfg.inference_input.reader_params ) - + # prepare output filename if "." in suffix: if ( diff --git a/mmv_im2im/proj_trainer.py b/mmv_im2im/proj_trainer.py index 9f7682a..150286f 100644 --- a/mmv_im2im/proj_trainer.py +++ b/mmv_im2im/proj_trainer.py @@ -49,8 +49,11 @@ def run_training(self): self.model = self.model.load_from_checkpoint( self.model_cfg.model_extra["resume"] ) + elif "pre-train" in self.model_cfg.model_extra: - pre_train = torch.load(self.model_cfg.model_extra["pre-train"]) + pre_train = torch.load( + self.model_cfg.model_extra["pre-train"], weights_only=False + ) if "extend" in self.model_cfg.model_extra: if ( @@ -71,10 +74,10 @@ def run_training(self): ) model_state.update(filtered_dict) - self.model.load_state_dict(model_state) + self.model.load_state_dict(model_state, strict=False) else: pre_train["state_dict"].pop("criterion.xym", None) - self.model.load_state_dict(pre_train["state_dict"]) + self.model.load_state_dict(pre_train["state_dict"], strict=False) if self.train_cfg.callbacks is None: trainer = pl.Trainer(**self.train_cfg.params) @@ -97,4 +100,3 @@ def run_training(self): print("start training ... ") trainer.fit(model=self.model, datamodule=self.data) - \ No newline at end of file diff --git a/mmv_im2im/utils/connectivity_loss.py b/mmv_im2im/utils/connectivity_loss.py new file mode 100644 index 0000000..a2a9d4f --- /dev/null +++ b/mmv_im2im/utils/connectivity_loss.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConnectivityCoherenceLoss(nn.Module): + """ + Calculates a connectivity coherence loss to penalize fragmentation or + undesired isolated components within predicted regions. + + This loss encourages predicted regions to be spatially coherent and continuous + by comparing local neighborhoods in predictions with ground truth. + It penalizes: + 1. Discontinuities within what should be a single, connected region (e.g., breaking a vein). + 2. Isolated 'islands' of one class within another class. + + Args: + kernel_size (int): Size of the convolutional kernel for neighborhood analysis (e.g., 3 for 3x3). + ignore_background (bool): If True, the loss focuses primarily on non-background classes. + Useful if background fragmentation is less critical. + num_classes (int): Total number of classes, including background. + """ + + def __init__( + self, kernel_size: int = 3, ignore_background: bool = True, num_classes: int = 2 + ): + super().__init__() + if kernel_size % 2 == 0: + raise ValueError("kernel_size must be an odd number.") + self.kernel_size = kernel_size + self.ignore_background = ignore_background + self.num_classes = num_classes + self.average_kernel = torch.ones(1, 1, kernel_size, kernel_size) / ( + kernel_size**2 - 1 + ) + self.center_offset = (kernel_size // 2, kernel_size // 2) + + def forward(self, y_pred_softmax, y_true_one_hot): + """ + Args: + y_pred_softmax (torch.Tensor): Softmax probabilities from the model (B, C, H, W). + y_true_one_hot (torch.Tensor): Ground truth as one-hot encoded tensor (B, C, H, W). + Should be float. + + Returns: + torch.Tensor: The calculated connectivity coherence loss. + """ + current_average_kernel = self.average_kernel.to(y_pred_softmax.device) + + y_true_one_hot = y_true_one_hot.float() + + loss_per_class = [] + + for c in range(self.num_classes): + if self.ignore_background and c == 0: + continue + + true_mask_c = y_true_one_hot[:, c : c + 1, :, :] + pred_prob_c = y_pred_softmax[:, c : c + 1, :, :] + + padded_true_mask_c = F.pad( + true_mask_c, + ( + self.center_offset[1], + self.center_offset[1], + self.center_offset[0], + self.center_offset[0], + ), + mode="replicate", + ) + + neighbor_sum_true = F.conv2d( + padded_true_mask_c, current_average_kernel, padding=0, groups=1 + ) * (self.kernel_size**2) + + padded_pred_prob_c = F.pad( + pred_prob_c, + ( + self.center_offset[1], + self.center_offset[1], + self.center_offset[0], + self.center_offset[0], + ), + mode="replicate", + ) + + pred_neighbor_avg = F.conv2d( + padded_pred_prob_c, current_average_kernel, padding=0, groups=1 + ) + + true_neighbor_avg = neighbor_sum_true / (self.kernel_size**2 - 1) + + loss_b = F.mse_loss(pred_neighbor_avg, true_mask_c, reduction="none") + loss_c = F.mse_loss(pred_prob_c, true_neighbor_avg, reduction="none") + + class_coherence_loss = torch.mean(loss_b + loss_c) + loss_per_class.append(class_coherence_loss) + + if not loss_per_class: + return torch.tensor(0.0, device=y_pred_softmax.device) + + return torch.sum(torch.stack(loss_per_class)) diff --git a/mmv_im2im/utils/elbo_loss.py b/mmv_im2im/utils/elbo_loss.py index 2a3cbc0..24d58de 100644 --- a/mmv_im2im/utils/elbo_loss.py +++ b/mmv_im2im/utils/elbo_loss.py @@ -1,6 +1,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mmv_im2im.utils.fractal_layers import Slice_windows, FractalDimension +from mmv_im2im.utils.topological_loss import TI_Loss +from mmv_im2im.utils.connectivity_loss import ConnectivityCoherenceLoss +from monai.losses import GeneralizedDiceFocalLoss class KLDivergence(nn.Module): @@ -9,19 +13,27 @@ class KLDivergence(nn.Module): def __init__(self): super().__init__() - def forward(self, mu_q, logvar_q, mu_p, logvar_p): + def forward(self, mu_q, logvar_q, mu_p, logvar_p, kl_clamp=None): """ Calculates the KL Divergence between two diagonal Gaussian distributions. Args: mu_q (torch.Tensor): Mean of the approximate posterior distribution. - logvar_q (torch.Tensor): Log-variance of the approximate posterior distribution. + logvar_q (torch.Tensor): Log-variance of the approximate posterior + distribution. mu_p (torch.Tensor): Mean of the prior distribution. logvar_p (torch.Tensor): Log-variance of the prior distribution. + clamp (float): Value to clamp logvar_q, logvar_p in case of gradient explotion. Returns: torch.Tensor: The mean KL divergence over the batch. """ + # Clamp log-variances to prevent numerical instability + # This limits exp(logvar) to a stable range, e.g., [2.06e-9, 4.85e8] + if kl_clamp is not None: + logvar_q = torch.clamp(logvar_q, min=-kl_clamp, max=kl_clamp) + logvar_p = torch.clamp(logvar_p, min=-kl_clamp, max=kl_clamp) + kl_batch_sum = 0.5 * torch.sum( logvar_p - logvar_q @@ -29,54 +41,238 @@ def forward(self, mu_q, logvar_q, mu_p, logvar_p): - 1, dim=[1, 2, 3], # Sum over latent channels, H, W ) - return torch.mean(kl_batch_sum) # Average over batch + return torch.mean(kl_batch_sum) class ELBOLoss(nn.Module): """ - Calculates the Evidence Lower Bound (ELBO) loss for Probabilistic UNet. + Calculates the Evidence Lower Bound (ELBO) loss for Probabilistic UNet, + with optional fractal dimension, topological, and connectivity regularization. Args: beta (float): Weighting factor for the KL divergence term. n_classes (int): Number of classes in the segmentation task. + kl_clamp (float): Value to clamp logvar_q, logvar_p in case of gradient explotion for kl. + use_fractal_regularization (bool): If True, includes the fractal dimension regularization term. + fractal_weight (float): Weighting factor for the fractal dimension loss term (only if use_fractal_regularization is True). + fractal_num_kernels (int): Number of kernels for FractalDimension (only if use_fractal_regularization is True). + fractal_mode (str): Mode for FractalDimension ("classic" or "entropy") (only if use_fractal_regularization is True). + fractal_to_binary (bool): Whether to binarize input for FractalDimension (only if use_fractal_regularization is True). + use_topological_regularization (bool): If True, includes the topological regularization term. + topological_weight (float): Weighting factor for the topological loss term (only if use_topological_regularization is True). + topological_dim (int): Dimension for TI_Loss (2 for 2D, 3 for 3D) (only if use_topological_regularization is True). + topological_connectivity (int): Connectivity for TI_Loss (4 or 8 for 2D; 6 or 26 for 3D) (only if use_topological_regularization is True). + topological_inclusion (list): List of [A,B] class pairs for inclusion in TI_Loss (only if use_topological_regularization is True). + topological_exclusion (list): List of [A,C] class pairs for exclusion in TI_Loss (only if use_topological_regularization is True). + topological_min_thick (int): Minimum thickness for TI_Loss (only if use_topological_regularization is True and connectivity is 8 or 26). + use_connectivity_regularization (bool): If True, includes the new connectivity coherence regularization term. + connectivity_weight (float): Weighting factor for the connectivity coherence loss term. + connectivity_kernel_size (int): Kernel size for connectivity coherence loss (e.g., 3). + connectivity_ignore_background (bool): If True, ignore background for connectivity loss. + elbo_class_weights (list or torch.Tensor, optional): Weights for each class in the cross-entropy loss. + use_gdl_focal_regularization (bool): If True, Includes Generalized Dice Focal (GDF) regularization. + gdl_focal_weight (float): Weighting factor for GDF. + gdl_class_weights (list): Weights for each class. """ - def __init__(self, beta: float = 1.0, n_classes: int = 2): + def __init__( + self, + beta: float = 1.0, + n_classes: int = 2, + kl_clamp: float = None, + use_fractal_regularization: bool = False, + fractal_weight: float = 0.1, + fractal_num_kernels: int = 5, + fractal_mode: str = "classic", + fractal_to_binary: bool = True, + use_topological_regularization: bool = False, + topological_weight: float = 0.1, + topological_dim: int = 2, + topological_connectivity: int = 4, + topological_inclusion: list = None, + topological_exclusion: list = None, + topological_min_thick: int = 1, + use_connectivity_regularization: bool = False, + connectivity_weight: float = 0.1, + connectivity_kernel_size: int = 3, + connectivity_ignore_background: bool = True, + use_gdl_focal_regularization: bool = False, + gdl_focal_weight: float = 1.0, + elbo_class_weights: list = None, + gdl_class_weights: list = None, + ): super().__init__() self.beta = beta self.n_classes = n_classes + self.kl_clamp = kl_clamp self.kl_divergence_calculator = KLDivergence() + self.use_fractal_regularization = use_fractal_regularization + if self.use_fractal_regularization: + self.fractal_weight = fractal_weight + self.fractal_dimension_calculator = FractalDimension( + num_kernels=fractal_num_kernels, + mode=fractal_mode, + to_binary=fractal_to_binary, + ) + else: + self.fractal_weight = 0.0 + + self.use_topological_regularization = use_topological_regularization + if self.use_topological_regularization: + self.topological_weight = topological_weight + if topological_inclusion is None: + topological_inclusion = [] + if topological_exclusion is None: + topological_exclusion = [] + self.topological_loss_calculator = TI_Loss( + dim=topological_dim, + connectivity=topological_connectivity, + inclusion=topological_inclusion, + exclusion=topological_exclusion, + min_thick=topological_min_thick, + ) + else: + self.topological_weight = 0.0 + + # New Connectivity Regularization + self.use_connectivity_regularization = use_connectivity_regularization + if self.use_connectivity_regularization: + self.connectivity_weight = connectivity_weight + self.connectivity_coherence_calculator = ConnectivityCoherenceLoss( + kernel_size=connectivity_kernel_size, + ignore_background=connectivity_ignore_background, + num_classes=n_classes, + ) + else: + self.connectivity_weight = 0.0 + + self.use_gdl_focal_regularization = use_gdl_focal_regularization + if self.use_gdl_focal_regularization: + self.gdl_focal_weight = gdl_focal_weight + monai_focal_weights = None + if gdl_class_weights is not None: + monai_focal_weights = torch.tensor( + gdl_class_weights, dtype=torch.float32 + ) + self.gdl_focal_loss_calculator = GeneralizedDiceFocalLoss( + softmax=True, to_onehot_y=True, weight=monai_focal_weights + ) + else: + self.gdl_focal_weight = 0.0 + + # Convert class_weights list to a torch.Tensor + if elbo_class_weights is not None: + self.elbo_class_weights = torch.tensor( + elbo_class_weights, dtype=torch.float32 + ) + else: + self.elbo_class_weights = None + def forward(self, logits, y_true, prior_mu, prior_logvar, post_mu, post_logvar): """ - Computes the ELBO loss. + Computes the ELBO loss, with optional fractal dimension and topological regularization terms. Args: - logits (torch.Tensor): Output logits from the Probabilistic UNet (B, C, H, W). - y_true (torch.Tensor): Ground truth segmentation mask (B, 1, H, W or B, H, W). + logits (torch.Tensor): Output logits from the Probabilistic UNet + (B, C, H, W). + y_true (torch.Tensor): Ground truth segmentation mask (B, 1, H, W + or B, H, W). prior_mu (torch.Tensor): Mean of the prior distribution. prior_logvar (torch.Tensor): Log-variance of the prior distribution. post_mu (torch.Tensor): Mean of the approximate posterior distribution. - post_logvar (torch.Tensor): Log-variance of the approximate posterior distribution. + post_logvar (torch.Tensor): Log-variance of the approximate posterior + distribution. Returns: torch.Tensor: The calculated ELBO loss. """ # Ensure y_true has correct dimensions (e.g., [B, H, W]) for cross_entropy if y_true.ndim == 4 and y_true.shape[1] == 1: - y_true = y_true.squeeze(1) # Squeeze channel dim to [B, H, W] + y_true_squeezed = y_true.squeeze(1) # Squeeze channel dim to [B, H, W] + else: + y_true_squeezed = y_true # Negative Cross-Entropy (Log-Likelihood) - # Using reduction='mean' to get a scalar loss per batch - log_likelihood = -F.cross_entropy(logits, y_true.long(), reduction="mean") + if ( + self.elbo_class_weights is not None + and self.elbo_class_weights.device != logits.device + ): + elbo_class_weights_on_device = self.elbo_class_weights.to(logits.device) + else: + elbo_class_weights_on_device = self.elbo_class_weights + + log_likelihood = -F.cross_entropy( + logits, + y_true_squeezed.long(), + reduction="mean", + weight=elbo_class_weights_on_device, + ) # KL-Divergence kl_div = self.kl_divergence_calculator( - post_mu, post_logvar, prior_mu, prior_logvar + post_mu, post_logvar, prior_mu, prior_logvar, self.kl_clamp ) - # ELBO = Log-Likelihood - beta * KL_Divergence - # We minimize the negative ELBO to maximize the ELBO elbo_loss = -(log_likelihood - self.beta * kl_div) - return elbo_loss + total_loss = elbo_loss + + if self.use_fractal_regularization: + y_pred_mask = F.softmax(logits, dim=1).argmax(dim=1, keepdim=True).float() + + if y_true_squeezed.ndim == 3: + y_true_for_fractal = y_true_squeezed.unsqueeze(1).float() + else: + y_true_for_fractal = y_true.float() + + fd_true = self.fractal_dimension_calculator(y_true_for_fractal) + fd_pred = self.fractal_dimension_calculator(y_pred_mask) + + fractal_loss = torch.mean(torch.abs(fd_true - fd_pred)) + total_loss += self.fractal_weight * fractal_loss + + if self.use_topological_regularization: + # y_true needs to be B, C, H, W or B, C, H, W, D for TI_Loss, where C=1 + # If y_true is B, H, W, unsqueeze to B, 1, H, W + if y_true_squeezed.ndim == 3: + y_true_for_topological = y_true_squeezed.unsqueeze(1).float() + else: + y_true_for_topological = ( + y_true.float() + ) # This should already be B, 1, H, W + + # logits are B, C, H, W (or B, C, H, W, D), which is what TI_Loss expects for x + topological_loss = self.topological_loss_calculator( + logits, y_true_for_topological + ) + total_loss += self.topological_weight * topological_loss + + if self.use_connectivity_regularization: + # y_pred_softmax: (B, C, H, W) + y_pred_softmax = F.softmax(logits, dim=1) + + # y_true_one_hot: Need to convert y_true_squeezed (B, H, W) to one-hot (B, C, H, W) + # Ensure the number of classes matches n_classes used in ELBOLoss + y_true_one_hot = ( + F.one_hot(y_true_squeezed.long(), num_classes=self.n_classes) + .permute(0, 3, 1, 2) + .float() + ) + + connectivity_loss = self.connectivity_coherence_calculator( + y_pred_softmax, y_true_one_hot + ) + total_loss += self.connectivity_weight * connectivity_loss + + if self.use_gdl_focal_regularization: + # logits: (B, C, H, W) + # y_true: (B, H, W) o (B, 1, H, W) + # GeneralizedDiceFocalLoss de MONAI puede manejar esto directamente + y_true_for_gdl_focal = y_true_squeezed.unsqueeze(1).long() + gdl_focal_loss = self.gdl_focal_loss_calculator( + logits, y_true_for_gdl_focal + ) + total_loss += self.gdl_focal_weight * gdl_focal_loss + + return total_loss diff --git a/mmv_im2im/utils/fractal_layers.py b/mmv_im2im/utils/fractal_layers.py new file mode 100644 index 0000000..611a680 --- /dev/null +++ b/mmv_im2im/utils/fractal_layers.py @@ -0,0 +1,235 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from scipy.stats import linregress +import math + + +class Slice_windows(nn.Module): + """ + A PyTorch module to count windows with ones or calculate the average Shannon entropy + of windows in an input tensor. + + Args: + num_kernels (int): The number of kernel sizes to use. Must be between 1 and 10. + The kernel sizes will be 2^1, 2^2, ..., 2^num_kernels. + mode (str): The operational mode of the module. + - "classic": Counts the number of windows that contain at least one '1'. + - "entropy": Calculates the average Shannon entropy for each window. + H = -p0 * log2(p0) - p1 * log2(p1), where p0 and p1 are + the probabilities of 0s and 1s in the window, respectively. + to_binary (bool): If True, the input tensor will be converted to a binary tensor + where all values > 0 become 1, and 0 remains 0. Defaults to False. + """ + + def __init__( + self, num_kernels: int, mode: str = "classic", to_binary: bool = False + ): + super().__init__() + if not (1 <= num_kernels <= 10): + raise ValueError( + "num_kernels must be between 1 and 10 to avoid excessively large kernels." + ) + self.num_kernels = num_kernels + # Kernel sizes are powers of 2, from 2^1 to 2^num_kernels + self.kernel_sizes = [2**i for i in range(1, num_kernels + 1)] + + if mode not in ["classic", "entropy"]: + raise ValueError(f"Mode must be 'classic' or 'entropy', but got '{mode}'.") + self.mode = mode + self.to_binary = to_binary + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Processes the input tensor based on the selected mode. + + Args: + x (torch.Tensor): The input tensor, expected to be 4D (Batch, Channels, H, W). + + Returns: + torch.Tensor: A 2D tensor (Batch, num_kernels) containing the results. + If mode is "classic", it contains the count of windows with ones. + If mode is "entropy", it contains the average entropy of windows. + """ + if x.dim() != 4: + raise ValueError( + f"Input must be a 4D tensor (Batch, Channels, H, W), but got {x.dim()}D." + ) + + # Convert to binary if to_binary is True + if self.to_binary: + x = ( + x > 0 + ).float() # Convert to float to maintain tensor type for subsequent operations + # Values > 0 become 1.0, 0 remains 0.0 + + batch_size, channels, H, W = x.shape + # Initialize results tensor to store output for each batch and kernel size + results = torch.zeros(batch_size, self.num_kernels, device=x.device) + + # Iterate through each defined kernel size + for i, kernel_size in enumerate(self.kernel_sizes): + # If kernel size is larger than image dimensions, skip and set result to 0 + if kernel_size > H or kernel_size > W: + results[:, i] = ( + 0.0 # Use 0.0 for consistency with entropy which can be float + ) + continue + + # Calculate the total number of elements in a single window + elements_in_window = kernel_size * kernel_size * channels + + # Process each item in the batch + for b in range(batch_size): + # Unfold the input tensor into overlapping windows. + # x[b:b+1] is used to maintain the batch dimension for F.unfold + unfolded_windows = F.unfold( + x[b : b + 1], + kernel_size=(kernel_size, kernel_size), + stride=(kernel_size, kernel_size), + ) + + # Transpose and squeeze to get shape (num_windows, elements_in_window) + # where num_windows is the total number of non-overlapping windows + unfolded_windows = unfolded_windows.transpose(1, 2).squeeze(0) + + if self.mode == "classic": + # Classic mode: Count windows containing at least one '1' + # Sum elements in each window, if sum > 0, it has at least one '1' + windows_with_ones = (unfolded_windows.sum(dim=1) > 0).float() + current_count = torch.sum(windows_with_ones).item() + results[b, i] = current_count + + elif self.mode == "entropy": + # Entropy mode: Calculate Shannon entropy for each window + entropies = [] + # Iterate over each window + for window in unfolded_windows: + # Count number of ones in the current window + num_ones = window.sum().item() + # Count number of zeros in the current window + num_zeros = elements_in_window - num_ones + + # Calculate probabilities + p1 = num_ones / elements_in_window + p0 = num_zeros / elements_in_window + + # Shannon entropy calculation + # H = -p0 * log2(p0) - p1 * log2(p1) + # Handle log2(0) case: if p is 0, p * log2(p) is considered 0 + entropy = 0.0 + if p0 > 0: + entropy -= p0 * math.log2(p0) + if p1 > 0: + entropy -= p1 * math.log2(p1) + entropies.append(entropy) + + # Calculate the average entropy for the current batch item and kernel size + if entropies: # Ensure there are entropies to average + results[b, i] = sum(entropies) / len(entropies) + else: + results[b, i] = 0.0 # No windows, so entropy is 0 + + return results + + +class FractalDimension(nn.Module): + def __init__( + self, num_kernels: int, mode: str = "classic", to_binary: bool = False + ): + """ + Initializes the layer to estimate a scaling exponent based on box characteristics. + + Args: + num_kernels (int): The number of kernels to use for the underlying box-counting/entropy layer. + mode (str): The operational mode for the underlying Slice_windows layer. + - "classic": The estimator calculates the traditional box-counting fractal dimension. + - "entropy": The estimator calculates a scaling exponent related to how + average Shannon entropy changes with box size. + to_binary (bool): If True, the input tensor will be converted to a binary tensor + where all values > 0 become 1, and 0 remains 0. This parameter + is passed directly to the Slice_windows layer. Defaults to False. + """ + super().__init__() + # Pass the mode and to_binary directly to the Slice_windows layer + self.count_layer = Slice_windows( + num_kernels=num_kernels, mode=mode, to_binary=to_binary + ) + self.kernel_sizes = ( + self.count_layer.kernel_sizes + ) # Get kernel sizes from the sub-layer + + # Store the mode to clarify the output interpretation + self.mode = mode + self.to_binary = to_binary # Store it for potential external inspection, though not strictly used internally in forward + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Calculates a scaling exponent for each image in the batch. + + Args: + x (torch.Tensor): The input image. If `to_binary` was set to True during initialization, + this input will be converted to binary internally by `Slice_windows`. + Expected format: (Batch_size, Channels, Height, Width) + + Returns: + torch.Tensor: A tensor of size (Batch_size,) where each element + is the estimated scaling exponent for the corresponding image. + If mode is "classic", this is the box-counting fractal dimension. + If mode is "entropy", this is an entropy-based scaling exponent. + Returns 0.0 if regression cannot be performed (e.g., not enough valid points). + """ + # Get the counts/entropies from the underlying layer based on the selected mode + # The to_binary conversion happens inside self.count_layer if self.to_binary was True + results_per_kernel = self.count_layer(x) + + batch_size = x.shape[0] + scaling_exponents = torch.zeros( + batch_size, device=x.device, dtype=torch.float32 + ) + + # Calculate the scaling exponent for each image in the batch + for b in range(batch_size): + current_results = results_per_kernel[b].cpu().numpy() + + # Calculate inverse kernel sizes for the x-axis of the log-log plot + inverse_kernel_sizes = np.array([1 / k for k in self.kernel_sizes]) + + # Filter valid points for regression: only consider points where the result is greater than 0 + # This avoids issues with log(0) which would lead to -infinity + valid_indices = current_results > 0 + + if ( + np.sum(valid_indices) < 2 + ): # Need at least 2 points for a linear regression + scaling_exponents[b] = 0.0 # Cannot estimate scaling exponent + continue + + # Take the logarithm of the results and inverse kernel sizes + if self.mode == "classic": + log_results = np.log(current_results[valid_indices]) + log_inverse_kernel_sizes = np.log(inverse_kernel_sizes[valid_indices]) + if self.mode == "entropy": + log_results = current_results[valid_indices] + log_inverse_kernel_sizes = np.log(inverse_kernel_sizes[valid_indices]) + + try: + # Perform linear regression: the slope is the scaling exponent + slope, intercept, r_value, p_value, std_err = linregress( + log_inverse_kernel_sizes, log_results + ) + scaling_exponents[b] = torch.tensor( + slope, dtype=torch.float32, device=x.device + ) + except ValueError: + # This can happen if there are issues with the input to linregress (e.g., all same values) + scaling_exponents[b] = 0.0 + except Exception as e: + print(f"Error during linregress for batch {b}: {e}") + scaling_exponents[b] = 0.0 + + return scaling_exponents diff --git a/mmv_im2im/utils/topological_loss.py b/mmv_im2im/utils/topological_loss.py new file mode 100644 index 0000000..332732a --- /dev/null +++ b/mmv_im2im/utils/topological_loss.py @@ -0,0 +1,163 @@ +""" +This code was a direct implemetation for the work describes in the paper + +@inproceedings{gupta2022learning, + title={Learning Topological Interactions for Multi-Class Medical Image Segmentation}, + author={Gupta, Saumya and Hu, Xiaoling and Kaan, James and Jin, Michael and Mpoy, Mutshipay and Chung, Katherine and Singh, Gagandeep and Saltz, Mary and Kurc, Tahsin and Saltz, Joel and others}, + booktitle={Computer Vision--ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part XXIX}, + pages={701--718}, + year={2022}, + organization={Springer} +} + +https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136890691.pdf + +and provided by the github repo: https://github.com/TopoXLab/TopoInteraction +""" + +import numpy as np +import torch + +""" +The proposed topological interaction (TI) module encodes topological interactions by computing the critical voxels map. The critical voxels map contains the locations which induce errors in the topological interactions. The TI loss is introduced based on the topological interaction module. +""" + + +class TI_Loss(torch.nn.Module): + def __init__(self, dim, connectivity, inclusion, exclusion, min_thick=1): + """ + :param dim: 2 if 2D; 3 if 3D + :param connectivity: 4 or 8 for 2D; 6 or 26 for 3D + :param inclusion: list of [A,B] classes where A is completely surrounded by B. + :param exclusion: list of [A,C] classes where A and C exclude each other. + :param min_thick: Minimum thickness/separation between the two classes. Only used if connectivity is 8 for 2D or 26 for 3D + """ + super(TI_Loss, self).__init__() + + self.dim = dim + self.connectivity = connectivity + self.min_thick = min_thick + self.interaction_list = [] + self.sum_dim_list = None + self.conv_op = None + self.apply_nonlin = lambda x: torch.nn.functional.softmax(x, 1) + self.ce_loss_func = torch.nn.CrossEntropyLoss(reduction="none") + + if self.dim == 2: + self.sum_dim_list = [1, 2, 3] + self.conv_op = torch.nn.functional.conv2d + elif self.dim == 3: + self.sum_dim_list = [1, 2, 3, 4] + self.conv_op = torch.nn.functional.conv3d + + self.set_kernel() + + for inc in inclusion: + temp_pair = [] + temp_pair.append(True) # type inclusion + temp_pair.append(inc[0]) + temp_pair.append(inc[1]) + self.interaction_list.append(temp_pair) + + for exc in exclusion: + temp_pair = [] + temp_pair.append(False) # type exclusion + temp_pair.append(exc[0]) + temp_pair.append(exc[1]) + self.interaction_list.append(temp_pair) + + def set_kernel(self): + """ + Sets the connectivity kernel based on user's sepcification of dim, connectivity, min_thick + """ + k = 2 * self.min_thick + 1 + if self.dim == 2: + if self.connectivity == 4: + np_kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + elif self.connectivity == 8: + np_kernel = np.ones((k, k)) + + elif self.dim == 3: + if self.connectivity == 6: + np_kernel = np.array( + [ + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], + [[0, 1, 0], [1, 1, 1], [0, 1, 0]], + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], + ] + ) + elif self.connectivity == 26: + np_kernel = np.ones((k, k, k)) + + self.kernel = torch_kernel = torch.from_numpy( + np.expand_dims(np.expand_dims(np_kernel, axis=0), axis=0) + ) + + def topological_interaction_module(self, P): + """ + Given a discrete segmentation map and the intended topological interactions, this module computes the critical voxels map. + :param P: Discrete segmentation map + :return: Critical voxels map + """ + critical_voxels_map = torch.zeros_like(P, dtype=torch.double) + + for ind, interaction in enumerate(self.interaction_list): + interaction_type = interaction[0] + label_A = interaction[1] + label_C = interaction[2] + + # Get Masks + mask_A = torch.where(P == label_A, 1.0, 0.0).double() + if interaction_type: + mask_C = torch.where(P == label_C, 1.0, 0.0).double() + mask_C = torch.logical_or(mask_C, mask_A).double() + mask_C = torch.logical_not(mask_C).double() + else: + mask_C = torch.where(P == label_C, 1.0, 0.0).double() + + # Get Neighbourhood Information + neighbourhood_C = self.conv_op(mask_C, self.kernel.double(), padding="same") + neighbourhood_C = torch.where(neighbourhood_C >= 1.0, 1.0, 0.0) + neighbourhood_A = self.conv_op(mask_A, self.kernel.double(), padding="same") + neighbourhood_A = torch.where(neighbourhood_A >= 1.0, 1.0, 0.0) + + # Get the pixels which induce errors + violating_A = neighbourhood_C * mask_A + violating_C = neighbourhood_A * mask_C + violating = violating_A + violating_C + violating = torch.where(violating >= 1.0, 1.0, 0.0) + + critical_voxels_map = torch.logical_or( + critical_voxels_map, violating + ).double() + + return critical_voxels_map + + def forward(self, x, y): + """ + The forward function computes the TI loss value. + :param x: Likelihood map of shape: b, c, x, y(, z) with c = total number of classes + :param y: GT of shape: b, c, x, y(, z) with c=1. The GT should only contain values in [0,L) range where L is the total number of classes. + :return: TI loss value + """ + + if x.device.type == "cuda": + self.kernel = self.kernel.cuda(x.device.index) + + # Obtain discrete segmentation map + x_softmax = self.apply_nonlin(x) + P = torch.argmax(x_softmax, dim=1) + P = torch.unsqueeze(P.double(), dim=1) + del x_softmax + + # Call the Topological Interaction Module + critical_voxels_map = self.topological_interaction_module(P) + + # Compute the TI loss value + ce_tensor = torch.unsqueeze( + self.ce_loss_func(x.double(), y[:, 0].long()), dim=1 + ) + ce_tensor[:, 0] = ce_tensor[:, 0] * torch.squeeze(critical_voxels_map, dim=1) + ce_loss_value = ce_tensor.sum(dim=self.sum_dim_list).mean() + + return ce_loss_value diff --git a/mmv_im2im/utils/utils.py b/mmv_im2im/utils/utils.py new file mode 100644 index 0000000..eda4e73 --- /dev/null +++ b/mmv_im2im/utils/utils.py @@ -0,0 +1,555 @@ +from typing import List + +import numpy as np +from scipy.ndimage import distance_transform_edt +from skimage.measure import label, regionprops +from skimage.morphology import ( + ball, + disk, + dilation, + erosion, + medial_axis, + remove_small_objects, +) + + +def hole_filling( + bw: np.ndarray, hole_min: int, hole_max: int, fill_2d: bool = True +) -> np.ndarray: + """Fill holes in 2D/3D segmentation + + Parameters: + ------------- + bw: np.ndarray + a binary 2D/3D image. + hole_min: int + the minimum size of the holes to be filled + hole_max: int + the maximum size of the holes to be filled + fill_2d: bool + if fill_2d=True, a 3D image will be filled slice by slice. + If you think of a hollow tube alone z direction, the inside + is not a hole under 3D topology, but the inside on each slice + is indeed a hole under 2D topology. + + Return: + a binary image after hole filling + """ + bw = bw > 0 + if len(bw.shape) == 2: + background_lab = label(~bw, connectivity=1) + fill_out = np.copy(background_lab) + component_sizes = np.bincount(background_lab.ravel()) + too_big = component_sizes > hole_max + too_big_mask = too_big[background_lab] + fill_out[too_big_mask] = 0 + too_small = component_sizes < hole_min + too_small_mask = too_small[background_lab] + fill_out[too_small_mask] = 0 + elif len(bw.shape) == 3: + if fill_2d: + fill_out = np.zeros_like(bw) + for zz in range(bw.shape[0]): + background_lab = label(~bw[zz, :, :], connectivity=1) + out = np.copy(background_lab) + component_sizes = np.bincount(background_lab.ravel()) + too_big = component_sizes > hole_max + too_big_mask = too_big[background_lab] + out[too_big_mask] = 0 + too_small = component_sizes < hole_min + too_small_mask = too_small[background_lab] + out[too_small_mask] = 0 + fill_out[zz, :, :] = out + else: + background_lab = label(~bw, connectivity=1) + fill_out = np.copy(background_lab) + component_sizes = np.bincount(background_lab.ravel()) + too_big = component_sizes > hole_max + too_big_mask = too_big[background_lab] + fill_out[too_big_mask] = 0 + too_small = component_sizes < hole_min + too_small_mask = too_small[background_lab] + fill_out[too_small_mask] = 0 + else: + print("error in image shape") + return + + return np.logical_or(bw, fill_out) + + +def size_filter( + img: np.ndarray, min_size: int, method: str = "3D", connectivity: int = 1 +): + """size filter + + Parameters: + ------------ + img: np.ndarray + the image to filter on + min_size: int + the minimum size to keep + method: str + either "3D" or "slice_by_slice", default is "3D" + connnectivity: int + the connectivity to use when computing object size + """ + assert len(img.shape) == 3, "image has to be 3D" + if method == "3D": + return remove_small_objects( + img > 0, min_size=min_size, connectivity=connectivity + ) + elif method == "slice_by_slice": + seg = np.zeros(img.shape, dtype=bool) + for zz in range(img.shape[0]): + seg[zz, :, :] = remove_small_objects( + img[zz, :, :] > 0, + min_size=min_size, + connectivity=connectivity, + ) + return seg + else: + raise NotImplementedError(f"unsupported method {method}") + + +def topology_preserving_thinning( + bw: np.ndarray, min_thickness: int = 1, thin: int = 1 +) -> np.ndarray: + """perform thinning on segmentation without breaking topology + + Parameters: + -------------- + bw: np.ndarray + the 3D binary image to be thinned + min_thickness: int + Half of the minimum width you want to keep from being thinned. + For example, when the object width is smaller than 4, you don't + want to make this part even thinner (may break the thin object + and alter the topology), you can set this value as 2. + thin: int + the amount to thin (has to be an positive integer). The number of + pixels to be removed from outter boundary towards center. + + Return: + ------------- + A binary image after thinning + """ + bw = bw > 0 + safe_zone = np.zeros_like(bw) + for zz in range(bw.shape[0]): + if np.any(bw[zz, :, :]): + ctl = medial_axis(bw[zz, :, :] > 0) + dist = distance_transform_edt(ctl == 0) + safe_zone[zz, :, :] = dist > min_thickness + 1e-5 + + rm_candidate = np.logical_xor(bw > 0, erosion(bw > 0, ball(thin))) + + bw[np.logical_and(safe_zone, rm_candidate)] = 0 + + return bw + + +def divide_nonzero(array1, array2): + """ + Divides two arrays. Returns zero when dividing by zero. + """ + denominator = np.copy(array2) + denominator[denominator == 0] = 1e-10 + return np.divide(array1, denominator) + + +def histogram_otsu(hist): + """Apply Otsu thresholding method on 1D histogram""" + + # modify the elements in hist to avoid completely zero value in cumsum + hist = hist + 1e-5 + + bin_size = 1 / (len(hist) - 1) + bin_centers = np.arange(0, 1 + 0.5 * bin_size, bin_size) + hist = hist.astype(float) + + # class probabilities for all possible thresholds + weight1 = np.cumsum(hist) + weight2 = np.cumsum(hist[::-1])[::-1] + # class means for all possible thresholds + + mean1 = np.cumsum(hist * bin_centers) / weight1 + mean2 = (np.cumsum((hist * bin_centers)[::-1]) / weight2[::-1])[::-1] + + # Clip ends to align class 1 and class 2 variables: + # The last value of `weight1`/`mean1` should pair with zero values in + # `weight2`/`mean2`, which do not exist. + variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:]) ** 2 + + idx = np.argmax(variance12) + threshold = bin_centers[:-1][idx] + return threshold + + +def absolute_eigenvaluesh(nd_array): + """Computes the eigenvalues sorted by absolute value from the symmetrical matrix. + + Parameters: + ------------- + nd_array: nd.ndarray + array from which the eigenvalues will be calculated. + + Return: + ------------- + A list with the eigenvalues sorted in absolute ascending order (e.g. + [eigenvalue1, eigenvalue2, ...]) + """ + eigenvalues = np.linalg.eigvalsh(nd_array) + sorted_eigenvalues = sortbyabs(eigenvalues, axis=-1) + return [ + np.squeeze(eigenvalue, axis=-1) + for eigenvalue in np.split( + sorted_eigenvalues, sorted_eigenvalues.shape[-1], axis=-1 + ) + ] + + +def sortbyabs(a: np.ndarray, axis=0): + """Sort array along a given axis by the absolute value + modified from: http://stackoverflow.com/a/11253931/4067734 + """ + index = list(np.ix_(*[np.arange(i) for i in a.shape])) + index[axis] = np.abs(a).argsort(axis) + return a[tuple(index)] + + +def get_middle_frame(struct_img: np.ndarray, method: str = "z") -> int: + """find the middle z frame of an image stack + + Parameters: + ------------ + struct_img: np.ndarray + the 3D image to process + method: str + which method to use to determine the middle frame. Options + are "z" or "intensity". "z" is solely based on the number of z + frames. "intensity" method uses Otsu threshod to estimate the + volume of foreground signals in the stack, then estimated volume + of each z frame forms a z-profile, and finally another Otsu + method is apply on the z profile to find the best z frame (with + an assumption of two peaks along z profile, one near the bottom + of the cells and one near the bottom of the cells, so the optimal + separation is the middle of the stack). + + Return: + ----------- + mid_frame: int + the z index of the middle z frame + """ + + from skimage.filters import threshold_otsu + + if method == "intensity": + bw = struct_img > threshold_otsu(struct_img) + z_profile = np.zeros((bw.shape[0],), dtype=int) + for zz in range(bw.shape[0]): + z_profile[zz] = np.count_nonzero(bw[zz, :, :]) + mid_frame = None + if isinstance(round(histogram_otsu(z_profile) * bw.shape[0]), int): + mid_frame = round(histogram_otsu(z_profile) * bw.shape[0]) + else: + mid_frame = round(histogram_otsu(z_profile) * bw.shape[0]).astype(int) + + elif method == "z": + mid_frame = struct_img.shape[0] // 2 + + else: + print("unsupported method") + quit() + + return mid_frame + + +def get_3dseed_from_mid_frame( + bw: np.ndarray, + stack_shape: List = None, + mid_frame: int = -1, + hole_min: int = 1, + bg_seed: bool = True, +) -> np.ndarray: + """build a 3D seed image from the binary segmentation of a single slice + + Parameters: + ------------ + bw: np.ndarray + the 2d segmentation of a single frame, or a 3D array with only one slice + containing segmentation + stack_shape: List + (only used when bw is 2d) the shape of original 3d image, e.g. + shape_3d = img.shape + frame_index: int + (only used when bw is 2d) the index of where bw is from the whole z-stack + hole_min: int + any connected component in bw2d with size smaller than area_min + will be excluded from seed image generation + bg_seed: bool + bg_seed=True will add a background seed at the first frame (z=0). + + """ + from skimage.morphology import remove_small_objects + + out = remove_small_objects(bw > 0, hole_min) + + out1 = label(out) + stat = regionprops(out1) + + # build the seed + seed = np.zeros(stack_shape) + seed_count = 0 + if bg_seed: + seed[0, :, :] = 1 + seed_count += 1 + + for idx in range(len(stat)): + py, px = np.round(stat[idx].centroid) + seed_count += 1 + seed[mid_frame, int(py), int(px)] = seed_count + + return seed + + +def remove_hot_pixel(seg: np.ndarray) -> np.ndarray: + """ + remove hot pixel from segmentation + """ + + assert len(seg.shape) == 3, "input segmentation must be 3D" + + # make sure the segmentation is 0/1 + seg = seg.astype(np.uint8) + seg[seg > 0] = 1 + + # get sum projection along z + seg_proj = np.sum(seg, axis=0) + + # find hot pixels + hot_pixel = seg_proj >= seg.shape[0] - 2 + + # dilate the area to cover the surrounding pixels + hot_pixel = dilation(hot_pixel, disk(2)) + + # clean up every z + for z in range(seg.shape[0]): + seg_z = seg[z, :, :] + seg_z[hot_pixel] = 0 + seg[z, :, :] = seg_z + + return seg + + +def get_seed_for_objects( + raw: np.ndarray, + bw: np.ndarray, + area_min: int = 1, + area_max: int = 10000, + bg_seed: bool = True, +) -> np.ndarray: + """ + build a seed image for an image of 3D objects (assuming roughly convex shape + in 3D) using the information in the middle slice + + Parameters: + ------------ + raw: np.ndarray + orignal image used to determine middle slice + bw: np.ndarray + a round 3D segmentation, expecting the segmentation in the middle slice + having relatively good quality + area_min: int + estimated minimal size on one single slice (major body chunk, e.g. the + center XY plane of a 3D ball) of an object + area_max: int + estimated maximal size on one single slice (major body chunk, e.g. the + center XY plane of a 3D ball) of an object. It is recommended to be + conservertive (setting this value a little larger) + bg_seed: bool + bg_seed=True will add a background seed at the first frame (z=0). + + """ + from skimage.morphology import remove_small_objects + + # determine middle slice + mid_z = get_middle_frame(raw, method="intensity") + + # take seg of middle slice + bw2d = bw[mid_z, :, :] + + # fillin holes to form solid objects + bw2d_fill = hole_filling(bw2d, area_min, area_max) + + # prune the objects in middle slice + out = remove_small_objects(bw2d_fill > 0, area_min) + + # extract object and calculate centroid + out1 = label(out) + stat = regionprops(out1) + + # use each centroid as one seed + seed = np.zeros(raw.shape) + seed_count = 0 + if bg_seed: + seed[0, :, :] = 1 + seed_count += 1 + + for idx in range(len(stat)): + py, px = np.round(stat[idx].centroid) + seed_count += 1 + seed[mid_z, int(py), int(px)] = seed_count + + return seed.astype(int) + + +def segmentation_union(seg: List) -> np.ndarray: + """merge multiple segmentations into a single result + + Parameters + ------------ + seg: List + a list of segmentations, should all have the same shape + """ + + return np.logical_or.reduce(seg) + + +def segmentation_intersection(seg: List) -> np.ndarray: + """get the intersection of multiple segmentations into a single result + + Parameters + ------------ + seg: List + a list of segmentations, should all have the same shape + """ + + return np.logical_and.reduce(seg) + + +def segmentation_xor(seg: List) -> np.ndarray: + """get the XOR of multiple segmentations into a single result + + Parameters + ------------ + seg: List + a list of segmentations, should all have the same shape + """ + + return np.logical_xor.reduce(seg) + + +def remove_index_object( + label: np.ndarray, id_to_remove: List[int] = [1], in_place: bool = False +) -> np.ndarray: + if in_place: + img = label + else: + img = label.copy() + + for id in id_to_remove: + img[img == id] = 0 + + return img + + +def peak_local_max_wrapper( + struct_img_for_peak: np.ndarray, bw: np.ndarray +) -> np.ndarray: + from skimage.feature import peak_local_max + + local_maxi = peak_local_max(struct_img_for_peak, labels=label(bw), min_distance=2) + local_maxi_image = np.zeros_like(struct_img_for_peak) + local_maxi_image[tuple(local_maxi.T)] = True + return local_maxi_image + + +def watershed_wrapper(bw: np.ndarray, local_maxi: np.ndarray) -> np.ndarray: + from scipy.ndimage import distance_transform_edt + from skimage.measure import label + from skimage.morphology import dilation, ball + from skimage.segmentation import watershed + + distance = distance_transform_edt(bw) + im_watershed = watershed( + -distance, + label(dilation(local_maxi, footprint=ball(1))), + mask=bw, + watershed_line=True, + ) + return im_watershed + + +def prune_z_slices(bw: np.ndarray): + """ + prune the segmentation by only keep a certain range of z-slices + with the assumption of all signals living only in a few consecutive + z-slices. This function will first determine the key z-slice where most + of the signals living on and then include a few slices up/down along z + to make the segmentation completed. This is useful when you have prior + knowledge about your segmentation target and can effectively exclude + small segmented objects due to noise/artifacts in those z-slices we are + sure the signal should not live on. + + Parameters: + ----------- + bw: np.ndarray + the segmentation before pruning + """ + bw_z = np.zeros(bw.shape[0], dtype=np.uint16) + for zz in range(bw.shape[0]): + bw_z[zz] = np.count_nonzero(bw[zz, :, :] > 0) + + mid_z = np.argmax(bw_z) + low_z = 0 + high_z = bw.shape[0] - 2 + for ii in np.arange(mid_z - 1, 0, -1): + if bw_z[ii] < 100: + low_z = ii + break + for ii in range(mid_z + 1, bw.shape[0] - 1, 1): + if bw_z[ii] < 100: + high_z = ii + break + + seg = bw.copy() + seg[:low_z, :, :] = 0 + seg[high_z + 1 :, :, :] = 0 + + return seg + + +def cell_local_adaptive_threshold( + structure_img_smooth: np.ndarray, cell_wise_min_area: int +): + from skimage.filters import threshold_triangle, threshold_otsu + from skimage.morphology import dilation + + # cell-wise local adaptive thresholding + th_low_level = threshold_triangle(structure_img_smooth) + + bw_low_level = structure_img_smooth > th_low_level + bw_low_level = remove_small_objects( + bw_low_level, min_size=cell_wise_min_area, connectivity=1, out=bw_low_level + ) + bw_low_level = dilation(bw_low_level, footprint=ball(2)) + + bw_high_level = np.zeros_like(bw_low_level) + lab_low, num_obj = label(bw_low_level, return_num=True, connectivity=1) + + for idx in range(num_obj): + single_obj = lab_low == (idx + 1) + local_otsu = threshold_otsu(structure_img_smooth[single_obj > 0]) + bw_high_level[ + np.logical_and(structure_img_smooth > local_otsu * 0.98, single_obj) + ] = 1 + return bw_high_level + + +def invert_mask(img): + return 1 - img + + +def mask_image(image, mask, value: int = 0): + image[mask] = value + return image diff --git a/pyproject.toml b/pyproject.toml index cdc0021..3de9e79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,27 +21,27 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "lightning>=2.0.1", - "torch==2.0.1", - "monai>=1.2.0", - "bioio==1.6.1", + "lightning>=2.5.2", + "torch>=2.6.0", + "monai>=1.5.0", + "bioio", "pandas", "scikit-image", - "protobuf<4.21.0", + "protobuf", "pyrallis", "scikit-learn", "tensorboard", "numba", - "numpy<2", - "pydantic==2.11.7", + "numpy", + "pydantic", "fastapi", "uvicorn", - "botocore==1.38.38", - "bioio-ome-tiff==1.1.0", - "bioio-ome-zarr==1.2.0", + "botocore", + "bioio-ome-tiff", + "bioio-ome-zarr", "pydantic-zarr", - "bioio-tifffile==1.1.0", - "bioio-lif==1.1.0", + "bioio-tifffile", + "bioio-lif", "ngff-zarr", "tifffile", "ome-types", @@ -56,7 +56,7 @@ advanced = [ paper = [ "quilt3", "pooch", - "matplotlib", + "matplotlib>=3.10.5", "notebook", ] @@ -72,7 +72,8 @@ dev = [ "pytest>=5.4.3", "pytest-cov>=2.9.0", "pytest-raises>=0.11", - "numpy<2", + "scipy>=1.10.0", + "numpy", "bump2version>=1.0.1", "coverage>=5.1", "ipython>=7.15.0", @@ -92,13 +93,13 @@ test = [ "pytest>=5.4.3", "pytest-cov>=2.9.0", "pytest-raises>=0.11", - "numpy<2", + "numpy", ] data_requirements = [ "quilt3", "pooch", - "matplotlib", + "matplotlib>=3.10.5", "notebook" ] diff --git a/script/generate_synthetic_data.py b/script/generate_synthetic_data.py index ead4ef4..417f2ac 100644 --- a/script/generate_synthetic_data.py +++ b/script/generate_synthetic_data.py @@ -143,7 +143,7 @@ def generate_data(args): px = randint(15, im.shape[-0] - 15) im[py, px] = 1 im = dilation(im > 0, disk(5)).astype(np.float32) - imsave(out_raw_fn, im) + imsave(out_raw_fn, im) raw = gaussian_filter(im, 5) raw = random_noise(raw).astype(np.float32) # raw = random_noise(raw, mode="salt").astype(np.float32) diff --git a/script/pull_labelfree_sample_data.py b/script/pull_labelfree_sample_data.py index cb4ea5f..a21f576 100644 --- a/script/pull_labelfree_sample_data.py +++ b/script/pull_labelfree_sample_data.py @@ -92,7 +92,7 @@ def show_info(self): log.debug("Command Line:") log.debug("\t{}".format(" ".join(sys.argv))) log.debug("Args:") - for (k, v) in self.__dict__.items(): + for k, v in self.__dict__.items(): log.debug("\t{}: {}".format(k, v)) @@ -139,7 +139,7 @@ def execute(self, args): holdout_path = holdout_path_base / Path(cline) holdout_path.mkdir(exist_ok=True) - # download all FOVs or a certain + # download all FOVs or a certain if num_samples_per_cell_line > 0: num = num_samples_per_cell_line else: @@ -162,9 +162,7 @@ def execute(self, args): bf_img = reader.get_image_data( "ZYX", C=row.ChannelNumberBrightfield, S=0, T=0 ) - str_img = reader.get_image_data( - "ZYX", C=row.ChannelNumberStruct, S=0, T=0 - ) + str_img = reader.get_image_data("ZYX", C=row.ChannelNumberStruct, S=0, T=0) if random.random() < 0.2: data_path = holdout_path @@ -197,4 +195,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tutorials/example_by_use_case.md b/tutorials/example_by_use_case.md index bb06d5d..594a423 100644 --- a/tutorials/example_by_use_case.md +++ b/tutorials/example_by_use_case.md @@ -5,7 +5,7 @@ This page lists the jupyter notebooks (for downloading the data from public reso | Application | data preparation | training config | inference config | | :---: | :---: | :---: | :---: | | 3D labelfree prediction | [notebook](../paper_configs/prepare_data/labelfree_3d.ipynb) | [FCN](../paper_configs/labelfree_3d_FCN_train.yaml), [pix2pix_from_scratch](../paper_configs/labelfree_3d_pix2pix_train.yaml), [pix2pix_transfer_learning](../paper_configs/labelfree_3d_pix2pix_finetune.yaml) | [FCN](../paper_configs/labelfree_3d_FCN_inference.yaml), [ix2pix](../paper_configs/labelfree_3d_pix2pix_inference.yaml)| -| 2D labelfree prediction | [notebook](../paper_configs/prepare_data/labelfree_2d.ipynb) | [FCN](../paper_configs/labelfree_2d_FCN_train.yaml | [FCN](../paper_configs/labelfree_2d_FCN_inference.yaml) | +| 2D labelfree prediction | [notebook](../paper_configs/prepare_data/labelfree_2d.ipynb) | [FCN](../paper_configs/labelfree_2d_FCN_train.yaml) | [FCN](../paper_configs/labelfree_2d_FCN_inference.yaml) | | 2D semantic segmentation | [notebook](../paper_configs/prepare_data/semantic_seg_2d.ipynb) | [FCN](../paper_configs/semantic_seg_2d_train.yaml) | [FCN](../paper_configs/semantic_seg_2d_inference.yaml) | | 3D semantic segmentation | [notebook](../paper_configs/prepare_data/semantic_seg_3d.ipynb) | [FCN](../paper_configs/semantic_seg_3d_train.yaml) | [FCN](../paper_configs/semantic_seg_3d_inference.yaml) | | 2D instance segmentation | [notebook](../paper_configs/prepare_data/instance_seg_2d.ipynb) | [EmbedSeg](../paper_configs/instance_seg_2d_train.yaml) | [Embedseg](../paper_configs/instance_seg_2d_inference.yaml) |