Skip to content
57 changes: 29 additions & 28 deletions spimquant/workflow/scripts/clean_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,32 @@
from zarrnii import ZarrNii
from zarrnii.plugins import SegmentationCleaner

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):

hires_level = int(snakemake.wildcards.level)
proc_level = int(snakemake.params.proc_level)

znimg = ZarrNii.from_ome_zarr(
snakemake.input.mask,
level=0, # we load level 0 since we are already at the highres level
**snakemake.params.zarrnii_kwargs,
)

# perform cleaning of artifactual positives by
# removing objects with low extent (extent is ratio of num voxels to bounding box)

# the downsample_factor we use should be proportional to the segmentation level
# e.g. if segmentation level is 3, then we have already downsampled by 2^3, so
# the downsample factor should be divided by that..
unadjusted_downsample_factor = 2**proc_level
adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level)

znimg_cleaned = znimg.apply_scaled_processing(
SegmentationCleaner(max_extent=snakemake.params.max_extent),
downsample_factor=adjusted_downsample_factor,
upsampled_ome_zarr_path=snakemake.output.exclude_mask,
)

# write to final ome_zarr
znimg_cleaned.to_ome_zarr(snakemake.output.cleaned_mask, max_layer=5)
if __name__ == "__main__":
with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):

hires_level = int(snakemake.wildcards.level)
proc_level = int(snakemake.params.proc_level)

znimg = ZarrNii.from_ome_zarr(
snakemake.input.mask,
level=0, # we load level 0 since we are already at the highres level
**snakemake.params.zarrnii_kwargs,
)

# perform cleaning of artifactual positives by
# removing objects with low extent (extent is ratio of num voxels to bounding box)

# the downsample_factor we use should be proportional to the segmentation level
# e.g. if segmentation level is 3, then we have already downsampled by 2^3, so
# the downsample factor should be divided by that..
unadjusted_downsample_factor = 2**proc_level
adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level)

znimg_cleaned = znimg.apply_scaled_processing(
SegmentationCleaner(max_extent=snakemake.params.max_extent),
downsample_factor=adjusted_downsample_factor,
upsampled_ome_zarr_path=snakemake.output.exclude_mask,
)

# write to final ome_zarr
znimg_cleaned.to_ome_zarr(snakemake.output.cleaned_mask, max_layer=5)
14 changes: 6 additions & 8 deletions spimquant/workflow/scripts/coloc_per_voxel_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from zarrnii import ZarrNii, density_from_points
from dask.diagnostics import ProgressBar
from dask_setup import get_dask_client
import pandas as pd

img = ZarrNii.from_nifti(
Expand All @@ -11,12 +10,11 @@
if hasattr(snakemake.wildcards, "level"):
img = img.downsample(level=int(snakemake.wildcards.level))

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
df = pd.read_parquet(snakemake.input.coloc_parquet)
df = pd.read_parquet(snakemake.input.coloc_parquet)

points = df[snakemake.params.coord_column_names].values
points = df[snakemake.params.coord_column_names].values

# Create counts map (zarrnii is calling this density right now)..
counts = density_from_points(points, img, in_physical_space=True)
with ProgressBar():
counts.to_nifti(snakemake.output.counts_nii)
# Create counts map (zarrnii is calling this density right now)..
counts = density_from_points(points, img, in_physical_space=True)
with ProgressBar():
counts.to_nifti(snakemake.output.counts_nii)
23 changes: 12 additions & 11 deletions spimquant/workflow/scripts/compute_filtered_regionprops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from dask_setup import get_dask_client
from zarrnii import ZarrNii

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
if __name__ == "__main__":
with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):

znimg = ZarrNii.from_ome_zarr(
snakemake.input.mask,
level=0, # input image is already downsampled to the wildcard level
**snakemake.params.zarrnii_kwargs,
)
znimg = ZarrNii.from_ome_zarr(
snakemake.input.mask,
level=0, # input image is already downsampled to the wildcard level
**snakemake.params.zarrnii_kwargs,
)

znimg.compute_region_properties(
output_path=snakemake.output.regionprops_parquet,
region_filters=snakemake.params.region_filters,
output_properties=snakemake.params.output_properties,
)
znimg.compute_region_properties(
output_path=snakemake.output.regionprops_parquet,
region_filters=snakemake.params.region_filters,
output_properties=snakemake.params.output_properties,
)
23 changes: 11 additions & 12 deletions spimquant/workflow/scripts/counts_per_voxel.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import numpy as np
from zarrnii import ZarrNii, density_from_points
from dask.diagnostics import ProgressBar
from dask_setup import get_dask_client
import pandas as pd

