3434
3535class Compose (Randomizable , InvertibleTransform ):
3636 """
37- ``Compose`` provides the ability to chain a series of calls together in a
38- sequence . Each transform in the sequence must take a single argument and
39- return a single value, so that the transforms can be called in a chain .
37+ ``Compose`` provides the ability to chain a series of callables together in
38+ a sequential manner . Each transform in the sequence must take a single
39+ argument and return a single value.
4040
4141 ``Compose`` can be used in two ways:
4242
@@ -48,23 +48,31 @@ class Compose(Randomizable, InvertibleTransform):
4848 dictionary. It is required that the dictionary is copied between input
4949 and output of each transform.
5050
51- If some transform generates a list batch of data in the transform chain,
52- every item in the list is still a dictionary, and all the following
53- transforms will apply to every item of the list, for example:
51+ If some transform takes a data item dictionary as input, and returns a
52+ sequence of data items in the transform chain, all following transforms
53+ will be applied to each item of this list if `map_items` is `True` (the
54+ default). If `map_items` is `False`, the returned sequence is passed whole
55+ to the next callable in the chain.
5456
55- #. transformA normalizes the intensity of 'img' field in the dict data.
56- #. transformB crops out a list batch of images on 'img' and 'seg' field.
57- And constructs a list of dict data, other fields are copied::
57+ For example:
5858
59- { [{ {
60- 'img': [1, 2], 'img': [1], 'img': [2],
61- 'seg': [1, 2], 'seg': [1], 'seg': [2],
62- 'extra': 123, --> 'extra': 123, 'extra': 123,
63- 'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD'
64- } }, }]
59+ A `Compose([transformA, transformB, transformC],
60+ map_items=True)(data_dict)` could achieve the following patch-based
61+ transformation on the `data_dict` input:
6562
66- #. transformC then randomly rotates or flips 'img' and 'seg' fields of
67- every dictionary item in the list.
63+ #. transformA normalizes the intensity of 'img' field in the `data_dict`.
64+ #. transformB crops out image patches from the 'img' and 'seg' of
65+ `data_dict`, and return a list of three patch samples::
66+
67+ {'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)}
68+ applying transformB
69+ ---------->
70+ [{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},
71+ {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},
72+ {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},]
73+
74+ #. transformC then randomly rotates or flips 'img' and 'seg' of
75+ each dictionary item in the list returned by transformB.
6876
6977 The composed transforms will be set the same global random seed if user called
7078 `set_determinism()`.
@@ -93,10 +101,13 @@ class Compose(Randomizable, InvertibleTransform):
93101 them are called on the labels.
94102 """
95103
96- def __init__ (self , transforms : Optional [Union [Sequence [Callable ], Callable ]] = None ) -> None :
104+ def __init__ (
105+ self , transforms : Optional [Union [Sequence [Callable ], Callable ]] = None , map_items : bool = True
106+ ) -> None :
97107 if transforms is None :
98108 transforms = []
99109 self .transforms = ensure_tuple (transforms )
110+ self .map_items = map_items
100111 self .set_random_state (seed = get_seed ())
101112
102113 def set_random_state (self , seed : Optional [int ] = None , state : Optional [np .random .RandomState ] = None ) -> "Compose" :
@@ -141,7 +152,7 @@ def __len__(self):
141152
142153 def __call__ (self , input_ ):
143154 for _transform in self .transforms :
144- input_ = apply_transform (_transform , input_ )
155+ input_ = apply_transform (_transform , input_ , self . map_items )
145156 return input_
146157
147158 def inverse (self , data ):
@@ -151,5 +162,5 @@ def inverse(self, data):
151162
152163 # loop backwards over transforms
153164 for t in reversed (invertible_transforms ):
154- data = apply_transform (t .inverse , data )
165+ data = apply_transform (t .inverse , data , self . map_items )
155166 return data
0 commit comments