Skip to content

Commit 66d0478

Browse files
authored
6136 6146 update the default writer flag (#6147)
Fixes #6136 fixes #6146 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <[email protected]>
1 parent 678b512 commit 66d0478

File tree

10 files changed

+99
-33
lines changed

10 files changed

+99
-33
lines changed

monai/data/image_writer.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,13 +373,14 @@ class ITKWriter(ImageWriter):
373373
output_dtype: DtypeLike = None
374374
channel_dim: int | None
375375

376-
def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool = True, **kwargs):
376+
def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool | None = True, **kwargs):
377377
"""
378378
Args:
379379
output_dtype: output data type.
380380
affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``.
381381
Set to ``True`` to be consistent with ``NibabelWriter``,
382382
otherwise the affine matrix is assumed already in the ITK convention.
383+
Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.
383384
kwargs: keyword arguments passed to ``ImageWriter``.
384385
385386
The constructor will create ``self.output_dtype`` internally.
@@ -406,17 +407,20 @@ def set_data_array(
406407
kwargs: keyword arguments passed to ``self.convert_to_channel_last``,
407408
currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively.
408409
"""
409-
_r = len(data_array.shape)
410+
n_chns = data_array.shape[channel_dim] if channel_dim is not None else 0
410411
self.data_obj = self.convert_to_channel_last(
411412
data=data_array,
412413
channel_dim=channel_dim,
413414
squeeze_end_dims=squeeze_end_dims,
414415
spatial_ndim=kwargs.pop("spatial_ndim", 3),
415416
contiguous=kwargs.pop("contiguous", True),
416417
)
417-
self.channel_dim = (
418-
channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None
419-
) # channel dim is at the end
418+
self.channel_dim = -1 # in most cases, the data is set to channel last
419+
if squeeze_end_dims and n_chns <= 1: # num_channel==1 squeezed
420+
self.channel_dim = None
421+
if not squeeze_end_dims and n_chns < 1: # originally no channel and convert_to_channel_last added a channel
422+
self.channel_dim = None
423+
self.data_obj = self.data_obj[..., 0]
420424

