-
Notifications
You must be signed in to change notification settings - Fork 1
Refactored Framework and Added 3D Visualization Capabilities #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
af282cb
Removed unused parameter
MattsonCam 654d9d8
Doubled the crop dimensionsions and allowed
MattsonCam 4f41c2a
Decoupled logic to retrieve data paths and specs
3240ec8
Added capability to call multiple image savers
926cafd
Updated to include the following:
7613eaa
Updated to use the image specs
170b56e
Changes:
1606631
Updated to run from uv
65a0920
Updated the uv environment to include the correct
1282eeb
Added documentation to whole slice saver
1b1b80f
Refactored code to visualize 3D images
5eefe06
Ensured the generated predictions are only saved once
57b713e
Renamed the visualization folder
2c25c43
Updated the save path for whole images
df30be2
Updated a comment
de50834
Changes:
3277d92
Updated documentation based on pr comments
f3ecc7d
Updated doc strings
09aea43
Changed the indexing to the thresholded image and
f1e969d
Fixed bug when specifying original slice indices
99f83c9
Removed extra print statement
e85a77f
Changed to update whole image slice and stride
e211d00
Added comments about padding and original crop computation
ba4bc65
Modified channel mapping to account for the
ac96ef6
Included the visualization code in one python file
038eb63
Added necessary module and removed redundant print
23b5e93
Allows for a binary segmentation model that can be
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,4 +2,4 @@ name: cell_segmentation_gff | |
|
|
||
| entry_points: | ||
| train_model: | ||
| command: "python3 train.py" | ||
| command: "uv run train.py" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| import pathlib | ||
| from typing import Any, Optional | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import tifffile | ||
| import torch | ||
|
|
||
| from .image_padding_specs import compute_patch_mapping | ||
| from .save_utils import save_image_locally, save_image_mlflow | ||
|
|
||
|
|
||
| class SaveWholeSlices: | ||
| """ | ||
| Saves all 2D images and slices to a 3D tiff format either locally or in MLflow. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| image_dataset: torch.utils.data.Dataset, | ||
| image_dataset_idxs: list[int], | ||
| image_specs: dict[str, Any], | ||
| stride, | ||
| crop_shape, | ||
| pad_mode="reflect", | ||
| image_postprocessor: Any = lambda x: x, | ||
| local_save_path: Optional[pathlib.Path] = None, | ||
| **kwargs, | ||
| ): | ||
|
|
||
| self.image_dataset = image_dataset | ||
| self.image_dataset_idxs = image_dataset_idxs | ||
| self.image_specs = image_specs | ||
| self.stride = stride | ||
| self.crop_shape = crop_shape | ||
| self.pad_mode = pad_mode | ||
| self.image_postprocessor = image_postprocessor | ||
| self.local_save_path = local_save_path | ||
|
|
||
| self.unique_image_dataset_idxs = [] | ||
| self.reduce_dataset_idxs(image_dataset=image_dataset) | ||
|
|
||
| self.pad_width, self.original_crop_coords = compute_patch_mapping( | ||
| image_specs=image_specs, | ||
| crop_shape=crop_shape, | ||
| stride=stride, | ||
| pad_slices=True, | ||
| ) | ||
|
|
||
| self.epoch = None | ||
|
|
||
| def reduce_dataset_idxs(self, image_dataset: torch.utils.data.Dataset): | ||
| """ | ||
| For reducing the dataset to only unique indices. | ||
MattsonCam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| self.unique_image_dataset_idxs = [] | ||
|
|
||
| for sample_idx in self.image_dataset_idxs: | ||
| if ( | ||
| image_dataset[sample_idx]["metadata"]["Metadata_ID"] | ||
| not in self.unique_image_dataset_idxs | ||
| ): | ||
| self.unique_image_dataset_idxs.append(sample_idx) | ||
|
|
||
| def predict_target( | ||
| self, padded_image: torch.Tensor, model: torch.nn.Module | ||
| ) -> torch.Tensor: | ||
| """ | ||
| padded_image: | ||
| Expects image of shape: (Z, H, W) | ||
| Z -> Number of Z slices | ||
| H -> Image Height | ||
| W -> Image Width | ||
| """ | ||
|
|
||
| output = torch.zeros( | ||
| *padded_image.shape, | ||
| dtype=torch.float32, | ||
| device=padded_image.device, | ||
| ) | ||
| weight = torch.zeros_like(output) | ||
|
|
||
| spatial_ranges = [ | ||
| range(0, s - c, st) | ||
| for s, c, st in zip(padded_image.shape, self.crop_shape, self.stride) | ||
| ] | ||
|
|
||
| for idx in torch.cartesian_prod( | ||
| *[torch.tensor(list(r)) for r in spatial_ranges] | ||
| ): | ||
| start = idx.tolist() | ||
| end = [s + c for s, c in zip(start, self.crop_shape)] | ||
|
|
||
| slices = tuple(slice(s, e) for s, e in zip(start, end)) | ||
| crop = padded_image[slices].unsqueeze(0) # add batch dim | ||
|
|
||
| with torch.no_grad(): | ||
| generated_prediction = self.image_postprocessor( | ||
| generated_prediction=model(crop) | ||
| ).squeeze(0) | ||
|
|
||
| # Accumulate prediction and weights | ||
| output[slices] += generated_prediction | ||
| weight[slices] += 1.0 | ||
|
|
||
| output /= weight | ||
|
|
||
| return output[self.original_crop_coords] | ||
|
|
||
| def pad_image(self, input_image: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| input_image: | ||
| Expects image of shape: (Z, H, W) | ||
| Z -> Number of Z slices | ||
| H -> Image Height | ||
| W -> Image Width | ||
MattsonCam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| padded_image = np.pad( | ||
| input_image.detach().cpu().numpy(), | ||
| pad_width=self.pad_width, | ||
| mode=self.pad_mode, | ||
| ) | ||
|
|
||
| padded_image = torch.from_numpy(padded_image).to( | ||
| dtype=torch.float32, device=input_image.device | ||
| ) | ||
|
|
||
| return padded_image | ||
|
|
||
| def save_image( | ||
| self, | ||
| image_path: pathlib.Path, | ||
| image_type: str, | ||
| image: torch.Tensor, | ||
| ) -> bool: | ||
| """ | ||
| - Determines if the image is completely black or not. | ||
| - Saves images in the correct format to the hardcoded path. | ||
| """ | ||
|
|
||
| if not ((image > 0.0) & (image < 1.0)).any(): | ||
| if image_type == "input": | ||
| raise ValueError("Pixels should be between 0 and 1 in the input image") | ||
|
|
||
| if image_type == "target": | ||
| image = (image != 0).float() | ||
|
|
||
| image = (image * 255).byte().cpu().numpy() | ||
|
|
||
| # Black images will not be saved | ||
| if np.max(image) == 0: | ||
| return False | ||
MattsonCam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| image_suffix = ".tiff" if ".tif" in image_path.suffix else image_path.suffix | ||
|
|
||
| filename = f"3D_{image_type}_{image_path.stem}{image_suffix}" | ||
|
|
||
| fov_well_name = image_path.parent.name | ||
| patient_name = image_path.parents[2].name | ||
|
|
||
| save_image_path_folder = f"{patient_name}/{fov_well_name}" | ||
| save_image_path_folder = ( | ||
| f"whole_images/epoch_{self.epoch:02}/{save_image_path_folder}" | ||
| if self.epoch is not None | ||
| else save_image_path_folder | ||
| ) | ||
|
|
||
| if self.local_save_path is None: | ||
| print("Started_save") | ||
| save_image_mlflow( | ||
| image=image, | ||
| save_image_path_folder=save_image_path_folder, | ||
| image_filename=filename, | ||
| ) | ||
| else: | ||
| save_image_path_folder = self.local_save_path / save_image_path_folder | ||
| save_image_locally( | ||
| image=image, | ||
| save_image_path_folder=save_image_path_folder, | ||
| image_filename=filename, | ||
| ) | ||
|
|
||
| return True | ||
|
|
||
| def __call__( | ||
| self, | ||
| model: torch.nn.Module, | ||
| epoch: Optional[int] = None, | ||
| ) -> None: | ||
|
|
||
| self.epoch = epoch | ||
| for sample_idx in self.unique_image_dataset_idxs: | ||
|
|
||
| sample_image = self.save_image( | ||
| image_path=self.image_dataset[sample_idx]["target_path"], | ||
| image_type="target", | ||
| image=self.image_dataset[sample_idx]["target"], | ||
| ) | ||
|
|
||
| # Only save these images if the segmentation mask isn't black | ||
| if sample_image: | ||
| padded_image = self.pad_image( | ||
| input_image=self.image_dataset[sample_idx]["input"] | ||
| ) | ||
MattsonCam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| generated_prediction = self.predict_target( | ||
| padded_image=padded_image, model=model | ||
| ) | ||
|
|
||
| self.save_image( | ||
| image_path=self.image_dataset[sample_idx]["input_path"], | ||
| image_type="input", | ||
| image=self.image_dataset[sample_idx]["input"], | ||
| ) | ||
|
|
||
| self.save_image( | ||
| image_path=self.image_dataset[sample_idx]["target_path"], | ||
| image_type="generated-prediction", | ||
| image=generated_prediction, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.