diff --git a/aicsmlsegment/Model.py b/aicsmlsegment/Model.py index b044f40..fe8807f 100644 --- a/aicsmlsegment/Model.py +++ b/aicsmlsegment/Model.py @@ -12,7 +12,7 @@ undo_resize, UniversalDataset, ) -from aicsmlsegment.utils import compute_iou +from aicsmlsegment.utils import compute_iou, save_image import numpy as np from skimage.io import imsave @@ -118,6 +118,7 @@ def __init__(self, config, model_config, train): self.args_inference["inference_batch_size"] = config["batch_size"] self.args_inference["mode"] = config["mode"]["name"] self.args_inference["Threshold"] = config["Threshold"] + self.uncertainty = config["uncertainty"] if config["large_image_resize"] != [1, 1, 1]: self.aggregate_img = {} self.count_map = {} @@ -239,15 +240,6 @@ def on_train_epoch_start(self): ) self.iter_dataloader = iter(self.DATALOADER) - def get_upsample_grid(self, desired_shape, n_targets): - x = torch.linspace(-1, 1, desired_shape[-1], device=self.device) - y = torch.linspace(-1, 1, desired_shape[-2], device=self.device) - z = torch.linspace(-1, 1, desired_shape[-3], device=self.device) - meshz, meshy, meshx = torch.meshgrid((z, y, x)) - grid = torch.stack((meshx, meshy, meshz), 3) - grid = torch.stack([grid] * n_targets) # one grid for each target in batch - return grid - def log_and_return(self, name, value): # sync_dist on_epoch=True ensures that results will be averaged across gpus self.log( @@ -272,7 +264,6 @@ def training_step(self, batch, batch_idx): targets = batch[1] cmap = batch[2] outputs = self(inputs) - vae_loss = 0 if self.model_name == "segresnetvae": # segresnetvae forward returns an additional vae loss term @@ -293,7 +284,7 @@ def validation_step(self, batch, batch_idx): costmap = batch[2] # fn = batch[3] - outputs, vae_loss = model_inference( + outputs, vae_loss, _ = model_inference( self.model, input_img, self.args_inference, @@ -335,7 +326,7 @@ def test_step(self, batch, batch_idx): if self.aggregate_img is not None: to_numpy = False # prevent excess gpu->cpu data transfer - output_img, _ = apply_on_image( + output_img, _, uncertaintymap = apply_on_image( self.model, img, args_inference, @@ -344,9 +335,11 @@ def test_step(self, batch, batch_idx): softmax=True, model_name=self.model_name, extract_output_ch=True, + uncertainty=self.uncertainty, ) - if self.aggregate_img is not None: + if self.uncertainty is not None: + print("Uncertainty is not yet supported with large image resizing.") # initialize the aggregate img i, j, k = batch["ijk"][0], batch["ijk"][1], batch["ijk"][2] if fn not in self.aggregate_img: @@ -378,8 +371,6 @@ def test_step(self, batch, batch_idx): # only want to perform post-processing and saving once the aggregated image # is complete or we're not aggregating an image if self.batch_count[fn] % save_n_batches == 0: - from aicsimageio.writers.ome_tiff_writer import OmeTiffWriter - if self.aggregate_img is not None: # normalize for overlapping patches output_img = self.aggregate_img[fn] / self.count_map[fn] @@ -409,10 +400,15 @@ def test_step(self, batch, batch_idx): path = self.config["OutputDir"] + os.sep + pathlib.PurePosixPath(fn).stem if tt != -1: path = path + "_T_" + f"{tt:03}" - path += "_struct_segmentation.tiff" - with OmeTiffWriter(path, overwrite_file=True) as writer: - writer.save( - data=out, - channel_names=[self.config["segmentation_name"]], - dimension_order="CZYX", + + save_image( + path + "_struct_segmentation.tiff", + out, + [self.config["segmentation_name"]], + ) + if uncertaintymap is not None: + save_image( + path + "_" + self.uncertainty + "_uncertainty.tiff", + uncertaintymap, + [self.uncertainty], ) diff --git a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py index 2adb776..b5e09f2 100644 --- a/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py +++ b/aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py @@ -4,11 +4,18 @@ class UNet3D(nn.Module): def __init__( - self, in_channel, n_classes, down_ratio, test_mode=True, batchnorm_flag=True + self, + in_channel, + n_classes, + down_ratio, + test_mode=True, + batchnorm_flag=True, + dropout=0, ): self.in_channel = in_channel self.n_classes = n_classes self.test_mode = test_mode + self.dropout = dropout super(UNet3D, self).__init__() k = down_ratio @@ -129,6 +136,7 @@ def encoder( padding=0, bias=True, batchnorm=False, + dropout=0, ): if batchnorm: layer = nn.Sequential( @@ -142,6 +150,7 @@ def encoder( ), nn.BatchNorm3d(out_channels, affine=False), nn.ReLU(), + nn.Dropout3d(p=dropout), nn.Conv3d( out_channels, 2 * out_channels, @@ -152,6 +161,7 @@ def encoder( ), nn.BatchNorm3d(2 * out_channels, affine=False), nn.ReLU(), + nn.Dropout3d(p=dropout), ) else: layer = nn.Sequential( @@ -164,6 +174,7 @@ def encoder( bias=bias, ), nn.ReLU(), + nn.Dropout3d(p=dropout), nn.Conv3d( out_channels, 2 * out_channels, @@ -173,6 +184,7 @@ def encoder( bias=bias, ), nn.ReLU(), + nn.Dropout3d(p=dropout), ) return layer @@ -185,6 +197,7 @@ def decoder( padding=0, bias=True, batchnorm=False, + dropout=0, ): if batchnorm: layer = nn.Sequential( @@ -198,6 +211,7 @@ def decoder( ), nn.BatchNorm3d(out_channels, affine=False), nn.ReLU(), + nn.Dropout3d(p=dropout), nn.Conv3d( out_channels, out_channels, @@ -208,6 +222,7 @@ def decoder( ), nn.BatchNorm3d(out_channels, affine=False), nn.ReLU(), + nn.Dropout3d(p=dropout), ) else: layer = nn.Sequential( @@ -220,6 +235,7 @@ def decoder( bias=bias, ), nn.ReLU(), + nn.Dropout3d(p=dropout), nn.Conv3d( out_channels, out_channels, @@ -229,6 +245,7 @@ def decoder( bias=bias, ), nn.ReLU(), + nn.Dropout3d(p=dropout), ) return layer diff --git a/aicsmlsegment/model_utils.py b/aicsmlsegment/model_utils.py index 8802bac..45c8c07 100644 --- a/aicsmlsegment/model_utils.py +++ b/aicsmlsegment/model_utils.py @@ -3,6 +3,7 @@ from pathlib import Path, PurePosixPath from aicsmlsegment.multichannel_sliding_window import sliding_window_inference from aicsmlsegment.fnet_prediction_torch import predict_piecewise +from typing import List def flip(img: np.ndarray, axis: int) -> torch.Tensor: @@ -32,6 +33,7 @@ def apply_on_image( softmax: bool, model_name, extract_output_ch: bool, + uncertainty: str = None, ) -> np.ndarray: """ Highest level API to perform inference on an input image through a model with @@ -65,6 +67,7 @@ def apply_on_image( to_numpy=to_numpy, extract_output_ch=extract_output_ch, softmax=softmax, + uncertainty=uncertainty, ) else: out0, vae_loss = model_inference( @@ -75,6 +78,8 @@ def apply_on_image( to_numpy=False, softmax=softmax, model_name=model_name, + extract_output_ch=extract_output_ch, + uncertainty=uncertainty, ) input_img = input_img[0] # remove batch_dimension for flip for i in range(3): @@ -88,6 +93,7 @@ def apply_on_image( softmax=softmax, model_name=model_name, extract_output_ch=extract_output_ch, + uncertainty=uncertainty, ) aug_flip = flip(out, axis=i) out0 += aug_flip @@ -136,6 +142,39 @@ def get_supported_model_names(): print(all_names + flist) +def calculate_uncertainty(output_img_list: List, uncertainty_type: str) -> np.ndarray: + # dimension of the output_img_list would be [10,1,C,Z,Y,X] + output_img_list = np.array(output_img_list) + if uncertainty_type == "softmax": + # just use the first maximum softmax as teh uncertaintymap_softmax + uncertaintymap = 1 - np.squeeze(np.max(output_img_list[0], 1), 0) + elif uncertainty_type == "variance": + # entropy code adapted from https://github.com/tanyanair/segmentation_uncertainty + uncertaintymap = np.squeeze( + ( + np.mean(np.square(output_img_list), 0) + - np.square(np.mean(output_img_list, 0)) + ), + 0, + )[-1, ...] + elif uncertainty_type == "entropy" or uncertainty_type == "mutual_information": + entropy = -np.sum( + np.mean(output_img_list, 0) * np.log(np.mean(output_img_list, 0) + 1e-5), + 1, + ) + expected_entropy = -np.mean( + np.sum(output_img_list * np.log(output_img_list + 1e-5), 2), 0 + ) + if uncertainty_type == "entropy": + uncertaintymap = np.squeeze(entropy, (0)) + else: + uncertaintymap = np.squeeze((entropy - expected_entropy), 0) + else: + print(f"Uncertainty {uncertainty_type} is not currently supported. ") + uncertaintymap = None + return uncertaintymap + + def model_inference( model, input_img: torch.Tensor, @@ -145,63 +184,80 @@ def model_inference( to_numpy: bool = False, extract_output_ch: bool = True, softmax: bool = False, + uncertainty: str = None, ): """ perform model inference and extract output channel """ - if args["size_in"] == args["size_out"]: - dims_max = [0] + args["size_in"] - overlaps = [int(0.1 * dim) for dim in dims_max] - result = predict_piecewise( - model, - input_img[0], - dims_max=dims_max, - overlaps=overlaps, - ) - for i in range(input_img.shape[0]): - output = predict_piecewise( - model, - input_img[i], - dims_max=dims_max, - overlaps=overlaps, - mode="fast", + uncertainty_iter = 1 + + if uncertainty is not None: + uncertainty_iter = 10 + # enable dropout layer during inference + model.train() + for m in model.modules(): + if isinstance(m, torch.nn.BatchNorm3d): + m.eval() + output_img_list = [] + + for i in range(uncertainty_iter): + if args["size_in"] == args["size_out"]: + dims_max = [0] + args["size_in"] + overlaps = [int(0.1 * dim) for dim in dims_max] + for j in range(input_img.shape[0]): # validation only, no uncertainty used + output = predict_piecewise( + model, + input_img[j], + dims_max=dims_max, + overlaps=overlaps, + # TODO update to config-dependent + mode="fast", + ) + if j == 0: + result = output + else: + result = torch.cat((result, output), dim=0) + vae_loss = 0 + else: + input_image_size = np.array((input_img.shape)[-3:]) + added_padding = np.array( + [2 * ((x - y) // 2) for x, y in zip(args["size_in"], args["size_out"])] ) - if i == 0: - result = output - else: - result = torch.cat((result, output), dim=0) - vae_loss = 0 - else: - input_image_size = np.array((input_img.shape)[-3:]) - added_padding = np.array( - [2 * ((x - y) // 2) for x, y in zip(args["size_in"], args["size_out"])] - ) - original_image_size = input_image_size - added_padding - with torch.no_grad(): - result, vae_loss = sliding_window_inference( - inputs=input_img, - roi_size=args["size_in"], - out_size=args["size_out"], - original_image_size=original_image_size, - sw_batch_size=1, - predictor=model.forward, - overlap=0.25, - mode="gaussian", - model_name=model_name, + original_image_size = input_image_size - added_padding + with torch.no_grad(): + result, vae_loss = sliding_window_inference( + inputs=input_img, + roi_size=args["size_in"], + out_size=args["size_out"], + original_image_size=original_image_size, + sw_batch_size=1, + predictor=model.forward, + overlap=0.25, + mode="gaussian", + model_name=model_name, + ) + if uncertainty is not None: + output_img_list.append( + torch.nn.Softmax(dim=1)(result).detach().cpu().numpy() ) - - if softmax: - result = torch.nn.Softmax(dim=1)(result) - if extract_output_ch: - # old models - if type(args["OutputCh"]) == list and len(args["OutputCh"]) >= 2: - args["OutputCh"] = args["OutputCh"][1] - result = result[:, args["OutputCh"], :, :, :] - if not squeeze: - result = torch.unsqueeze(result, dim=1) - if to_numpy: - result = result.detach().cpu().numpy() - return result, vae_loss + if uncertainty is not None: + uncertaintymap = calculate_uncertainty(output_img_list, uncertainty) + # specifying uncertainty results in squeeze = True + result = np.mean(output_img_list, axis=0)[:, -1, ...] + else: + uncertaintymap = None + if softmax: + result = torch.nn.Softmax(dim=1)(result) + if extract_output_ch: + # old models + if type(args["OutputCh"]) == list and len(args["OutputCh"]) >= 2: + args["OutputCh"] = args["OutputCh"][1] + result = result[:, args["OutputCh"], :, :, :] + if not squeeze: + result = torch.unsqueeze(result, dim=1) + if to_numpy: + result = result.detach().cpu().numpy() + return result, vae_loss, uncertaintymap def weights_init(m): diff --git a/aicsmlsegment/utils.py b/aicsmlsegment/utils.py index 6349cf2..fc9e02b 100644 --- a/aicsmlsegment/utils.py +++ b/aicsmlsegment/utils.py @@ -11,6 +11,7 @@ from monai.networks.layers import Norm, Act import os import datetime +from aicsimageio.writers.ome_tiff_writer import OmeTiffWriter REQUIRED_CONFIG_FIELDS = { @@ -68,6 +69,7 @@ "large_image_resize": None, "precision": None, "segmentation_name": None, + "uncertainty": None, }, } @@ -84,6 +86,7 @@ "large_image_resize": [1, 1, 1], "epoch_shuffle": None, "segmentation_name": "segmentation", + "uncertainty": None, } MODEL_PARAMETERS = { @@ -105,7 +108,7 @@ "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], }, "unet_xy_zoom_0pad": { - "Optional": [], + "Optional": ["dropout"], "Required": ["nchannel", "nclass", "size_in", "size_out", "zoom_ratio"], }, "unet_xy_zoom_0pad_stridedconv": { @@ -313,7 +316,7 @@ def create_unique_run_directory(config, train): for sub in subfolders if subdir_names[train][1:] in sub ] - if len(subfolders) > 0: + if len(run_numbers) > 0: most_recent_run_number = max(run_numbers) most_recent_run_dir = ( dir_name + subdir_names[train] + str(most_recent_run_number) @@ -594,3 +597,21 @@ def get_logger(name, level=logging.INFO): logger.addHandler(stream_handler) return logger + + +def save_image( + path: str, + img: np.ndarray, + channel_names: List, + dimension_order: str = "CZYX", + overwrite: bool = True, +): + with OmeTiffWriter( + path, + overwrite_file=overwrite, + ) as writer: + writer.save( + data=img, + channel_names=channel_names, + dimension_order=dimension_order, + ) diff --git a/configs/all_predict_options.yaml b/configs/all_predict_options.yaml index bbd97f3..8edc628 100644 --- a/configs/all_predict_options.yaml +++ b/configs/all_predict_options.yaml @@ -33,7 +33,7 @@ ResizeRatio: [1,1,1] #ratio to resize image Normalization: 18 # normalization 'recipe', should match training data Threshold: -1 # whether to binarize output images. Either -1 or an integer between 0 and 1 RuntimeAug: False # run prediction on four flipped versions of the image - increases quality, takes ~4x longer - +uncertainty: 'entropy' # one of 'variance', 'entropy', 'mutual_information', or 'softmax'. increases inference time and outputs an uncertainty map. only used with dropout models mode: name: file # whether to predict on individual file or entire folder of images InputFile: "path/to/input/input/img" diff --git a/docs/doc_pred_yaml.md b/docs/doc_pred_yaml.md index 9d86178..54a8f7a 100644 --- a/docs/doc_pred_yaml.md +++ b/docs/doc_pred_yaml.md @@ -72,10 +72,11 @@ large_image_resize: [1,1,1] Threshold: 0.75 RuntimeAug: False Normalization: 10 +uncertainty: 'entropy' ``` -`DataType` is the type of images to be processed in `InputDir`, which the `InputCh`'th (keep the [ ]) channel of each image will be segmented. If your model is trained on images of a certain resolution and your test images are of different resolution `ResizeRatio` needs to be set as [new_z_size/old_z_size, new_y_size/old_y_size, new_x_size/old_x_size]. The actual output is the likelihood of each voxels being the target structure. `large_image_resize` can be set if large images cause GPU out of memory during prediction. This parameter specifies how many patches in the `ZYX` axes each image should be split into. After patch-wise prediction, the final image is reconstructed based on overlap between patches. A `Threshold` between 0 and 1 needs to be set to generate the binary mask. We recommend to use 0.6 ~ 0.9. When `Threshold` is set as `-1`, the raw prediction from the model will be saved, for users to determine a proper binary cutoff. `Normalization` is the index of a list of pre-defined normalization recipes and should be the same index as generating training data (see [Curator](./bb2.md) for the full list of normalization recipes). If `RuntimeAug` is `True`, the model will predict on the original image and three flipped versions of the image. The final prediction is then averaged across each flipped prediction. This increases prediction quality, but takes ~4x longer. +`DataType` is the type of images to be processed in `InputDir`, which the `InputCh`'th (keep the [ ]) channel of each image will be segmented. If your model is trained on images of a certain resolution and your test images are of different resolution `ResizeRatio` needs to be set as [new_z_size/old_z_size, new_y_size/old_y_size, new_x_size/old_x_size]. The actual output is the likelihood of each voxels being the target structure. `large_image_resize` can be set if large images cause GPU out of memory during prediction. This parameter specifies how many patches in the `ZYX` axes each image should be split into. After patch-wise prediction, the final image is reconstructed based on overlap between patches. A `Threshold` between 0 and 1 needs to be set to generate the binary mask. We recommend to use 0.6 ~ 0.9. When `Threshold` is set as `-1`, the raw prediction from the model will be saved, for users to determine a proper binary cutoff. `Normalization` is the index of a list of pre-defined normalization recipes and should be the same index as generating training data (see [Curator](./bb2.md) for the full list of normalization recipes). If `RuntimeAug` is `True`, the model will predict on the original image and three flipped versions of the image. The final prediction is then averaged across each flipped prediction. This increases prediction quality, but takes ~4x longer. `Uncertainty` is an optional argument that can be used with dropout models. Possible values area `entropy`, `softmax`, `variance`, or `mutual_information`. These uncertainty estimation techniques run inference 10 times on each image and calculate the uncertainty based on differences between runs. \ No newline at end of file diff --git a/setup.py b/setup.py index 86a7af5..91d8c92 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ "scipy>=1.1.0", "scikit-image", "pandas>=0.23.4", - "aicsimageio>3.3.0", + "aicsimageio=3.3.3", "tqdm", "pyyaml", "monai>=0.4.0",