Skip to content

Commit 5feb353

Browse files
function2-llxwyli
andauthored
Fix lazy rand affine (#6774)
Fixes #6773. ### Description Call `rand_affine_grid()` once before call `rand_affine_grid.get_transformation_matrix()`, since its documented as "Get the most recently applied transformation matrix", or the `.affine` attribute will not be set. Also, set `randomize=False` here since randomization if performed in the beginning of the function. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. --------- Signed-off-by: function2 <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent e2fa53b commit 5feb353

File tree

3 files changed

+4
-0
lines changed

3 files changed

+4
-0
lines changed

monai/transforms/spatial/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,6 +2523,8 @@ def __call__(
25232523
img = convert_to_tensor(img, track_meta=get_track_meta())
25242524
if lazy_:
25252525
if self._do_transform:
2526+
if grid is None:
2527+
self.rand_affine_grid(sp_size, randomize=randomize, lazy=True)
25262528
affine = self.rand_affine_grid.get_transformation_matrix()
25272529
else:
25282530
affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0]

monai/transforms/spatial/dictionary.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,7 @@ def __call__(
11411141
grid = self.rand_affine.get_identity_grid(sp_size, lazy=lazy_)
11421142
if self._do_transform: # add some random factors
11431143
grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid, lazy=lazy_)
1144+
grid = 0 if grid is None else grid # always provide a grid to self.rand_affine
11441145

11451146
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
11461147
# do the transform

tests/test_rand_affine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def test_rand_affine(self, input_param, input_data, expected_val):
145145
g = RandAffine(**input_param)
146146
g.set_random_state(123)
147147
result = g(**input_data)
148+
g.rand_affine_grid.affine = torch.eye(4, dtype=torch.float64) # reset affine
148149
test_resampler_lazy(g, result, input_param, input_data, seed=123)
149150
if input_param.get("cache_grid", False):
150151
self.assertTrue(g._cached_grid is not None)

0 commit comments

Comments
 (0)