Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions aicsmlsegment/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
undo_resize,
UniversalDataset,
)
from aicsmlsegment.utils import compute_iou
from aicsmlsegment.utils import compute_iou, save_image

import numpy as np
from skimage.io import imsave
Expand Down Expand Up @@ -118,6 +118,7 @@ def __init__(self, config, model_config, train):
self.args_inference["inference_batch_size"] = config["batch_size"]
self.args_inference["mode"] = config["mode"]["name"]
self.args_inference["Threshold"] = config["Threshold"]
self.uncertainty = config["uncertainty"]
if config["large_image_resize"] != [1, 1, 1]:
self.aggregate_img = {}
self.count_map = {}
Expand Down Expand Up @@ -239,15 +240,6 @@ def on_train_epoch_start(self):
)
self.iter_dataloader = iter(self.DATALOADER)

def get_upsample_grid(self, desired_shape, n_targets):
x = torch.linspace(-1, 1, desired_shape[-1], device=self.device)
y = torch.linspace(-1, 1, desired_shape[-2], device=self.device)
z = torch.linspace(-1, 1, desired_shape[-3], device=self.device)
meshz, meshy, meshx = torch.meshgrid((z, y, x))
grid = torch.stack((meshx, meshy, meshz), 3)
grid = torch.stack([grid] * n_targets) # one grid for each target in batch
return grid

def log_and_return(self, name, value):
# sync_dist on_epoch=True ensures that results will be averaged across gpus
self.log(
Expand All @@ -272,7 +264,6 @@ def training_step(self, batch, batch_idx):
targets = batch[1]
cmap = batch[2]
outputs = self(inputs)

vae_loss = 0
if self.model_name == "segresnetvae":
# segresnetvae forward returns an additional vae loss term
Expand All @@ -293,7 +284,7 @@ def validation_step(self, batch, batch_idx):
costmap = batch[2]
# fn = batch[3]

outputs, vae_loss = model_inference(
outputs, vae_loss, _ = model_inference(
self.model,
input_img,
self.args_inference,
Expand Down Expand Up @@ -335,7 +326,7 @@ def test_step(self, batch, batch_idx):
if self.aggregate_img is not None:
to_numpy = False # prevent excess gpu->cpu data transfer

output_img, _ = apply_on_image(
output_img, _, uncertaintymap = apply_on_image(
self.model,
img,
args_inference,
Expand All @@ -344,9 +335,11 @@ def test_step(self, batch, batch_idx):
softmax=True,
model_name=self.model_name,
extract_output_ch=True,
uncertainty=self.uncertainty,
)

if self.aggregate_img is not None:
if self.uncertainty is not None:
print("Uncertainty is not yet supported with large image resizing.")
# initialize the aggregate img
i, j, k = batch["ijk"][0], batch["ijk"][1], batch["ijk"][2]
if fn not in self.aggregate_img:
Expand Down Expand Up @@ -378,8 +371,6 @@ def test_step(self, batch, batch_idx):
# only want to perform post-processing and saving once the aggregated image
# is complete or we're not aggregating an image
if self.batch_count[fn] % save_n_batches == 0:
from aicsimageio.writers.ome_tiff_writer import OmeTiffWriter

if self.aggregate_img is not None:
# normalize for overlapping patches
output_img = self.aggregate_img[fn] / self.count_map[fn]
Expand Down Expand Up @@ -409,10 +400,15 @@ def test_step(self, batch, batch_idx):
path = self.config["OutputDir"] + os.sep + pathlib.PurePosixPath(fn).stem
if tt != -1:
path = path + "_T_" + f"{tt:03}"
path += "_struct_segmentation.tiff"
with OmeTiffWriter(path, overwrite_file=True) as writer:
writer.save(
data=out,
channel_names=[self.config["segmentation_name"]],
dimension_order="CZYX",

save_image(
path + "_struct_segmentation.tiff",
out,
[self.config["segmentation_name"]],
)
if uncertaintymap is not None:
save_image(
path + "_" + self.uncertainty + "_uncertainty.tiff",
uncertaintymap,
[self.uncertainty],
)
19 changes: 18 additions & 1 deletion aicsmlsegment/NetworkArchitecture/unet_xy_zoom_0pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@

class UNet3D(nn.Module):
def __init__(
self, in_channel, n_classes, down_ratio, test_mode=True, batchnorm_flag=True
self,
in_channel,
n_classes,
down_ratio,
test_mode=True,
batchnorm_flag=True,
dropout=0,
):
self.in_channel = in_channel
self.n_classes = n_classes
self.test_mode = test_mode
self.dropout = dropout
super(UNet3D, self).__init__()

k = down_ratio
Expand Down Expand Up @@ -129,6 +136,7 @@ def encoder(
padding=0,
bias=True,
batchnorm=False,
dropout=0,
):
if batchnorm:
layer = nn.Sequential(
Expand All @@ -142,6 +150,7 @@ def encoder(
),
nn.BatchNorm3d(out_channels, affine=False),
nn.ReLU(),
nn.Dropout3d(p=dropout),
nn.Conv3d(
out_channels,
2 * out_channels,
Expand All @@ -152,6 +161,7 @@ def encoder(
),
nn.BatchNorm3d(2 * out_channels, affine=False),
nn.ReLU(),
nn.Dropout3d(p=dropout),
)
else:
layer = nn.Sequential(
Expand All @@ -164,6 +174,7 @@ def encoder(
bias=bias,
),
nn.ReLU(),
nn.Dropout3d(p=dropout),
nn.Conv3d(
out_channels,
2 * out_channels,
Expand All @@ -173,6 +184,7 @@ def encoder(
bias=bias,
),
nn.ReLU(),
nn.Dropout3d(p=dropout),
)
return layer

Expand All @@ -185,6 +197,7 @@ def decoder(
padding=0,
bias=True,
batchnorm=False,
dropout=0,
):
if batchnorm:
layer = nn.Sequential(
Expand All @@ -198,6 +211,7 @@ def decoder(
),
nn.BatchNorm3d(out_channels, affine=False),
nn.ReLU(),
nn.Dropout3d(p=dropout),
nn.Conv3d(
out_channels,
out_channels,
Expand All @@ -208,6 +222,7 @@ def decoder(
),
nn.BatchNorm3d(out_channels, affine=False),
nn.ReLU(),
nn.Dropout3d(p=dropout),
)
else:
layer = nn.Sequential(
Expand All @@ -220,6 +235,7 @@ def decoder(
bias=bias,
),
nn.ReLU(),
nn.Dropout3d(p=dropout),
nn.Conv3d(
out_channels,
out_channels,
Expand All @@ -229,6 +245,7 @@ def decoder(
bias=bias,
),
nn.ReLU(),
nn.Dropout3d(p=dropout),
)
return layer

Expand Down
Loading