diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7e1e074f71..201d5dcb98 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -687,9 +687,33 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" if orig_key in d and isinstance(d[orig_key], MetaTensor): - transform_info = d[orig_key].applied_operations + all_transforms = d[orig_key].applied_operations meta_info = d[orig_key].meta - else: + + # If orig_key == key, the data at d[orig_key] may have been modified by + # postprocessing transforms. We need to exclude any transforms that were + # added after the preprocessing pipeline completed. + # When orig_key == key, filter out postprocessing transforms to prevent + # confusion during inversion (see issue #8396) + if orig_key == key: + num_preproc_transforms = 0 + try: + if hasattr(self.transform, 'transforms'): + for t in self.transform.flatten().transforms: + if isinstance(t, InvertibleTransform): + num_preproc_transforms += 1 + elif isinstance(self.transform, InvertibleTransform): + num_preproc_transforms = 1 + except AttributeError: + # Fallback: use all transforms if flatten fails + num_preproc_transforms = len(all_transforms) + + if num_preproc_transforms > 0: + transform_info = all_transforms[:num_preproc_transforms] + else: + transform_info = all_transforms + else: + transform_info = all_transforms transform_info = d[InvertibleTransform.trace_key(orig_key)] meta_info = d.get(orig_meta_key, {}) if nearest_interp: diff --git a/tests/transforms/inverse/test_invertd.py b/tests/transforms/inverse/test_invertd.py index 2b5e9da85d..e80f105f6a 100644 --- a/tests/transforms/inverse/test_invertd.py +++ b/tests/transforms/inverse/test_invertd.py @@ -137,6 +137,49 @@ def test_invert(self): set_determinism(seed=None) + def test_invert_with_postproc_lambdad(self): + """Test that Invertd works correctly when postprocessing contains invertible transforms like Lambdad.""" + set_determinism(seed=0) + + # Create test images + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) + + # Define preprocessing transforms + preproc = Compose([ + LoadImaged(KEYS, image_only=True), + EnsureChannelFirstd(KEYS), + Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), + ScaleIntensityd("image", minv=1, maxv=10), + ResizeWithPadOrCropd(KEYS, 100), + ]) + + # Define postprocessing with Lambdad before Invertd (the problematic case) + from monai.transforms import Lambdad + postproc = Compose([ + # This Lambdad should not interfere with Invertd + Lambdad(["pred"], lambda x: x), # Identity transform + # Invertd should only invert the preprocessing transforms + Invertd(["pred"], preproc, orig_keys=["image"], nearest_interp=True), + ]) + + # Apply preprocessing + data = {"image": im_fname, "label": seg_fname} + preprocessed = preproc(data) + + # Create prediction (copy from preprocessed image) + preprocessed["pred"] = preprocessed["image"].clone() + + # Apply postprocessing with Lambdad before Invertd + # This should work without errors - the main issue was that it would fail + result = postproc(preprocessed) + # Check that the inversion was successful + self.assertIn("pred", result) + # Check that the shape was correctly inverted + self.assertTupleEqual(result["pred"].shape[1:], (101, 100, 107)) + # The fact that we got here without an exception means the fix is working + + set_determinism(seed=None) + if __name__ == "__main__": unittest.main()