421425
def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options):
422426
"""
@@ -478,7 +482,7 @@ def create_backend_obj(
478482
channel_dim: int | None = 0,
479483
affine: NdarrayOrTensor | None = None,
480484
dtype: DtypeLike = np.float32,
481-
affine_lps_to_ras: bool = True,
485+
affine_lps_to_ras: bool | None = True,
482486
**kwargs,
483487
):
484488
"""
@@ -492,14 +496,18 @@ def create_backend_obj(
492496
affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``.
493497
Set to ``True`` to be consistent with ``NibabelWriter``,
494498
otherwise the affine matrix is assumed already in the ITK convention.
499+
Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.
495500
kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary.
496501
497502
see also:
498503
499504
- https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389
505+
500506
"""
501-
if isinstance(data_array, MetaTensor) and data_array.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) != SpaceKeys.LPS:
502-
affine_lps_to_ras = False # do the converting from LPS to RAS only if the space type is currently LPS.
507+
if isinstance(data_array, MetaTensor) and affine_lps_to_ras is None:
508+
affine_lps_to_ras = (
509+
data_array.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) != SpaceKeys.LPS
510+
) # do the converting from LPS to RAS only if the space type is currently LPS.
503511
data_array = super().create_backend_obj(data_array)
504512
_is_vec = channel_dim is not None
505513
if _is_vec:

monai/networks/layers/filtering.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,10 @@ class BilateralFilter(torch.autograd.Function):
3232
3333
Args:
3434
input: input tensor.
35-
3635
spatial_sigma: the standard deviation of the spatial blur. Higher values can
3736
hurt performance when not using the approximate method (see fast approx).
38-
3937
color_sigma: the standard deviation of the color blur. Lower values preserve
4038
edges better whilst higher values tend to a simple gaussian spatial blur.
41-
4239
fast approx: This flag chooses between two implementations. The approximate method may
4340
produce artifacts in some scenarios whereas the exact solution may be intolerably
4441
slow for high spatial standard deviations.
@@ -76,9 +73,7 @@ class PHLFilter(torch.autograd.Function):
7673
7774
Args:
7875
input: input tensor to be filtered.
79-
8076
features: feature tensor used to filter the input.
81-
8277
sigmas: the standard deviations of each feature in the filter.
8378
8479
Returns:
@@ -114,13 +109,9 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):
114109
115110
Args:
116111
input: input tensor to be filtered.
117-
118112
sigma x: trainable standard deviation of the spatial filter kernel in x direction.
119-
120113
sigma y: trainable standard deviation of the spatial filter kernel in y direction.
121-
122114
sigma z: trainable standard deviation of the spatial filter kernel in z direction.
123-
124115
color sigma: trainable standard deviation of the intensity range kernel. This filter
125116
parameter determines the degree of edge preservation.
126117
@@ -200,11 +191,9 @@ class TrainableBilateralFilter(torch.nn.Module):
200191
201192
Args:
202193
input: input tensor to be filtered.
203-
204194
spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard
205195
deviations of the spatial filter kernels. Tuple length must equal the number of
206196
spatial input dimensions.
207-
208197
color_sigma: trainable standard deviation of the intensity range kernel. This filter
209198
parameter determines the degree of edge preservation.
210199
@@ -280,15 +269,10 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):
280269
281270
Args:
282271
input: input tensor to be filtered.
283-
284272
guide: guidance image tensor to be used during filtering.
285-
286273
sigma x: trainable standard deviation of the spatial filter kernel in x direction.
287-
288274
sigma y: trainable standard deviation of the spatial filter kernel in y direction.
289-
290275
sigma z: trainable standard deviation of the spatial filter kernel in z direction.
291-
292276
color sigma: trainable standard deviation of the intensity range kernel. This filter
293277
parameter determines the degree of edge preservation.
294278
@@ -373,13 +357,10 @@ class TrainableJointBilateralFilter(torch.nn.Module):
373357
374358
Args:
375359
input: input tensor to be filtered.
376-
377360
guide: guidance image tensor to be used during filtering.
378-
379361
spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard
380362
deviations of the spatial filter kernels. Tuple length must equal the number of
381363
spatial input dimensions.
382-
383364
color_sigma: trainable standard deviation of the intensity range kernel. This filter
384365
parameter determines the degree of edge preservation.
385366

monai/transforms/io/array.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
PydicomReader,
4141
)
4242
from monai.data.meta_tensor import MetaTensor
43+
from monai.data.utils import is_no_channel
4344
from monai.transforms.transform import Transform
4445
from monai.transforms.utility.array import EnsureChannelFirst
4546
from monai.utils import GridSamplePadMode
@@ -440,6 +441,7 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ
440441
self.meta_kwargs.update(meta_kwargs)
441442
if write_kwargs is not None:
442443
self.write_kwargs.update(write_kwargs)
444+
return self
443445

444446
def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None):
445447
"""
@@ -450,8 +452,15 @@ def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None
450452
meta_data = img.meta if isinstance(img, MetaTensor) else meta_data
451453
kw = self.fname_formatter(meta_data, self)
452454
filename = self.folder_layout.filename(**kw)
453-
if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape):
454-
self.data_kwargs["channel_dim"] = None
455+
if meta_data:
456+
meta_spatial_shape = ensure_tuple(meta_data.get("spatial_shape", ()))
457+
if len(meta_spatial_shape) >= len(img.shape):
458+
self.data_kwargs["channel_dim"] = None
459+
elif is_no_channel(self.data_kwargs.get("channel_dim")):
460+
warnings.warn(
461+
f"data shape {img.shape} (with spatial shape {meta_spatial_shape}) "
462+
f"but SaveImage `channel_dim` is set to {self.data_kwargs.get('channel_dim')} no channel."
463+
)
455464

456465
err = []
457466
for writer_cls in self.writers:

monai/transforms/io/dictionary.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def __init__(
296296

297297
def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):
298298
self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs)
299+
return self
299300

300301
def __call__(self, data):
301302
d = dict(data)

tests/test_auto3dseg_ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171