level = int(snakemake.wildcards.level)
stain = snakemake.wildcards.stain

img = ZarrNii.from_ome_zarr(
snakemake.input.ref_spim,
level=level,
channel_labels=[stain],
downsample_near_isotropic=True,
)
with get_dask_client("threads", snakemake.threads):
img = ZarrNii.from_ome_zarr(
snakemake.input.ref_spim,
level=level,
channel_labels=[stain],
downsample_near_isotropic=True,
)

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
df = pd.read_parquet(snakemake.input.regionprops_parquet)

points = df[snakemake.params.coord_column_names].values

# Create counts map (zarrnii is calling this density right now)..
counts = density_from_points(points, img, in_physical_space=True)
with ProgressBar():
counts.to_nifti(snakemake.output.counts_nii)
# Create counts map (zarrnii is calling this density right now)..
counts = density_from_points(points, img, in_physical_space=True)
with ProgressBar():
counts.to_nifti(snakemake.output.counts_nii)
3 changes: 1 addition & 2 deletions spimquant/workflow/scripts/counts_per_voxel_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from zarrnii import ZarrNii, density_from_points
from dask.diagnostics import ProgressBar
from dask_setup import get_dask_client
import pandas as pd

stain = snakemake.wildcards.stain
Expand All @@ -13,7 +12,7 @@
if hasattr(snakemake.wildcards, "level"):
img = img.downsample(level=int(snakemake.wildcards.level))

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
with get_dask_client("threads", snakemake.threads):
df = pd.read_parquet(snakemake.input.regionprops_parquet)

df = df[df["stain"] == snakemake.wildcards.stain]
Expand Down
2 changes: 1 addition & 1 deletion spimquant/workflow/scripts/create_imaris_crops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# Create output directory
Path(output_dir).mkdir(parents=True, exist_ok=True)

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
with get_dask_client("threads", snakemake.threads):
# Load the atlas with labels
atlas = ZarrNiiAtlas.from_files(
input_dseg,
Expand Down
2 changes: 1 addition & 1 deletion spimquant/workflow/scripts/create_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
# Create output directory
Path(output_dir).mkdir(parents=True, exist_ok=True)

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
with get_dask_client("threads", snakemake.threads):
# Load the atlas with labels
atlas = ZarrNiiAtlas.from_files(
input_dseg,
Expand Down
45 changes: 23 additions & 22 deletions spimquant/workflow/scripts/gaussian_biasfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,32 @@
from zarrnii import ZarrNii
from zarrnii.plugins import GaussianBiasFieldCorrection

hires_level = int(snakemake.wildcards.level)
proc_level = int(snakemake.params.proc_level)
if __name__ == "__main__":
hires_level = int(snakemake.wildcards.level)
proc_level = int(snakemake.params.proc_level)

unadjusted_downsample_factor = 2**proc_level
adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level)
unadjusted_downsample_factor = 2**proc_level
adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level)

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):

znimg = ZarrNii.from_ome_zarr(
snakemake.input.spim,
channel_labels=[snakemake.wildcards.stain],
level=hires_level,
downsample_near_isotropic=True,
**snakemake.params.zarrnii_kwargs,
)
znimg = ZarrNii.from_ome_zarr(
snakemake.input.spim,
channel_labels=[snakemake.wildcards.stain],
level=hires_level,
downsample_near_isotropic=True,
**snakemake.params.zarrnii_kwargs,
)

print("compute bias field correction")
with ProgressBar():
print("compute bias field correction")
with ProgressBar():

# Apply bias field correction
znimg_corrected = znimg.apply_scaled_processing(
GaussianBiasFieldCorrection(sigma=5.0),
downsample_factor=adjusted_downsample_factor,
upsampled_ome_zarr_path=snakemake.output.biasfield,
)
# Apply bias field correction
znimg_corrected = znimg.apply_scaled_processing(
GaussianBiasFieldCorrection(sigma=5.0),
downsample_factor=adjusted_downsample_factor,
upsampled_ome_zarr_path=snakemake.output.biasfield,
)

# write to ome_zarr
znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5)
# write to ome_zarr
znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5)
73 changes: 37 additions & 36 deletions spimquant/workflow/scripts/multiotsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,40 @@

matplotlib.use("agg")

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):

# we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0
znimg = ZarrNii.from_ome_zarr(
snakemake.input.corrected, **snakemake.params.zarrnii_kwargs
)

