Skip to content

Fix Invertd confusion with postprocessing transforms #8526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,33 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:

orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
if orig_key in d and isinstance(d[orig_key], MetaTensor):
transform_info = d[orig_key].applied_operations
all_transforms = d[orig_key].applied_operations
meta_info = d[orig_key].meta
else:

# If orig_key == key, the data at d[orig_key] may have been modified by
# postprocessing transforms. We need to exclude any transforms that were
# added after the preprocessing pipeline completed.
# When orig_key == key, filter out postprocessing transforms to prevent
# confusion during inversion (see issue #8396)
if orig_key == key:
num_preproc_transforms = 0
try:
if hasattr(self.transform, 'transforms'):
for t in self.transform.flatten().transforms:
if isinstance(t, InvertibleTransform):
num_preproc_transforms += 1
elif isinstance(self.transform, InvertibleTransform):
num_preproc_transforms = 1
except AttributeError:
# Fallback: use all transforms if flatten fails
num_preproc_transforms = len(all_transforms)

if num_preproc_transforms > 0:
transform_info = all_transforms[:num_preproc_transforms]
else:
transform_info = all_transforms
else:
transform_info = all_transforms
transform_info = d[InvertibleTransform.trace_key(orig_key)]
meta_info = d.get(orig_meta_key, {})
if nearest_interp:
Expand Down
43 changes: 43 additions & 0 deletions tests/transforms/inverse/test_invertd.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,49 @@ def test_invert(self):

set_determinism(seed=None)

def test_invert_with_postproc_lambdad(self):
"""Test that Invertd works correctly when postprocessing contains invertible transforms like Lambdad."""
set_determinism(seed=0)

# Create test images
im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100))

# Define preprocessing transforms
preproc = Compose([
LoadImaged(KEYS, image_only=True),
EnsureChannelFirstd(KEYS),
Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32),
ScaleIntensityd("image", minv=1, maxv=10),
ResizeWithPadOrCropd(KEYS, 100),
])

# Define postprocessing with Lambdad before Invertd (the problematic case)
from monai.transforms import Lambdad
postproc = Compose([
# This Lambdad should not interfere with Invertd
Lambdad(["pred"], lambda x: x), # Identity transform
# Invertd should only invert the preprocessing transforms
Invertd(["pred"], preproc, orig_keys=["image"], nearest_interp=True),
])

# Apply preprocessing
data = {"image": im_fname, "label": seg_fname}
preprocessed = preproc(data)

# Create prediction (copy from preprocessed image)
preprocessed["pred"] = preprocessed["image"].clone()

# Apply postprocessing with Lambdad before Invertd
# This should work without errors - the main issue was that it would fail
result = postproc(preprocessed)
# Check that the inversion was successful
self.assertIn("pred", result)
# Check that the shape was correctly inverted
self.assertTupleEqual(result["pred"].shape[1:], (101, 100, 107))
# The fact that we got here without an exception means the fix is working

set_determinism(seed=None)


if __name__ == "__main__":
unittest.main()
Loading