Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 74 additions & 94 deletions src/imgtools/io/nnunet_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Sequence
import tempfile
import shutil

import pandas as pd
from pydantic import (
Expand All @@ -26,6 +28,7 @@
generate_dataset_json,
generate_nnunet_scripts,
)
from imgtools.io.sample_output import FailedToSaveSingleImageError, AnnotatedPathSequence

__all__ = ["nnUNetOutput", "MaskSavingStrategy"]

Expand Down Expand Up @@ -318,7 +321,7 @@ def __call__(
/,
SampleNumber: str, # noqa: N803
**kwargs: object, # noqa: N803
) -> Sequence[Path]:
) -> AnnotatedPathSequence:
"""
Save the data to files using the configured writer.

Expand All @@ -331,104 +334,81 @@ def __call__(

Returns
-------
List[Path]
List of paths to the saved files.
AnnotatedPathSequence
List of paths to the saved files, annotated with any errors that occurred.
"""

valid_masks = self._get_valid_masks(data, SampleNumber)
selected_mask = valid_masks[0] # Select the first valid mask

saved_files = []
save_errors = []
temp_dir = None

try:
temp_dir = Path(tempfile.mkdtemp(prefix=".tmp_nnunet_", dir=self.writer.root_directory))

temp_writer = NIFTIWriter(
root_directory=temp_dir,
filename_format=self._file_name_format,
existing_file_mode=ExistingFileMode.FAIL,
compression_level=self.writer.compression_level,
context={**self.writer.context, **kwargs}
)

match self.mask_saving_strategy:
case MaskSavingStrategy.LABEL_IMAGE:
mask = selected_mask.to_label_image()
case MaskSavingStrategy.SPARSE_MASK:
mask = selected_mask.to_sparse_mask()
case MaskSavingStrategy.REGION_MASK:
mask = selected_mask.to_region_mask()
case _:
msg = f"Unknown mask saving strategy: {self.mask_saving_strategy}"
raise MaskSavingStrategyError(msg)

roi_match_data = {
f"roi_matches.{rmap.roi_key}": "|".join(rmap.roi_names)
for rmap in selected_mask.roi_mapping.values()
}

p = self.writer.save(
mask,
DirType="labels",
SplitType="Tr",
SampleID=SampleNumber,
Dataset=self.dataset_name,
**roi_match_data,
**selected_mask.metadata,
**kwargs,
)
saved_files.append(p)

for image in data:
if isinstance(image, Scan):
# Handle Scan case(CT or MR)
p = self.writer.save(
image,
DirType="images",
SplitType="Tr",
SampleID=f"{SampleNumber}_{MODALITY_MAP[image.metadata['Modality']]}",
Dataset=self.dataset_name,
**image.metadata,
**kwargs,
)
saved_files.append(p)
elif isinstance(image, VectorMask):
pass
else:
errmsg = (
f"Unsupported image type: {type(image)}. "
"Expected Scan or VectorMask."
)
logger.error(errmsg)
raise TypeError(errmsg)

return saved_files


if __name__ == "__main__":
from imgtools.io.sample_input import SampleInput

input_directory = "./data/RADCURE"
output_directory = Path("./temp_outputs")
output_directory.mkdir(exist_ok=True)
modalities = ["CT", "RTSTRUCT"]
roi_match_map = {
"BRAIN": ["Brain"],
"BRAINSTEM": ["Brainstem"],
}

input = SampleInput.build( # noqa: A001
directory=Path(input_directory),
modalities=modalities,
roi_match_map=roi_match_map,
)
output = nnUNetOutput(
directory=output_directory,
existing_file_mode=ExistingFileMode.OVERWRITE,
dataset_name="RADCURE",
extra_context={},
roi_keys=list(roi_match_map.keys()),
mask_saving_strategy=MaskSavingStrategy.REGION_MASK,
)
files_to_commit = {}

samples = input.query()
valid_masks = self._get_valid_masks(data, SampleNumber)
selected_mask = valid_masks[0]

for idx, sample in enumerate(samples, start=1):
loaded_sample = input(sample)
# Stage 1: Save all files for the sample to the temporary directory
mask = {
MaskSavingStrategy.LABEL_IMAGE: selected_mask.to_label_image,
MaskSavingStrategy.SPARSE_MASK: selected_mask.to_sparse_mask,
MaskSavingStrategy.REGION_MASK: selected_mask.to_region_mask,
}.get(self.mask_saving_strategy)()

with contextlib.suppress(Exception):
output(loaded_sample, SampleNumber=f"{idx:03}")
if mask is None:
raise MaskSavingStrategyError(f"Unknown mask saving strategy: {self.mask_saving_strategy}")
Comment on lines +367 to +368
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Dead code: mask will never be None

The dictionary lookup on lines 361-365 will always return a callable method (to_label_image, to_sparse_mask, or to_region_mask) because self.mask_saving_strategy is validated as an enum. The .get() call with no default will return None only for invalid enum values, which can't happen. This check is unreachable.

-            if mask is None:
-                raise MaskSavingStrategyError(f"Unknown mask saving strategy: {self.mask_saving_strategy}")
+            # No need to check for None - enum validation ensures valid strategy

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/imgtools/io/nnunet_output.py around lines 367-368, the if-check raising
MaskSavingStrategyError when mask is None is dead code because the dict lookup
for mask saving strategy always returns a callable for validated enum values;
remove this unreachable if-block and its raise, and if you need a defensive
guarantee for static analysis keep a short assert like "assert callable(mask)"
immediately after the lookup (or use typing.cast) instead of the conditional
raise.


