@@ -201,13 +201,11 @@ def remove_all_stripe(
201201 Corrected 3D tomographic data as a CuPy or NumPy array.
202202
203203 """
204- matindex = _create_matindex (data .shape [2 ], data .shape [0 ])
205204 for m in range (data .shape [1 ]):
206- sino = data [:, m , :]
207- sino = _rs_dead (sino , snr , la_size , matindex )
208- sino = _rs_sort (sino , sm_size , dim )
209- sino = cp .nan_to_num (sino )
210- data [:, m , :] = sino
205+ data [:, m , :] = _rs_dead (data [:, m , :], snr , la_size )
206+ data [:, m , :] = _rs_sort (data [:, m , :], sm_size , dim )
207+ data [:, m , :] = cp .nan_to_num (data [:, m , :])
208+
211209 return data
212210
213211
@@ -252,7 +250,7 @@ def _detect_stripe(listdata, snr):
252250 return listmask
253251
254252
255- def _rs_large (sinogram , snr , size , matindex , drop_ratio = 0.1 , norm = True ):
253+ def _rs_large (sinogram , snr , size , drop_ratio = 0.1 , norm = True ):
256254 """
257255 Remove large stripes.
258256 """
@@ -264,35 +262,35 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
264262 list1 = cp .mean (sinosort [ndrop : nrow - ndrop ], axis = 0 )
265263 list2 = cp .mean (sinosmooth [ndrop : nrow - ndrop ], axis = 0 )
266264 listfact = list1 / list2
267-
268265 # Locate stripes
269266 listmask = _detect_stripe (listfact , snr )
270267 listmask = binary_dilation (listmask , iterations = 1 ).astype (listmask .dtype )
271- matfact = cp . tile ( listfact , ( nrow , 1 ))
268+
272269 # Normalize
273- if norm is True :
274- sinogram = sinogram / matfact
275- sinogram1 = cp . transpose ( sinogram )
276- matcombine = cp . asarray ( cp . dstack (( matindex , sinogram1 )))
277-
278- ids = cp . argsort ( matcombine [:, :, 1 ], axis = 1 )
279- matsort = matcombine . copy ()
280- matsort [:, :, 0 ] = cp .take_along_axis (matsort [:, :, 0 ], ids , axis = 1 )
281- matsort [:, :, 1 ] = cp . take_along_axis ( matsort [:, :, 1 ], ids , axis = 1 )
282-
283- matsort [:, :, 1 ] = cp .transpose (sinosmooth )
284- ids = cp . argsort ( matsort [:, :, 0 ], axis = 1 )
285- matsortback = matsort . copy ()
286- matsortback [:, :, 0 ] = cp .take_along_axis ( matsortback [:, :, 0 ], ids , axis = 1 )
287- matsortback [:, :, 1 ] = cp .take_along_axis (matsortback [:, :, 1 ], ids , axis = 1 )
288-
289- sino_corrected = cp . transpose ( matsortback [:, :, 1 ])
270+ if norm :
271+ sinogram /= cp . tile ( listfact , ( nrow , 1 ))
272+
273+ sino_transposed = sinogram . T
274+ ids_sort = cp . argsort ( sino_transposed , axis = 1 )
275+
276+ # Apply sorting without explicit matindex
277+ sino_sorted = cp .take_along_axis (sino_transposed , ids_sort , axis = 1 )
278+
279+ # Smoothen sorted sinogram
280+ sino_sorted [:, :] = cp .transpose (sinosmooth )
281+
282+ # Restore original order
283+ ids_restore = cp .argsort ( ids_sort , axis = 1 )
284+ sino_corrected = cp .take_along_axis (sino_sorted , ids_restore , axis = 1 ). T
285+
286+ # Apply corrections only to affected columns
290287 listxmiss = cp .where (listmask > 0.0 )[0 ]
291288 sinogram [:, listxmiss ] = sino_corrected [:, listxmiss ]
289+
292290 return sinogram
293291
294292
295- def _rs_dead (sinogram , snr , size , matindex , norm = True ):
293+ def _rs_dead (sinogram , snr , size , norm = True ):
296294 """remove unresponsive and fluctuating stripes"""
297295 sinogram = cp .copy (sinogram ) # Make it mutable
298296 (nrow , _ ) = sinogram .shape
@@ -323,7 +321,7 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
323321
324322 # Remove residual stripes
325323 if norm is True :
326- sinogram = _rs_large (sinogram , snr , size , matindex )
324+ sinogram = _rs_large (sinogram , snr , size )
327325 return sinogram
328326
329327
0 commit comments