Skip to content

Commit 198cd2f

Browse files
authored
2111 - adds a map_items option to compose (#2112)
* update compose Signed-off-by: Wenqi Li <[email protected]> * update based on comments Signed-off-by: Wenqi Li <[email protected]>
1 parent 06fd8fe commit 198cd2f

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

monai/transforms/compose.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434

3535
class Compose(Randomizable, InvertibleTransform):
3636
"""
37-
``Compose`` provides the ability to chain a series of calls together in a
38-
sequence. Each transform in the sequence must take a single argument and
39-
return a single value, so that the transforms can be called in a chain.
37+
``Compose`` provides the ability to chain a series of callables together in
38+
a sequential manner. Each transform in the sequence must take a single
39+
argument and return a single value.
4040
4141
``Compose`` can be used in two ways:
4242
@@ -48,23 +48,31 @@ class Compose(Randomizable, InvertibleTransform):
4848
dictionary. It is required that the dictionary is copied between input
4949
and output of each transform.
5050
51-
If some transform generates a list batch of data in the transform chain,
52-
every item in the list is still a dictionary, and all the following
53-
transforms will apply to every item of the list, for example:
51+
If some transform takes a data item dictionary as input, and returns a
52+
sequence of data items in the transform chain, all following transforms
53+
will be applied to each item of this list if `map_items` is `True` (the
54+
default). If `map_items` is `False`, the returned sequence is passed whole
55+
to the next callable in the chain.
5456
55-
#. transformA normalizes the intensity of 'img' field in the dict data.
56-
#. transformB crops out a list batch of images on 'img' and 'seg' field.
57-
And constructs a list of dict data, other fields are copied::
57+
For example:
5858
59-
{ [{ {
60-
'img': [1, 2], 'img': [1], 'img': [2],
61-
'seg': [1, 2], 'seg': [1], 'seg': [2],
62-
'extra': 123, --> 'extra': 123, 'extra': 123,
63-
'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD'
64-
} }, }]
59+
A `Compose([transformA, transformB, transformC],
60+
map_items=True)(data_dict)` could achieve the following patch-based
61+
transformation on the `data_dict` input:
6562
66-
#. transformC then randomly rotates or flips 'img' and 'seg' fields of
67-
every dictionary item in the list.
63+
#. transformA normalizes the intensity of 'img' field in the `data_dict`.
64+
#. transformB crops out image patches from the 'img' and 'seg' of
65+
`data_dict`, and return a list of three patch samples::
66+
67+
{'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)}
68+
applying transformB
69+
---------->
70+
[{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},
71+
{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},
72+
{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},]
73+
74+
#. transformC then randomly rotates or flips 'img' and 'seg' of
75+
each dictionary item in the list returned by transformB.
6876
6977
The composed transforms will be set the same global random seed if user called
7078
`set_determinism()`.
@@ -93,10 +101,13 @@ class Compose(Randomizable, InvertibleTransform):
93101
them are called on the labels.
94102
"""
95103

96-
def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = None) -> None:
104+
def __init__(
105+
self, transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True
106+
) -> None:
97107
if transforms is None:
98108
transforms = []
99109
self.transforms = ensure_tuple(transforms)
110+
self.map_items = map_items
100111
self.set_random_state(seed=get_seed())
101112

102113
def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose":
@@ -141,7 +152,7 @@ def __len__(self):
141152

142153
def __call__(self, input_):
143154
for _transform in self.transforms:
144-
input_ = apply_transform(_transform, input_)
155+
input_ = apply_transform(_transform, input_, self.map_items)
145156
return input_
146157

147158
def inverse(self, data):
@@ -151,5 +162,5 @@ def inverse(self, data):
151162

152163
# loop backwards over transforms
153164
for t in reversed(invertible_transforms):
154-
data = apply_transform(t.inverse, data)
165+
data = apply_transform(t.inverse, data, self.map_items)
155166
return data

tests/test_compose.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,29 @@ def c(d): # transform to handle dict data
7979
for item in value:
8080
self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2})
8181

82+
def test_list_dict_compose_no_map(self):
83+
def a(d): # transform to handle dict data
84+
d = dict(d)
85+
d["a"] += 1
86+
return d
87+
88+
def b(d): # transform to generate a batch list of data
89+
d = dict(d)
90+
d["b"] += 1
91+
d = [d] * 5
92+
return d
93+
94+
def c(d): # transform to handle dict data
95+
d = [dict(di) for di in d]
96+
for di in d:
97+
di["c"] += 1
98+
return d
99+
100+
transforms = Compose([a, a, b, c, c], map_items=False)
101+
value = transforms({"a": 0, "b": 0, "c": 0})
102+
for item in value:
103+
self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2})
104+
82105
def test_random_compose(self):
83106
class _Acc(Randomizable):
84107
self.rand = 0.0

0 commit comments

Comments
 (0)