Skip to content

Commit 45ca4b1

Browse files
mjanuszcopybara-github
authored andcommitted
Add support for block interpolation in backward mode.
PiperOrigin-RevId: 595629332
1 parent 00ac445 commit 45ca4b1

File tree

1 file changed

+44
-18
lines changed

1 file changed

+44
-18
lines changed

processor/maps.py

+44-18
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from connectomics.common import bounding_box
2222
from connectomics.volume import subvolume
2323
from connectomics.volume import subvolume_processor
24-
2524
import numpy as np
2625
from scipy import spatial
2726
from sofima import map_utils
@@ -54,13 +53,14 @@ class ReconcileCrossBlockMaps(subvolume_processor.SubvolumeProcessor):
5453

5554
def __init__(
5655
self,
57-
cross_block_volinfo,
58-
cross_block_inv_volinfo,
59-
last_inv_volinfo,
60-
main_inv_volinfo,
61-
z_map,
62-
stride,
63-
xy_overlap=128,
56+
cross_block_volinfo: str,
57+
cross_block_inv_volinfo: str,
58+
last_inv_volinfo: str,
59+
main_inv_volinfo: str,
60+
z_map: dict[int | str, int | str],
61+
stride: int,
62+
xy_overlap: int = 128,
63+
backward: bool = False,
6464
input_volinfo=None,
6565
):
6666
"""Constructor.
@@ -80,6 +80,8 @@ def __init__(
8080
in pixels of the output volume
8181
xy_overlap: neighboring subvolume overlap in the XY directions, in units
8282
of pixels of main input volume
83+
backward: whether the mesh was solved in backward mode (proceeding from
84+
higher z coordinates towards lower ones)
8385
input_volinfo: path to the high-res input volume (unused)
8486
"""
8587
del input_volinfo
@@ -91,6 +93,7 @@ def __init__(
9193
self._z_map = {int(k): int(v) for k, v in z_map.items()}
9294
self._sorted_z = list(sorted(self._z_map.keys()))
9395
self._stride = stride
96+
self._backward = backward
9497

9598
def _open_volume(self, path: str) -> Any:
9699
"""Returns a CZYX-shaped ndarray-like object."""
@@ -146,17 +149,30 @@ def _interpolate(
146149
'cross_block' volume
147150
done: set of 'z' section coordinates that have already been processed
148151
"""
149-
xblock_post = load_xblock(self._z_map[z1])
150-
if z0 > 0:
152+
if self._backward:
153+
xblock_post = load_xblock(self._z_map[z0])
154+
else:
155+
xblock_post = load_xblock(self._z_map[z1])
156+
157+
if not self._backward and z0 > 0:
151158
xblock_pre = load_xblock(self._z_map[z0])
152159
xblock_pre_inv = load_xblock_inv(self._z_map[z0])
160+
elif self._backward and z1 < self._sorted_z[-1]:
161+
xblock_pre = load_xblock(self._z_map[z1])
162+
xblock_pre_inv = load_xblock_inv(self._z_map[z1])
153163
else:
154164
xblock_pre_inv = xblock_pre = np.zeros_like(xblock_post)
155165

156-
if z1 != self._sorted_z[-1]:
157-
block_end_inv = load_last_inv(z1)
166+
if self._backward:
167+
if z0 != self._sorted_z[0]:
168+
block_end_inv = load_last_inv(z0)
169+
else:
170+
block_end_inv = load_main_inv(z0)
158171
else:
159-
block_end_inv = load_main_inv(z1)
172+
if z1 != self._sorted_z[-1]:
173+
block_end_inv = load_last_inv(z1)
174+
else:
175+
block_end_inv = load_main_inv(z1)
160176

161177
flat_box = bounding_box.BoundingBox(
162178
start=box.start, size=(box.size[0], box.size[1], 1)
@@ -201,19 +217,29 @@ def _interpolate(
201217
self._stride,
202218
)
203219

204-
b = z1 - z0
220+
block_size = z1 - z0
205221
for z in range(max(box.start[2], z0), min(box.end[2], z1 + 1)):
206222
i = z - z0
207223
# Each section can be processed only once.
208224
if z in done:
209225
continue
210226
rel_z = z - box.start[2]
211227

212-
if i == b:
213-
data[:, rel_z : rel_z + 1, ...] = xblock_post
228+
if i == block_size:
229+
data[:, rel_z : rel_z + 1, ...] = (
230+
xblock_pre if self._backward else xblock_post
231+
)
214232
elif i == 0:
215-
data[:, rel_z : rel_z + 1, ...] = xblock_pre
233+
data[:, rel_z : rel_z + 1, ...] = (
234+
xblock_post if self._backward else xblock_pre
235+
)
216236
else:
237+
238+
if self._backward:
239+
scale = (block_size - i) / block_size
240+
else:
241+
scale = i / block_size
242+
217243
try:
218244
# The output coordinate map here is the inverse of the argument
219245
# passed to warp() in the comment above, i.e.:
@@ -230,7 +256,7 @@ def _interpolate(
230256
interior_aligned,
231257
flat_box,
232258
self._stride,
233-
offset * i / b,
259+
offset * scale,
234260
flat_box,
235261
self._stride,
236262
)

0 commit comments

Comments
 (0)