# first calculate histogram - using preset bins to avoid issues where bins are too large
# because of high intensity outliers
(hist_counts, bin_edges) = znimg.compute_histogram(
bins=snakemake.params.hist_bins, range=snakemake.params.hist_range
)

# get otsu thresholds (uses histogram)
print("computing thresholds")
(thresholds, fig) = compute_otsu_thresholds(
hist_counts,
classes=snakemake.params.otsu_k,
bin_edges=bin_edges,
return_figure=True,
)
print(f" 📈 thresholds: {[f'{t:.3f}' for t in thresholds]}")

fig.savefig(snakemake.output.thresholds_png)

print("thresholding image, saving as ome zarr")
znimg_mask = znimg.segment_threshold(
thresholds[snakemake.params.otsu_threshold_index]
)

# multiplying binary mask by 100 (so values are 0 and 100) to enable
# field fraction calculation by subsequent local-mean downsampling
znimg_mask = znimg_mask * 100

# write to ome_zarr
znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5)
if __name__ == "__main__":
with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):

# we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0
znimg = ZarrNii.from_ome_zarr(
snakemake.input.corrected, **snakemake.params.zarrnii_kwargs
)

# first calculate histogram - using preset bins to avoid issues where bins are too large
# because of high intensity outliers
(hist_counts, bin_edges) = znimg.compute_histogram(
bins=snakemake.params.hist_bins, range=snakemake.params.hist_range
)

# get otsu thresholds (uses histogram)
print("computing thresholds")
(thresholds, fig) = compute_otsu_thresholds(
hist_counts,
classes=snakemake.params.otsu_k,
bin_edges=bin_edges,
return_figure=True,
)
print(f" 📈 thresholds: {[f'{t:.3f}' for t in thresholds]}")

fig.savefig(snakemake.output.thresholds_png)

print("thresholding image, saving as ome zarr")
znimg_mask = znimg.segment_threshold(
thresholds[snakemake.params.otsu_threshold_index]
)

# multiplying binary mask by 100 (so values are 0 and 100) to enable
# field fraction calculation by subsequent local-mean downsampling
znimg_mask = znimg_mask * 100

# write to ome_zarr
znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5)
47 changes: 24 additions & 23 deletions spimquant/workflow/scripts/n4_biasfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,34 @@
from zarrnii.plugins import N4BiasFieldCorrection
from dask.diagnostics import ProgressBar

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
if __name__ == "__main__":
with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):

hires_level = int(snakemake.wildcards.level)
proc_level = int(snakemake.params.proc_level)
hires_level = int(snakemake.wildcards.level)
proc_level = int(snakemake.params.proc_level)

unadjusted_downsample_factor = 2**proc_level
adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level)
unadjusted_downsample_factor = 2**proc_level
adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level)

znimg = ZarrNii.from_ome_zarr(
snakemake.input.spim,
channel_labels=[snakemake.wildcards.stain],
level=hires_level,
downsample_near_isotropic=True,
**snakemake.params.zarrnii_kwargs,
)
znimg = ZarrNii.from_ome_zarr(
snakemake.input.spim,
channel_labels=[snakemake.wildcards.stain],
level=hires_level,
downsample_near_isotropic=True,
**snakemake.params.zarrnii_kwargs,
)

print("compute bias field correction")
print("compute bias field correction")

adjusted_chunk = int(320 / (2**adjusted_downsample_factor))
adjusted_chunk = int(320 / (2**adjusted_downsample_factor))

with ProgressBar():
# Apply bias field correction
znimg_corrected = znimg.apply_scaled_processing(
N4BiasFieldCorrection(shrink_factor=snakemake.params.shrink_factor),
downsample_factor=adjusted_downsample_factor,
upsampled_ome_zarr_path=snakemake.output.biasfield,
)
with ProgressBar():
# Apply bias field correction
znimg_corrected = znimg.apply_scaled_processing(
N4BiasFieldCorrection(shrink_factor=snakemake.params.shrink_factor),
downsample_factor=adjusted_downsample_factor,
upsampled_ome_zarr_path=snakemake.output.biasfield,
)

# write to ome_zarr
znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5)
# write to ome_zarr
znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5)
2 changes: 1 addition & 1 deletion spimquant/workflow/scripts/ome_zarr_to_nii.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dask_setup import get_dask_client
from zarrnii import ZarrNii

with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads):
with get_dask_client("threads", snakemake.threads):
znimg = ZarrNii.from_ome_zarr(
snakemake.input.spim,
level=int(snakemake.wildcards.level),
Expand Down
Loading
Loading