if idx == 5:
break
roi_match_data = {
f"roi_matches.{rmap.roi_key}": "|".join(rmap.roi_names)
for rmap in selected_mask.roi_mapping.values()
}

output.finalize_dataset()
mask_context = {
"DirType": "labels", "SplitType": "Tr", "SampleID": SampleNumber,
"Dataset": self.dataset_name, **roi_match_data, **selected_mask.metadata, **kwargs
}
temp_mask_path = temp_writer.save(mask, **mask_context)
final_mask_path = self.writer.resolve_path(**mask_context)
files_to_commit[temp_mask_path] = final_mask_path

for image in data:
if isinstance(image, Scan):
image_context = {
"DirType": "images", "SplitType": "Tr",
"SampleID": f"{SampleNumber}_{MODALITY_MAP[image.metadata['Modality']]}",
"Dataset": self.dataset_name, **image.metadata, **kwargs
}
temp_image_path = temp_writer.save(image, **image_context)
final_image_path = self.writer.resolve_path(**image_context)
files_to_commit[temp_image_path] = final_image_path
elif not isinstance(image, VectorMask):
raise TypeError(f"Unsupported image type: {type(image)}. Expected Scan or VectorMask.")

# Stage 2: Commit files by moving them to their final destination
for temp_path, final_path in files_to_commit.items():
final_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(temp_path), str(final_path))
saved_files.append(final_path)

except Exception as e:
errmsg = f"Failed to save nnUNet sample atomically: {e}"
image_context = data[0] if data else None
save_error = FailedToSaveSingleImageError(errmsg, image_context)
save_errors.append(save_error)
logger.error(errmsg, error=save_error)

finally:
# Stage 3: cleanup the temporary directory
if temp_dir and temp_dir.exists():
shutil.rmtree(temp_dir)

return AnnotatedPathSequence(saved_files, save_errors)
90 changes: 56 additions & 34 deletions src/imgtools/io/sample_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pathlib import Path
from typing import Any, Dict, List, Sequence
import tempfile
import shutil

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -186,48 +188,68 @@ def __call__(
"""
saved_files = []
save_errors = []
for image in data:
try:
temp_dir = None

try:
# Create a unique temporary directory within the root output directory for this transaction
temp_dir = Path(tempfile.mkdtemp(prefix=".tmp_sample_", dir=self.writer.root_directory))

# Create a temporary writer configured to save into this new directory
temp_writer = NIFTIWriter(
root_directory=temp_dir,
filename_format=self.filename_format,
existing_file_mode=ExistingFileMode.FAIL, # Should not have conflicts in a new temp dir
compression_level=self.writer.compression_level,
context={**self.writer.context, **kwargs}
)

files_to_commit = {}

# Stage 1: Write all files to the temporary directory
for image in data:
if isinstance(image, VectorMask):
for (
_i,
roi_key,
roi_names,
image_id,
mask,
) in image.iter_masks():
# image_id = f"{roi_key}_[{matched_rois}]"
p = self.writer.save(
mask,
roi_key=roi_key,
matched_rois="|".join(roi_names),
for (_i, roi_key, roi_names, image_id, mask) in image.iter_masks():
context = {
"roi_key": roi_key,
"matched_rois": "|".join(roi_names),
"ImageID": image_id,
**image.metadata,
**kwargs,
ImageID=image_id,
)
saved_files.append(p)
}
# Save to temp and resolve final path
temp_path = temp_writer.save(mask, **context)
final_path = self.writer.resolve_path(**context)
files_to_commit[temp_path] = final_path
elif isinstance(image, MedImage):
# Handle MedImage case
p = self.writer.save(
image,
context = {
"ImageID": image.metadata.get("Modality", "Unknown"),
**image.metadata,
**kwargs,
ImageID=image.metadata["Modality"],
)
saved_files.append(p)
}
# Save to temp and resolve final path
temp_path = temp_writer.save(image, **context)
final_path = self.writer.resolve_path(**context)
files_to_commit[temp_path] = final_path
else:
errmsg = (
f"Unsupported image type: {type(image)}. "
"Expected MedImage or VectorMask."
)
logger.error(errmsg)
errmsg = f"Unsupported image type: {type(image)}. Expected MedImage or VectorMask."
raise TypeError(errmsg)
except Exception as e:
errmsg = f"Failed to save image SeriesUID: {image.metadata['SeriesInstanceUID']}: {e}"

# create an error object
save_error = FailedToSaveSingleImageError(errmsg, image)
save_errors.append(save_error)
logger.error(errmsg, error=save_error)
# Stage 2: Commit files by moving them from temp to final destination
for temp_path, final_path in files_to_commit.items():
final_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(temp_path), str(final_path))
saved_files.append(final_path)

except Exception as e:
errmsg = f"Failed to save sample atomically: {e}"
image_context = data[0] if data else None
save_error = FailedToSaveSingleImageError(errmsg, image_context)
save_errors.append(save_error)
logger.error(errmsg, error=save_error)

finally:
# Stage 3: Cleanup the temporary directory
if temp_dir and temp_dir.exists():
shutil.rmtree(temp_dir)

return AnnotatedPathSequence(saved_files, save_errors)