diff --git a/segmenter_model_zoo/zoo.py b/segmenter_model_zoo/zoo.py index b3e9cd0..1da4a26 100644 --- a/segmenter_model_zoo/zoo.py +++ b/segmenter_model_zoo/zoo.py @@ -7,12 +7,12 @@ import importlib import torch -from torch.autograd import Variable from aicsmlsegment.utils import input_normalization from scipy.ndimage import zoom from aicsimageio import AICSImage from segmenter_model_zoo.quilt_utils import validate_model +from aicsmlsegment.multichannel_sliding_window import sliding_window_inference ############################################################################### @@ -214,6 +214,7 @@ def load_train( else: model_type = CHECKPOINT_PATH_MAPPING[checkpoint_name]["model_type"] + self.model_name = model_type # load default model parameters or from model_param if "size_in" in model_param: @@ -304,6 +305,8 @@ def apply_on_single_zstack( already_normalized: bool = False, cutoff: float = None, inference_param: Dict = {}, + size_in: List = None, + size_out: List = None, ) -> np.ndarray: """ Apply a trained model on an image @@ -334,6 +337,10 @@ def apply_on_single_zstack( only one parameter is allowed: "ResizeRatio" (a list of three float numbers to indicate the ResizeRatio to apply on ZYX axis). More parameters may be added in the future. + size_in: List + the input patch size, to overwrite default + size_out: List + the output patch size, to overwrite default Return: ------------- @@ -383,8 +390,14 @@ def apply_on_single_zstack( model = self.model model.eval() + # check if need to use default size_in and size_out + if size_in is None: + size_in = self.size_in + if size_out is None: + size_out = self.size_out + # do padding on input - padding = [(x - y) // 2 for x, y in zip(self.size_in, self.size_out)] + padding = [(x - y) // 2 for x, y in zip(size_in, size_out)] img_pad0 = np.pad( input_img, ((0, 0), (0, 0), (padding[1], padding[1]), (padding[2], padding[2])), @@ -394,62 +407,35 @@ def apply_on_single_zstack( img_pad0, ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)), "constant" ) - # we only support single output image in model zoo - # other outputs are only supported in full segmenter prediction so far - assert len(self.OutputCh) == 2 - output_img = np.zeros(input_img.shape) + # pad the extra batch dimension + img_pad = np.expand_dims(img_pad, axis=0) - # loop through the image patch by patch - num_step_z = int(np.ceil(input_img.shape[1] / self.size_out[0])) - num_step_y = int(np.ceil(input_img.shape[2] / self.size_out[1])) - num_step_x = int(np.ceil(input_img.shape[3] / self.size_out[2])) + # run sliding window inference with torch.no_grad(): - for ix in range(num_step_x): - if ix < num_step_x - 1: - xa = ix * self.size_out[2] + output_tensor, _ = sliding_window_inference( + inputs=torch.from_numpy(img_pad).float().cuda(), + roi_size=size_in, + out_size=size_out, + original_image_size=input_img.shape[-3:], + sw_batch_size=1, + predictor=model.forward, + overlap=0.25, + mode="gaussian", + model_name=self.model_name, + ) + + output_img = output_tensor.cpu().data.numpy() + if self.OutputCh: + # old models, only take the output from the highest resolution + if type(self.OutputCh) == list: + # if it is [v1, v2], the second value is which channel to take from + # the highest resolution output + if len(self.OutputCh) >= 2: + self.OutputCh = self.OutputCh[1] else: - xa = input_img.shape[3] - self.size_out[2] - - for iy in range(num_step_y): - if iy < num_step_y - 1: - ya = iy * self.size_out[1] - else: - ya = input_img.shape[2] - self.size_out[1] - - for iz in range(num_step_z): - if iz < num_step_z - 1: - za = iz * self.size_out[0] - else: - za = input_img.shape[1] - self.size_out[0] - - input_patch = img_pad[ - :, - za : (za + self.size_in[0]), - ya : (ya + self.size_in[1]), - xa : (xa + self.size_in[2]), - ] - input_img_tensor = torch.from_numpy(input_patch) - tmp_out = model(Variable(input_img_tensor.cuda()).unsqueeze(0)) - assert len(self.OutputCh) // 2 <= len( - tmp_out - ), "the parameter OutputCh not compatible with output tensors" - - label = tmp_out[self.OutputCh[0]] - prob = self.softmax(label) - out_flat_tensor = prob.cpu().data - out_tensor = out_flat_tensor.view( - self.size_out[0], - self.size_out[1], - self.size_out[2], - self.nclass[0], - ) - out_nda = out_tensor.numpy() - output_img[ - 0, - za : (za + self.size_out[0]), - ya : (ya + self.size_out[1]), - xa : (xa + self.size_out[2]), - ] = out_nda[:, :, :, self.OutputCh[1]] + # just convert list to integer + self.OutputCh = self.OutputCh[0] + output_img = output_img[:, self.OutputCh, :, :, :] torch.cuda.empty_cache() diff --git a/setup.py b/setup.py index ff7128d..8261b3f 100644 --- a/setup.py +++ b/setup.py @@ -41,9 +41,9 @@ requirements = [ 'PyYAML', 'aicsimageio>3.3.0', - 'aicsmlsegment>0.0.5' + 'aicsmlsegment>0.0.5', 'scikit-image', - "quilt3", + 'quilt3', ] extra_requirements = {