Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
af282cb
Removed unused parameter
MattsonCam Sep 22, 2025
654d9d8
Doubled the crop dimensionsions and allowed
MattsonCam Sep 22, 2025
4f41c2a
Decoupled logic to retrieve data paths and specs
Oct 1, 2025
3240ec8
Added capability to call multiple image savers
Oct 10, 2025
926cafd
Updated to include the following:
Oct 10, 2025
7613eaa
Updated to use the image specs
Oct 10, 2025
170b56e
Changes:
Oct 10, 2025
1606631
Updated to run from uv
Oct 15, 2025
65a0920
Updated the uv environment to include the correct
Oct 15, 2025
1282eeb
Added documentation to whole slice saver
Oct 16, 2025
1b1b80f
Refactored code to visualize 3D images
Oct 16, 2025
5eefe06
Ensured the generated predictions are only saved once
Oct 16, 2025
57b713e
Renamed the visualization folder
Oct 16, 2025
2c25c43
Updated the save path for whole images
Oct 16, 2025
df30be2
Updated a comment
Oct 16, 2025
de50834
Changes:
Oct 16, 2025
3277d92
Updated documentation based on pr comments
Oct 16, 2025
f3ecc7d
Updated doc strings
Oct 16, 2025
09aea43
Changed the indexing to the thresholded image and
Oct 16, 2025
f1e969d
Fixed bug when specifying original slice indices
Oct 17, 2025
99f83c9
Removed extra print statement
Oct 17, 2025
e85a77f
Changed to update whole image slice and stride
Oct 17, 2025
e211d00
Added comments about padding and original crop computation
Oct 20, 2025
ba4bc65
Modified channel mapping to account for the
Oct 23, 2025
ac96ef6
Included the visualization code in one python file
Oct 27, 2025
038eb63
Added necessary module and removed redundant print
Oct 27, 2025
23b5e93
Allows for a binary segmentation model that can be
Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MLproject
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ name: cell_segmentation_gff

entry_points:
train_model:
command: "python3 train.py"
command: "uv run train.py"
9 changes: 6 additions & 3 deletions callbacks/Callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _on_epoch_end(
("train", train_dataloader),
("validation", val_dataloader),
]:
f = 0

