Skip to content

Commit e2fa53b

Browse files
authored
Add channel_wise in RandScaleIntensity (#6793)
Part of #6629 . ### Description Add `channel_wise` in `RandScaleIntensity` and `RandScaleIntensityd`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]>
1 parent 11546e8 commit e2fa53b

File tree

4 files changed

+71
-8
lines changed

4 files changed

+71
-8
lines changed

monai/transforms/intensity/array.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,8 @@ def __init__(
493493
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
494494
to ensure that the output has the same mean as the input.
495495
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
496-
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
497-
channel of the image if True.
496+
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
497+
channel of the image if True.
498498
dtype: output data type, if None, same as input image. defaults to float32.
499499
"""
500500
self.factor = factor
@@ -633,12 +633,20 @@ class RandScaleIntensity(RandomizableTransform):
633633

634634
backend = ScaleIntensity.backend
635635

636-
def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtype: DtypeLike = np.float32) -> None:
636+
def __init__(
637+
self,
638+
factors: tuple[float, float] | float,
639+
prob: float = 0.1,
640+
channel_wise: bool = False,
641+
dtype: DtypeLike = np.float32,
642+
) -> None:
637643
"""
638644
Args:
639645
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
640646
if single number, factor value is picked from (-factors, factors).
641647
prob: probability of scale.
648+
channel_wise: if True, scale on each channel separately. Please ensure
649+
that the first dimension represents the channel of the image if True.
642650
dtype: output data type, if None, same as input image. defaults to float32.
643651
644652
"""
@@ -650,26 +658,39 @@ def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtyp
650658
else:
651659
self.factors = (min(factors), max(factors))
652660
self.factor = self.factors[0]
661+
self.channel_wise = channel_wise
653662
self.dtype = dtype
654663

655664
def randomize(self, data: Any | None = None) -> None:
656665
super().randomize(None)
657666
if not self._do_transform:
658667
return None
659-
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
668+
if self.channel_wise:
669+
self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore
670+
else:
671+
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
660672

661673
def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
662674
"""
663675
Apply the transform to `img`.
664676
"""
665677
img = convert_to_tensor(img, track_meta=get_track_meta())
666678
if randomize:
667-
self.randomize()
679+
self.randomize(img)
668680

669681
if not self._do_transform:
670682
return convert_data_type(img, dtype=self.dtype)[0]
671683

672-
return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img)
684+
ret: NdarrayOrTensor
685+
if self.channel_wise:
686+
out = []
687+
for i, d in enumerate(img):
688+
out_channel = ScaleIntensity(minv=None, maxv=None, factor=self.factor[i], dtype=self.dtype)(d) # type: ignore
689+
out.append(out_channel)
690+
ret = torch.stack(out) # type: ignore
691+
else:
692+
ret = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img)
693+
return ret
673694

674695

675696
class RandBiasField(RandomizableTransform):

monai/transforms/intensity/dictionary.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ def __init__(
586586
keys: KeysCollection,
587587
factors: tuple[float, float] | float,
588588
prob: float = 0.1,
589+
channel_wise: bool = False,
589590
dtype: DtypeLike = np.float32,
590591
allow_missing_keys: bool = False,
591592
) -> None:
@@ -597,13 +598,15 @@ def __init__(
597598
if single number, factor value is picked from (-factors, factors).
598599
prob: probability of scale.
599600
(Default 0.1, with 10% probability it returns a scaled array.)
601+
channel_wise: if True, scale on each channel separately. Please ensure
602+
that the first dimension represents the channel of the image if True.
600603
dtype: output data type, if None, same as input image. defaults to float32.
601604
allow_missing_keys: don't raise exception if key is missing.
602605
603606
"""
604607
MapTransform.__init__(self, keys, allow_missing_keys)
605608
RandomizableTransform.__init__(self, prob)
606-
self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0)
609+
self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0, channel_wise=channel_wise)
607610

608611
def set_random_state(
609612
self, seed: int | None = None, state: np.random.RandomState | None = None
@@ -620,8 +623,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
620623
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
621624
return d
622625

626+
# expect all the specified keys have same spatial shape and share same random holes
627+
first_key: Hashable = self.first_key(d)
628+
if first_key == ():
629+
for key in self.key_iterator(d):
630+
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
631+
return d
632+
623633
# all the keys share the same random scale factor
624-
self.scaler.randomize(None)
634+
self.scaler.randomize(d[first_key])
625635
for key in self.key_iterator(d):
626636
d[key] = self.scaler(d[key], randomize=False)
627637
return d

tests/test_rand_scale_intensity.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ def test_value(self, p):
3333
expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32))
3434
assert_allclose(result, p(expected), rtol=1e-7, atol=0, type_test="tensor")
3535

36+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
37+
def test_channel_wise(self, p):
38+
scaler = RandScaleIntensity(factors=0.5, channel_wise=True, prob=1.0)
39+
scaler.set_random_state(seed=0)
40+
im = p(self.imt)
41+
result = scaler(im)
42+
np.random.seed(0)
43+
# simulate the randomize() of transform
44+
np.random.random()
45+
channel_num = self.imt.shape[0]
46+
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
47+
expected = p(
48+
np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32)
49+
)
50+
assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)
51+
3652

3753
if __name__ == "__main__":
3854
unittest.main()

tests/test_rand_scale_intensityd.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ def test_value(self):
3232
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
3333
assert_allclose(result[key], p(expected), type_test="tensor")
3434

35+
def test_channel_wise(self):
36+
key = "img"
37+
for p in TEST_NDARRAYS:
38+
scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0, channel_wise=True)
39+
scaler.set_random_state(seed=0)
40+
result = scaler({key: p(self.imt)})
41+
np.random.seed(0)
42+
# simulate the randomize function of transform
43+
np.random.random()
44+
channel_num = self.imt.shape[0]
45+
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
46+
expected = p(
47+
np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32)
48+
)
49+
assert_allclose(result[key], p(expected), type_test="tensor")
50+
3551

3652
if __name__ == "__main__":
3753
unittest.main()

0 commit comments

Comments
 (0)