Skip to content

Commit b1983a1

Browse files
committed
Remove matindex
1 parent 8180218 commit b1983a1

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,11 @@ def remove_all_stripe(
201201
Corrected 3D tomographic data as a CuPy or NumPy array.
202202
203203
"""
204-
matindex = _create_matindex(data.shape[2], data.shape[0])
205204
for m in range(data.shape[1]):
206-
sino = data[:, m, :]
207-
sino = _rs_dead(sino, snr, la_size, matindex)
208-
sino = _rs_sort(sino, sm_size, dim)
209-
sino = cp.nan_to_num(sino)
210-
data[:, m, :] = sino
205+
data[:, m, :] = _rs_dead(data[:, m, :], snr, la_size)
206+
data[:, m, :] = _rs_sort(data[:, m, :], sm_size, dim)
207+
data[:, m, :] = cp.nan_to_num(data[:, m, :])
208+
211209
return data
212210

213211

@@ -252,7 +250,7 @@ def _detect_stripe(listdata, snr):
252250
return listmask
253251

254252

255-
def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
253+
def _rs_large(sinogram, snr, size, drop_ratio=0.1, norm=True):
256254
"""
257255
Remove large stripes.
258256
"""
@@ -264,35 +262,35 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
264262
list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0)
265263
list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0)
266264
listfact = list1 / list2
267-
268265
# Locate stripes
269266
listmask = _detect_stripe(listfact, snr)
270267
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
271-
matfact = cp.tile(listfact, (nrow, 1))
268+
272269
# Normalize
273-
if norm is True:
274-
sinogram = sinogram / matfact
275-
sinogram1 = cp.transpose(sinogram)
276-
matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))
277-
278-
ids = cp.argsort(matcombine[:, :, 1], axis=1)
279-
matsort = matcombine.copy()
280-
matsort[:, :, 0] = cp.take_along_axis(matsort[:, :, 0], ids, axis=1)
281-
matsort[:, :, 1] = cp.take_along_axis(matsort[:, :, 1], ids, axis=1)
282-
283-
matsort[:, :, 1] = cp.transpose(sinosmooth)
284-
ids = cp.argsort(matsort[:, :, 0], axis=1)
285-
matsortback = matsort.copy()
286-
matsortback[:, :, 0] = cp.take_along_axis(matsortback[:, :, 0], ids, axis=1)
287-
matsortback[:, :, 1] = cp.take_along_axis(matsortback[:, :, 1], ids, axis=1)
288-
289-
sino_corrected = cp.transpose(matsortback[:, :, 1])
270+
if norm:
271+
sinogram /= cp.tile(listfact, (nrow, 1))
272+
273+
sino_transposed = sinogram.T
274+
ids_sort = cp.argsort(sino_transposed, axis=1)
275+
276+
# Apply sorting without explicit matindex
277+
sino_sorted = cp.take_along_axis(sino_transposed, ids_sort, axis=1)
278+
279+
# Smoothen sorted sinogram
280+
sino_sorted[:, :] = cp.transpose(sinosmooth)
281+
282+
# Restore original order
283+
ids_restore = cp.argsort(ids_sort, axis=1)
284+
sino_corrected = cp.take_along_axis(sino_sorted, ids_restore, axis=1).T
285+
286+
# Apply corrections only to affected columns
290287
listxmiss = cp.where(listmask > 0.0)[0]
291288
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]
289+
292290
return sinogram
293291

294292

295-
def _rs_dead(sinogram, snr, size, matindex, norm=True):
293+
def _rs_dead(sinogram, snr, size, norm=True):
296294
"""remove unresponsive and fluctuating stripes"""
297295
sinogram = cp.copy(sinogram) # Make it mutable
298296
(nrow, _) = sinogram.shape
@@ -323,7 +321,7 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
323321

324322
# Remove residual stripes
325323
if norm is True:
326-
sinogram = _rs_large(sinogram, snr, size, matindex)
324+
sinogram = _rs_large(sinogram, snr, size)
327325
return sinogram
328326

329327

0 commit comments

Comments
 (0)