7272
@skip_if_quick
73-
@SkipIfBeforePyTorchVersion((1, 10, 0))
73+
@SkipIfBeforePyTorchVersion((1, 13, 0))
7474
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
7575
class TestEnsembleBuilder(unittest.TestCase):
7676
def setUp(self) -> None:

tests/test_image_rw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):
167167
filepath = f"testfile_{ndim}d"
168168
saver = SaveImage(
169169
output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer
170-
)
170+
).set_options(init_kwargs={"affine_lps_to_ras": True})
171171
test_data = MetaTensor(
172172
p(test_data), meta={"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape}
173173
)

tests/test_integration_autorunner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070

7171
@skip_if_quick
72-
@SkipIfBeforePyTorchVersion((1, 9, 1))
72+
@SkipIfBeforePyTorchVersion((1, 13, 0))
7373
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
7474
class TestAutoRunner(unittest.TestCase):
7575
def setUp(self) -> None:

tests/test_integration_gpu_customization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070

7171
@skip_if_quick
72-
@SkipIfBeforePyTorchVersion((1, 9, 1))
72+
@SkipIfBeforePyTorchVersion((1, 13, 0))
7373
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
7474
class TestEnsembleGpuCustomization(unittest.TestCase):
7575
def setUp(self) -> None:

tests/test_itk_writer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ def test_rgb(self):
5252
np.testing.assert_allclose(output.shape, (5, 5, 3))
5353
np.testing.assert_allclose(output[1, 1], (5, 5, 4))
5454

55+
def test_no_channel(self):
56+
with tempfile.TemporaryDirectory() as tempdir:
57+
fname = os.path.join(tempdir, "testing.nii.gz")
58+
writer = ITKWriter(output_dtype=np.uint8)
59+
writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=None)
60+
writer.write(fname)
61+
62+
output = np.asarray(itk.imread(fname))
63+
np.testing.assert_allclose(output.shape, (4, 4, 3))
64+
np.testing.assert_allclose(output[1, 1], (5, 21, 37))
65+
5566

5667
if __name__ == "__main__":
5768
unittest.main()

tests/testing_data/integration_answers.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,62 @@
1414
import numpy as np
1515

1616
EXPECTED_ANSWERS = [
17+
{ # test answers for PyTorch 2.0
18+
"integration_segmentation_3d": {
19+
"losses": [
20+
0.5430086106061935,
21+
0.47010003924369814,
22+
0.4453376233577728,
23+
0.451901963353157,
24+
0.4398456811904907,
25+
0.43450237810611725,
26+
],
27+
"best_metric": 0.9329540133476257,
28+
"infer_metric": 0.9330471754074097,
29+
"output_sums": [
30+
0.14212507078546172,
31+
0.15199039602949577,
32+
0.15133471939291526,
33+
0.13967984811021827,
34+
0.18831614355832332,
35+
0.1694076821827231,
36+
0.14663931509271658,
37+
0.16788710637623733,
38+
0.1569452710008219,
39+
0.17907130698392254,
40+
0.16244092698688475,
41+
0.1679350345855819,
42+
0.14437674754879065,
43+
0.11355098478396568,
44+
0.161660275855964,
45+
0.20082478187698194,
46+
0.17575491677668853,
47+
0.0974593860605401,
48+
0.19366775441539907,
49+
0.20293016863409002,
50+
0.19610441127101647,
51+
0.20812173772459808,
52+
0.16184212006067655,
53+
0.13185211452732482,
54+
0.14824716961304257,
55+
0.14229818359602905,
56+
0.23141282114085215,
57+
0.1609268635938338,
58+
0.14825300029123678,
59+
0.10286266811772046,
60+
0.11873484714087054,
61+
0.1296615212510262,
62+
0.11386621034856693,
63+
0.15203351148564773,
64+
0.16300823766585265,
65+
0.1936726544485426,
66+
0.2227251185536394,
67+
0.18067789917505797,
68+
0.19005874127683337,
69+
0.07462121515702229,
70+
],
71+
}
72+
},
1773
{ # test answers for PyTorch 1.12.1
1874
"integration_classification_2d": {
1975
"losses": [0.776835828070428, 0.1615355300011149, 0.07492854832938523, 0.04591309238865877],

0 commit comments

Comments
 (0)