self._log_epoch_metrics(
model=model,
Expand All @@ -179,9 +180,11 @@ def _on_epoch_end(

# Images can be saved in different ways if desired in the future
if self.image_savers is not None and not isinstance(self.image_savers, list):
self.image_savers(
dataset=val_dataloader.dataset.dataset, model=model, epoch=epoch
)
self.image_savers(model=model, epoch=epoch)

else:
for image_saver in self.image_savers:
image_saver(model=model, epoch=epoch)

val_sample = next(iter(val_dataloader))
val_sample = val_sample["input"]
Expand Down
50 changes: 21 additions & 29 deletions callbacks/utils/SaveEpochSlices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import tifffile
import torch

from .save_utils import save_image_mlflow


class SaveEpochSlices:
"""
Expand All @@ -15,30 +17,24 @@ class SaveEpochSlices:

def __init__(
self,
image_dataset_idxs: list[int],
data_split: str,
image_dataset: torch.utils.data.Dataset,
image_postprocessor: Any = lambda x: x,
image_dataset_idxs: Optional[list[int]] = None,
) -> None:

self.image_dataset = image_dataset
self.image_dataset_idxs = image_dataset_idxs
self.data_split = data_split
self.crop_key_order = ["height_start", "height_end", "width_start", "width_end"]
self.image_postprocessor = image_postprocessor

def save_image_mlflow(
self,
image: torch.Tensor,
save_image_path_folder: str,
image_filename: str,
) -> None:

with tempfile.TemporaryDirectory() as tmp_dir:
save_path = pathlib.Path(tmp_dir) / image_filename
tifffile.imwrite(save_path, image.astype(np.uint8))
self.epoch = None
self.metadata = None

mlflow.log_artifact(
local_path=save_path, artifact_path=save_image_path_folder
)
self.image_dataset_idxs = (
range(len(image_dataset))
if image_dataset_idxs is None
else image_dataset_idxs
)

def save_image(
self,
Expand Down Expand Up @@ -67,20 +63,18 @@ def save_image(
for k in self.crop_key_order
)

filename = (
f"{image_path.stem}__{image_type}{image_path.suffix}"
if image_type == "generated_prediction"
else image_path.name
)
image_suffix = ".tiff" if ".tif" in image_path.suffix else image_path.suffix

image_filename = f"{crop_name}__{filename}"
image_filename = (
f"3D_{image_type}_{image_path.stem}__{crop_name}__{image_suffix}"
)

fov_well_name = image_path.parent.name
patient_name = image_path.parents[2].name

save_image_path_folder = f"epoch_{self.epoch:02}/{patient_name}/{fov_well_name}/{input_slices_name}__{target_slices_name}"
save_image_path_folder = f"cropped_images/epoch_{self.epoch:02}/{patient_name}/{fov_well_name}/{input_slices_name}__{target_slices_name}"

self.save_image_mlflow(
save_image_mlflow(
image=image,
save_image_path_folder=save_image_path_folder,
image_filename=image_filename,
Expand All @@ -93,12 +87,10 @@ def predict_target(
) -> torch.Tensor:
return self.image_postprocessor(model(image.unsqueeze(0)).squeeze(0))

def __call__(
self, dataset: torch.utils.data.Dataset, model: torch.nn.Module, epoch: int
) -> None:
def __call__(self, model: torch.nn.Module, epoch: int) -> None:
self.epoch = epoch
for sample_idx in self.image_dataset_idxs:
sample = dataset[sample_idx]
sample = self.image_dataset[sample_idx]
self.metadata = sample["metadata"]

sample_image = self.save_image(
Expand All @@ -121,6 +113,6 @@ def __call__(

self.save_image(
image_path=sample["target_path"],
image_type="generated_prediction",
image_type="generated-prediction",
image=generated_prediction,
)
221 changes: 221 additions & 0 deletions callbacks/utils/SaveWholeSlices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import pathlib
from typing import Any, Optional

import numpy as np
import pandas as pd
import tifffile
import torch

from .image_padding_specs import compute_patch_mapping
from .save_utils import save_image_locally, save_image_mlflow


class SaveWholeSlices:
"""
Saves all 2D images and slices to a 3D tiff format either locally or in MLflow.
"""

def __init__(
self,
image_dataset: torch.utils.data.Dataset,
image_dataset_idxs: list[int],
image_specs: dict[str, Any],
stride,
crop_shape,
pad_mode="reflect",
image_postprocessor: Any = lambda x: x,
local_save_path: Optional[pathlib.Path] = None,
**kwargs,
):

self.image_dataset = image_dataset
self.image_dataset_idxs = image_dataset_idxs
self.image_specs = image_specs
self.stride = stride
self.crop_shape = crop_shape
self.pad_mode = pad_mode
self.image_postprocessor = image_postprocessor
self.local_save_path = local_save_path

self.unique_image_dataset_idxs = []
self.reduce_dataset_idxs(image_dataset=image_dataset)

self.pad_width, self.original_crop_coords = compute_patch_mapping(
image_specs=image_specs,
crop_shape=crop_shape,
stride=stride,
pad_slices=True,
)

self.epoch = None

def reduce_dataset_idxs(self, image_dataset: torch.utils.data.Dataset):
"""
For reducing the dataset to only unique indices.
"""
self.unique_image_dataset_idxs = []

for sample_idx in self.image_dataset_idxs:
if (
image_dataset[sample_idx]["metadata"]["Metadata_ID"]
not in self.unique_image_dataset_idxs
):
self.unique_image_dataset_idxs.append(sample_idx)

def predict_target(
self, padded_image: torch.Tensor, model: torch.nn.Module
) -> torch.Tensor:
"""
padded_image:
Expects image of shape: (Z, H, W)
Z -> Number of Z slices
H -> Image Height
W -> Image Width
"""

output = torch.zeros(
*padded_image.shape,
dtype=torch.float32,
device=padded_image.device,
)
weight = torch.zeros_like(output)

spatial_ranges = [
range(0, s - c, st)
for s, c, st in zip(padded_image.shape, self.crop_shape, self.stride)
]

for idx in torch.cartesian_prod(
*[torch.tensor(list(r)) for r in spatial_ranges]
):
start = idx.tolist()
end = [s + c for s, c in zip(start, self.crop_shape)]

slices = tuple(slice(s, e) for s, e in zip(start, end))
crop = padded_image[slices].unsqueeze(0) # add batch dim

with torch.no_grad():
generated_prediction = self.image_postprocessor(
generated_prediction=model(crop)
).squeeze(0)

# Accumulate prediction and weights
output[slices] += generated_prediction
weight[slices] += 1.0

output /= weight

return output[self.original_crop_coords]

def pad_image(self, input_image: torch.Tensor) -> torch.Tensor:
"""
input_image:
Expects image of shape: (Z, H, W)
Z -> Number of Z slices
H -> Image Height
W -> Image Width
"""

padded_image = np.pad(
input_image.detach().cpu().numpy(),
pad_width=self.pad_width,
mode=self.pad_mode,
)

padded_image = torch.from_numpy(padded_image).to(
dtype=torch.float32, device=input_image.device
)

return padded_image

def save_image(
self,
image_path: pathlib.Path,
image_type: str,
image: torch.Tensor,
) -> bool:
"""
- Determines if the image is completely black or not.
- Saves images in the correct format to the hardcoded path.
"""

if not ((image > 0.0) & (image < 1.0)).any():
if image_type == "input":
raise ValueError("Pixels should be between 0 and 1 in the input image")

if image_type == "target":
image = (image != 0).float()

image = (image * 255).byte().cpu().numpy()

# Black images will not be saved
if np.max(image) == 0:
return False

image_suffix = ".tiff" if ".tif" in image_path.suffix else image_path.suffix

filename = f"3D_{image_type}_{image_path.stem}{image_suffix}"

fov_well_name = image_path.parent.name
patient_name = image_path.parents[2].name

save_image_path_folder = f"{patient_name}/{fov_well_name}"
save_image_path_folder = (
f"whole_images/epoch_{self.epoch:02}/{save_image_path_folder}"
if self.epoch is not None
else save_image_path_folder
)

if self.local_save_path is None:
print("Started_save")
save_image_mlflow(
image=image,
save_image_path_folder=save_image_path_folder,
image_filename=filename,
)
else:
save_image_path_folder = self.local_save_path / save_image_path_folder
save_image_locally(
image=image,
save_image_path_folder=save_image_path_folder,
image_filename=filename,
)

return True

def __call__(
self,
model: torch.nn.Module,
epoch: Optional[int] = None,
) -> None:

self.epoch = epoch
for sample_idx in self.unique_image_dataset_idxs:

sample_image = self.save_image(
image_path=self.image_dataset[sample_idx]["target_path"],
image_type="target",
image=self.image_dataset[sample_idx]["target"],
)

# Only save these images if the segmentation mask isn't black
if sample_image:
padded_image = self.pad_image(
input_image=self.image_dataset[sample_idx]["input"]
)

generated_prediction = self.predict_target(
padded_image=padded_image, model=model
)

self.save_image(
image_path=self.image_dataset[sample_idx]["input_path"],
image_type="input",
image=self.image_dataset[sample_idx]["input"],
)

self.save_image(
image_path=self.image_dataset[sample_idx]["target_path"],
image_type="generated-prediction",
image=generated_prediction,
)
Loading