Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .scripts/download_zenodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def calculate_md5(filename):

def download_zenodo_files(output_dir: Path):
"""
Download all files from Zenodo record 14938787 and verify their checksums.
Download all files from Zenodo record 14979785 and verify their checksums.

Args:
output_dir: Directory where files should be downloaded
"""
try:
print("Fetching files from Zenodo record 14938787...")
print("Fetching files from Zenodo record 14979785...")
with urllib.request.urlopen(
"https://zenodo.org/api/records/14938787"
"https://zenodo.org/api/records/14979785"
) as response:
data = json.loads(response.read())

Expand Down
54 changes: 26 additions & 28 deletions httomolibgpu/prep/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,11 @@ def remove_all_stripe(
Corrected 3D tomographic data as a CuPy or NumPy array.

"""
matindex = _create_matindex(data.shape[2], data.shape[0])
for m in range(data.shape[1]):
sino = data[:, m, :]
sino = _rs_dead(sino, snr, la_size, matindex)
sino = _rs_sort(sino, sm_size, dim)
sino = cp.nan_to_num(sino)
data[:, m, :] = sino
data[:, m, :] = _rs_dead(data[:, m, :], snr, la_size)
data[:, m, :] = _rs_sort(data[:, m, :], sm_size, dim)
data[:, m, :] = cp.nan_to_num(data[:, m, :])

return data


Expand Down Expand Up @@ -252,7 +250,7 @@ def _detect_stripe(listdata, snr):
return listmask


def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
def _rs_large(sinogram, snr, size, drop_ratio=0.1, norm=True):
"""
Remove large stripes.
"""
Expand All @@ -264,35 +262,35 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0)
list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0)
listfact = list1 / list2

# Locate stripes
listmask = _detect_stripe(listfact, snr)
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
matfact = cp.tile(listfact, (nrow, 1))

# Normalize
if norm is True:
sinogram = sinogram / matfact
sinogram1 = cp.transpose(sinogram)
matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))

ids = cp.argsort(matcombine[:, :, 1], axis=1)
matsort = matcombine.copy()
matsort[:, :, 0] = cp.take_along_axis(matsort[:, :, 0], ids, axis=1)
matsort[:, :, 1] = cp.take_along_axis(matsort[:, :, 1], ids, axis=1)

matsort[:, :, 1] = cp.transpose(sinosmooth)
ids = cp.argsort(matsort[:, :, 0], axis=1)
matsortback = matsort.copy()
matsortback[:, :, 0] = cp.take_along_axis(matsortback[:, :, 0], ids, axis=1)
matsortback[:, :, 1] = cp.take_along_axis(matsortback[:, :, 1], ids, axis=1)

sino_corrected = cp.transpose(matsortback[:, :, 1])
if norm:
sinogram /= cp.tile(listfact, (nrow, 1))

sino_transposed = sinogram.T
ids_sort = cp.argsort(sino_transposed, axis=1)

# Apply sorting without explicit matindex
sino_sorted = cp.take_along_axis(sino_transposed, ids_sort, axis=1)

# Smoothen sorted sinogram
sino_sorted[:, :] = cp.transpose(sinosmooth)

# Restore original order
ids_restore = cp.argsort(ids_sort, axis=1)
sino_corrected = cp.take_along_axis(sino_sorted, ids_restore, axis=1).T

# Apply corrections only to affected columns
listxmiss = cp.where(listmask > 0.0)[0]
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]

return sinogram


def _rs_dead(sinogram, snr, size, matindex, norm=True):
def _rs_dead(sinogram, snr, size, norm=True):
"""remove unresponsive and fluctuating stripes"""
sinogram = cp.copy(sinogram) # Make it mutable
(nrow, _) = sinogram.shape
Expand Down Expand Up @@ -323,7 +321,7 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):

# Remove residual stripes
if norm is True:
sinogram = _rs_large(sinogram, snr, size, matindex)
sinogram = _rs_large(sinogram, snr, size)
return sinogram


Expand Down
14 changes: 14 additions & 0 deletions zenodo-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,20 @@ def geant4_dataset1(geant4_dataset1_file):
)


@pytest.fixture(scope="session")
def synth_tomophantom1_file(test_data_path):
in_file = os.path.join(test_data_path, "synth_tomophantom1.npz")
return np.load(in_file)


@pytest.fixture
def synth_tomophantom1_dataset(synth_tomophantom1_file):
return (
cp.asarray(cp.swapaxes(synth_tomophantom1_file["projdata"], 0, 1)),
synth_tomophantom1_file["angles"],
)


@pytest.fixture
def ensure_clean_memory():
gc.collect()
Expand Down
36 changes: 36 additions & 0 deletions zenodo-tests/test_prep/test_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,42 @@ def test_remove_all_stripe_i12_dataset4(
assert output.flags.c_contiguous


@pytest.mark.parametrize(
"dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected",
[
("synth_tomophantom1_dataset", 1.0, 61, 21, 53435.61),
("synth_tomophantom1_dataset", 0.1, 61, 21, 67917.71),
("synth_tomophantom1_dataset", 0.001, 61, 21, 70015.51),
],
ids=["snr_1", "snr_2", "snr_3"],
)
def test_remove_all_stripe_synth_tomophantom1_dataset(
request, dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected
):
dataset = request.getfixturevalue(dataset_fixture)
force_clean_gpu_memory()

output = remove_all_stripe(
cp.copy(dataset[0]),
snr=snr_val,
la_size=la_size_val,
sm_size=sm_size_val,
dim=1,
)
np.savez(
"/home/algol/Documents/DEV/httomolibgpu/zenodo-tests/large_data_archive/stripe_res2.npz",
data=output.get(),
)

residual_calc = dataset[0] - output
norm_res = cp.linalg.norm(residual_calc.flatten())

assert isclose(norm_res, norm_res_expected, abs_tol=10**-2)
assert not np.isnan(output).any(), "Output contains NaN values"
assert output.dtype == np.float32
assert output.flags.c_contiguous


@pytest.mark.parametrize(
"dataset_fixture, nvalue_val, vvalue_val, norm_res_expected",
[
Expand Down
Loading