Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finishing Stage 2 #7

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions kits19cnn/experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import numpy as np
import torch
import albumentations as albu
from albumentations.pytorch import ToTensorV2
from copy import deepcopy

from kits19cnn.io import CenterCrop
from kits19cnn.io import CenterCrop, ToTensorV2

def get_training_augmentation(augmentation_key="aug1"):
transform_dict = {
Expand Down
6 changes: 2 additions & 4 deletions kits19cnn/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .evaluate import Evaluator
from .utils import create_submission, load_weights_infer, \
remove_3D_connected_components
from .predictors import BasePredictor, General3DPredictor, Stage1Predictor
from .ensemble import Ensembler
from .base_predictor import BasePredictor
from .stage1 import Stage1Predictor
from .general_predictors import General3DPredictor
from .evaluate import Evaluator
3 changes: 3 additions & 0 deletions kits19cnn/inference/predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_predictor import BasePredictor
from .general_predictors import General3DPredictor
from .stage1 import Stage1Predictor
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import inspect
import torch

from kits19cnn.inference import remove_3D_connected_components, BasePredictor
from kits19cnn.inference.utils import remove_3D_connected_components
from kits19cnn.inference.predictors import BasePredictor

class General3DPredictor(BasePredictor):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import inspect
import torch

from kits19cnn.inference import remove_3D_connected_components, BasePredictor
from kits19cnn.inference.utils import remove_3D_connected_components
from kits19cnn.inference.predictors import BasePredictor
from kits19cnn.io import get_bbox_from_mask, expand_bbox, crop_to_bbox, resize_bbox
from kits19cnn.utils import load_json, save_json

Expand Down
2 changes: 1 addition & 1 deletion kits19cnn/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .test_dataset import VoxelDataset, TestVoxelDataset
from .preprocess import Preprocessor
from .resample import resample_patient
from .custom_transforms import CenterCrop
from .custom_transforms import CenterCrop, ToTensorV2
from .custom_augmentations import resize_data_and_seg, crop_to_bbox, \
expand_bbox, get_bbox_from_mask, resize_bbox
from .slice_sampler import SliceIDSampler
25 changes: 24 additions & 1 deletion kits19cnn/io/custom_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,28 @@
from albumentations.core.transforms_interface import DualTransform
from albumentations.core.transforms_interface import BasicTransform, DualTransform
import numpy as np
import torch

class ToTensorV2(BasicTransform):
"""Convert image and mask to `torch.Tensor`."""

def __init__(self, always_apply=True, p=1.0):
super(ToTensorV2, self).__init__(always_apply=always_apply, p=p)

@property
def targets(self):
return {"image": self.apply, "mask": self.apply_to_mask}

def apply(self, img, **params):
return torch.from_numpy(img.transpose(2, 0, 1))

def apply_to_mask(self, mask, **params):
return torch.from_numpy(mask.transpose(2, 0, 1))

def get_transform_init_args_names(self):
return []

def get_params_dependent_on_targets(self, params):
return {}

class CenterCrop(DualTransform):
"""
Expand Down
16 changes: 13 additions & 3 deletions kits19cnn/io/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def __init__(self, in_dir, out_dir, cases=None, kits_json_path=None,
if not isdir(out_dir):
os.mkdir(out_dir)
print("Created directory: {0}".format(out_dir))
self.resize_xy_shape = tuple(resize_xy_shape)
self.resize_xy_shape = tuple(resize_xy_shape) \
if isinstance(resize_xy_shape, list) \
else None

def gen_data(self, save_fnames=["imaging", "segmentation"]):
"""
Expand Down Expand Up @@ -324,9 +326,17 @@ def _load_bbox_json(self, json_path):
def crop_case_to_bbox(self, image, label, case):
"""
Crops a 3D image and 3D label to the corresponding bounding box.
Args:
image (np.ndarray): 3D array (no channels)
label (np.ndarray): Same shape as image
case (str): Path to the case (will be processed to the raw case)
"""
bbox_coord = self.bbox_dict[case]
return (crop_to_bbox(image, bbox), crop_to_bbox(label, case))
bbox_coord = self.bbox_dict[Path(case).name]
if label is not None:
return (crop_to_bbox(image, bbox_coord),
crop_to_bbox(label, bbox_coord))
elif label is None:
return (crop_to_bbox(image, bbox_coord), None)

def standardize_per_image(image):
"""
Expand Down
2 changes: 1 addition & 1 deletion kits19cnn/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot

from catalyst.utils.tensorboard import SummaryItem, SummaryReader
from catalyst.utils.tools.tensorboard import SummaryItem, SummaryReader

print("If you're using a notebook, "
"make sure to run %matplotlib inline beforehand.")
Expand Down
59 changes: 0 additions & 59 deletions script_configs/pred.yml

This file was deleted.

2 changes: 1 addition & 1 deletion script_configs/stage1/pred.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
out_dir: C:\\Users\\jchen\\Desktop\\Datasets\\kits19_predictions_empty #C:\\Users\\Joseph\\Desktop\\kits19_predictions
scale_ratios_json_path: C:\Users\jchen\Active Github Repositories\kits19-2d-reproduce\scale_factors.json
with_masks: True
mode: segmentation
stage: 1
checkpoint_path: C:\\Users\\jchen\\Desktop\\stage1resunet_23epochs_last_full.pth
pseudo_3D: True

Expand Down
32 changes: 32 additions & 0 deletions script_configs/stage2/pred_resnet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
in_dir: C:\\Users\\jchen\\Desktop\\Datasets\\kits_reproduced_from_raw_cropped #C:\\Users\\Joseph\\Desktop\\kits19_preprocessed\\data
out_dir: C:\\Users\\jchen\\Desktop\\Datasets\\kits19_predictions_cropped #C:\\Users\\Joseph\\Desktop\\kits19_predictions
with_masks: True
stage: 2
checkpoint_path: C:\\Users\\jchen\\Desktop\\stage1resunet_23epochs_last_full.pth
pseudo_3D: True

io_params:
test_size: 0.2
split_seed: 42
batch_size: 1
num_workers: 2
file_ending: .npy # nii.gz

model_params:
model_name: ResNetSeg
ResNetSeg:
input_channels: 5

predict_3D_params:
do_mirroring: True
num_repeats: 1
min_size:
- 256
- 256
batch_size: 1
mirror_axes:
- 0
- 1
regions_class_order: ~ # this is argmax
pseudo3D_slices: 5
all_in_gpu: False
36 changes: 36 additions & 0 deletions script_configs/stage2/pred_resunet2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
in_dir: C:\\Users\\jchen\\Desktop\\Datasets\\kits_reproduced_from_raw_cropped #C:\\Users\\Joseph\\Desktop\\kits19_preprocessed\\data
out_dir: C:\\Users\\jchen\\Desktop\\Datasets\\kits19_predictions_cropped #C:\\Users\\Joseph\\Desktop\\kits19_predictions
with_masks: True
stage: 2
checkpoint_path: C:\\Users\\jchen\\Desktop\\stage1resunet_23epochs_last_full.pth
pseudo_3D: True

io_params:
test_size: 0.2
split_seed: 42
batch_size: 1
num_workers: 2
file_ending: .npy # nii.gz

model_params:
model_name: ResUNet
ResUNet:
input_channels: 5
base_num_features: 16
num_classes: 3
num_pool: 4
max_num_features: 256

predict_3D_params:
do_mirroring: True
num_repeats: 1
min_size:
- 256
- 256
batch_size: 1
mirror_axes:
- 0
- 1
regions_class_order: ~ # this is argmax
pseudo3D_slices: 5
all_in_gpu: False
5 changes: 3 additions & 2 deletions script_configs/stage2/preprocess.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
preprocessor_params:
in_dir: /content/kits19/data
out_dir: /content/kits_preprocessed
out_dir: /content/kits_preprocessed_cropped
cases: ~
kits_json_path: /content/kits19/data/kits.json
kits_json_path: ~
bbox_json_path: /content/kits19_predictions/bbox_stage1.json
clip_values:
- -30
- 300
Expand Down
6 changes: 6 additions & 0 deletions script_configs/utility/create_stage2_training_labels.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
cases_raw: ~
in_dir: C:\Users\jchen\kits19\data #/content/kits19/data
save_path: C:\Users\jchen\Active Github Repositories\kits19-2d-reproduce\example_generated_files\stage2_actual_bbox.json #/content/kits19_preprocessed/stage2_actual_bbox.json
# in_dir: C:\Users\jchen\Desktop\Datasets\kits19_interpolated\data #/content/kits19/data
# save_path: C:\Users\jchen\Active Github Repositories\kits19-2d-reproduce\example_generated_files\stage2_actual_bbox_interpolated.json #/content/kits19_preprocessed/stage2_actual_bbox.json
# python "C:\Users\jchen\Active Github Repositories\kits19-2d-reproduce\scripts\utility\create_stage2_training_labels.py" --yml_path="C:\Users\jchen\Active Github Repositories\kits19-2d-reproduce\script_configs\utility\create_stage2_training_labels.yml"
20 changes: 12 additions & 8 deletions scripts/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from catalyst.dl.runner import SupervisedRunner

from kits19cnn.inference import Stage1Predictor
from kits19cnn.inference import Stage1Predictor, General3DPredictor
from kits19cnn.experiments import SegmentationInferenceExperiment2D, \
seed_everything

Expand All @@ -20,11 +18,17 @@ def main(config):
exp = SegmentationInferenceExperiment2D(config)

print(f"Seed: {seed}")
pred = Stage1Predictor(out_dir=config["out_dir"],
model=exp.model, test_loader=exp.loaders["test"],
scale_ratios_json_path=config["scale_ratios_json_path"],
pred_3D_params=config["predict_3D_params"],
pseudo_3D=config.get("pseudo_3D"))
if config["stage"] == 1:
pred = Stage1Predictor(out_dir=config["out_dir"],
model=exp.model, test_loader=exp.loaders["test"],
scale_ratios_json_path=config["scale_ratios_json_path"],
pred_3D_params=config["predict_3D_params"],
pseudo_3D=config.get("pseudo_3D"))
elif config["stage"] == 2:
pred = General3DPredictor(out_dir=config["out_dir"],
model=exp.model, test_loader=exp.loaders["test"],
pred_3D_params=config["predict_3D_params"],
pseudo_3D=config.get("pseudo_3D"))
pred.run_3D_predictions()

if __name__ == "__main__":
Expand Down
68 changes: 68 additions & 0 deletions scripts/utility/create_stage2_training_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from tqdm import tqdm
import os
from os.path import join
import nibabel as nib

from kits19cnn.io import get_bbox_from_mask, expand_bbox
from kits19cnn.utils import save_json

def create_bbox_stage1(mask):
"""
Creates the bounding box from mask and expands it to 256 by 464. DOES NOT
RESIZE THE BBOX LIKE IN STAGE 1.
Args:
mask (np.ndarray): 3D Array (no channels)
Returns:
expanded_bbox (list):
[[lb_x, ub_x], [lb_y, ub_y], [lb_z, ub_z]]
where lb -> lower bound coordinate
ub -> upper bound coordinate
"""
bbox = get_bbox_from_mask(mask, outside_value=0)
expanded_bbox = expand_bbox(bbox,
bbox_lengths=[None, 256, 464])
# Changed to 256, 464 because the max y length
# in the uninterpolated dset is 459 and it
# needs to be divisible by 16
return expanded_bbox

def fetch_cases(in_dir):
"""
Creates a list of all available case folders in `in_dir`
"""
cases_raw = [case \
for case in os.listdir(in_dir) \
if case.startswith("case")]
cases_raw = sorted(cases_raw)
assert len(cases_raw) > 0, \
"Please make sure that in_dir refers to the proper directory."
return cases_raw[:210] # past 210 are the test cases with no masks

def main(cases_raw, in_dir, save_path):
"""
Reads all raw masks and creates bbox. This bbox is then expanded to 256 by
256 (z-dim stays the same).
"""
cases_raw = cases_raw if cases_raw is not None else fetch_cases(in_dir)
actual_bbox_dict = {}
for case_raw in tqdm(cases_raw):
mask = nib.load(join(in_dir, case_raw, "segmentation.nii.gz")).get_fdata()
actual_bbox_dict[case_raw] = create_bbox_stage1(mask)
save_json(actual_bbox_dict, save_path)

if __name__ == "__main__":
import yaml
import argparse

parser = argparse.ArgumentParser(description="For prediction.")
parser.add_argument("--yml_path", type=str, required=True,
help="Path to the .yml config.")
args = parser.parse_args()

with open(args.yml_path, 'r') as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)

main(**config)