diff --git a/spimquant/workflow/scripts/clean_segmentation.py b/spimquant/workflow/scripts/clean_segmentation.py index 9e6523b..0d72c2f 100644 --- a/spimquant/workflow/scripts/clean_segmentation.py +++ b/spimquant/workflow/scripts/clean_segmentation.py @@ -15,31 +15,32 @@ from zarrnii import ZarrNii from zarrnii.plugins import SegmentationCleaner -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - - hires_level = int(snakemake.wildcards.level) - proc_level = int(snakemake.params.proc_level) - - znimg = ZarrNii.from_ome_zarr( - snakemake.input.mask, - level=0, # we load level 0 since we are already at the highres level - **snakemake.params.zarrnii_kwargs, - ) - - # perform cleaning of artifactual positives by - # removing objects with low extent (extent is ratio of num voxels to bounding box) - - # the downsample_factor we use should be proportional to the segmentation level - # e.g. if segmentation level is 3, then we have already downsampled by 2^3, so - # the downsample factor should be divided by that.. - unadjusted_downsample_factor = 2**proc_level - adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) - - znimg_cleaned = znimg.apply_scaled_processing( - SegmentationCleaner(max_extent=snakemake.params.max_extent), - downsample_factor=adjusted_downsample_factor, - upsampled_ome_zarr_path=snakemake.output.exclude_mask, - ) - - # write to final ome_zarr - znimg_cleaned.to_ome_zarr(snakemake.output.cleaned_mask, max_layer=5) +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + hires_level = int(snakemake.wildcards.level) + proc_level = int(snakemake.params.proc_level) + + znimg = ZarrNii.from_ome_zarr( + snakemake.input.mask, + level=0, # we load level 0 since we are already at the highres level + **snakemake.params.zarrnii_kwargs, + ) + + # perform cleaning of artifactual positives by + # removing objects with low extent (extent is ratio of num voxels to bounding box) + + # the downsample_factor we use should be proportional to the segmentation level + # e.g. if segmentation level is 3, then we have already downsampled by 2^3, so + # the downsample factor should be divided by that.. + unadjusted_downsample_factor = 2**proc_level + adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) + + znimg_cleaned = znimg.apply_scaled_processing( + SegmentationCleaner(max_extent=snakemake.params.max_extent), + downsample_factor=adjusted_downsample_factor, + upsampled_ome_zarr_path=snakemake.output.exclude_mask, + ) + + # write to final ome_zarr + znimg_cleaned.to_ome_zarr(snakemake.output.cleaned_mask, max_layer=5) diff --git a/spimquant/workflow/scripts/coloc_per_voxel_template.py b/spimquant/workflow/scripts/coloc_per_voxel_template.py index 47b206a..a6d7f08 100644 --- a/spimquant/workflow/scripts/coloc_per_voxel_template.py +++ b/spimquant/workflow/scripts/coloc_per_voxel_template.py @@ -1,7 +1,6 @@ import numpy as np from zarrnii import ZarrNii, density_from_points from dask.diagnostics import ProgressBar -from dask_setup import get_dask_client import pandas as pd img = ZarrNii.from_nifti( @@ -11,12 +10,11 @@ if hasattr(snakemake.wildcards, "level"): img = img.downsample(level=int(snakemake.wildcards.level)) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - df = pd.read_parquet(snakemake.input.coloc_parquet) +df = pd.read_parquet(snakemake.input.coloc_parquet) - points = df[snakemake.params.coord_column_names].values +points = df[snakemake.params.coord_column_names].values - # Create counts map (zarrnii is calling this density right now).. - counts = density_from_points(points, img, in_physical_space=True) - with ProgressBar(): - counts.to_nifti(snakemake.output.counts_nii) +# Create counts map (zarrnii is calling this density right now).. +counts = density_from_points(points, img, in_physical_space=True) +with ProgressBar(): + counts.to_nifti(snakemake.output.counts_nii) diff --git a/spimquant/workflow/scripts/compute_filtered_regionprops.py b/spimquant/workflow/scripts/compute_filtered_regionprops.py index da2a511..e85c03f 100644 --- a/spimquant/workflow/scripts/compute_filtered_regionprops.py +++ b/spimquant/workflow/scripts/compute_filtered_regionprops.py @@ -8,16 +8,17 @@ from dask_setup import get_dask_client from zarrnii import ZarrNii -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - znimg = ZarrNii.from_ome_zarr( - snakemake.input.mask, - level=0, # input image is already downsampled to the wildcard level - **snakemake.params.zarrnii_kwargs, - ) + znimg = ZarrNii.from_ome_zarr( + snakemake.input.mask, + level=0, # input image is already downsampled to the wildcard level + **snakemake.params.zarrnii_kwargs, + ) - znimg.compute_region_properties( - output_path=snakemake.output.regionprops_parquet, - region_filters=snakemake.params.region_filters, - output_properties=snakemake.params.output_properties, - ) + znimg.compute_region_properties( + output_path=snakemake.output.regionprops_parquet, + region_filters=snakemake.params.region_filters, + output_properties=snakemake.params.output_properties, + ) diff --git a/spimquant/workflow/scripts/counts_per_voxel.py b/spimquant/workflow/scripts/counts_per_voxel.py index 7e5711e..7403eb0 100644 --- a/spimquant/workflow/scripts/counts_per_voxel.py +++ b/spimquant/workflow/scripts/counts_per_voxel.py @@ -1,25 +1,24 @@ import numpy as np from zarrnii import ZarrNii, density_from_points from dask.diagnostics import ProgressBar -from dask_setup import get_dask_client import pandas as pd level = int(snakemake.wildcards.level) stain = snakemake.wildcards.stain -img = ZarrNii.from_ome_zarr( - snakemake.input.ref_spim, - level=level, - channel_labels=[stain], - downsample_near_isotropic=True, -) +with get_dask_client("threads", snakemake.threads): + img = ZarrNii.from_ome_zarr( + snakemake.input.ref_spim, + level=level, + channel_labels=[stain], + downsample_near_isotropic=True, + ) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): df = pd.read_parquet(snakemake.input.regionprops_parquet) points = df[snakemake.params.coord_column_names].values - # Create counts map (zarrnii is calling this density right now).. - counts = density_from_points(points, img, in_physical_space=True) - with ProgressBar(): - counts.to_nifti(snakemake.output.counts_nii) +# Create counts map (zarrnii is calling this density right now).. +counts = density_from_points(points, img, in_physical_space=True) +with ProgressBar(): + counts.to_nifti(snakemake.output.counts_nii) diff --git a/spimquant/workflow/scripts/counts_per_voxel_template.py b/spimquant/workflow/scripts/counts_per_voxel_template.py index d0db2c1..8407578 100644 --- a/spimquant/workflow/scripts/counts_per_voxel_template.py +++ b/spimquant/workflow/scripts/counts_per_voxel_template.py @@ -1,7 +1,6 @@ import numpy as np from zarrnii import ZarrNii, density_from_points from dask.diagnostics import ProgressBar -from dask_setup import get_dask_client import pandas as pd stain = snakemake.wildcards.stain @@ -13,7 +12,7 @@ if hasattr(snakemake.wildcards, "level"): img = img.downsample(level=int(snakemake.wildcards.level)) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): df = pd.read_parquet(snakemake.input.regionprops_parquet) df = df[df["stain"] == snakemake.wildcards.stain] diff --git a/spimquant/workflow/scripts/create_imaris_crops.py b/spimquant/workflow/scripts/create_imaris_crops.py index 4fac9a0..889f8fd 100644 --- a/spimquant/workflow/scripts/create_imaris_crops.py +++ b/spimquant/workflow/scripts/create_imaris_crops.py @@ -48,7 +48,7 @@ # Create output directory Path(output_dir).mkdir(parents=True, exist_ok=True) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): # Load the atlas with labels atlas = ZarrNiiAtlas.from_files( input_dseg, diff --git a/spimquant/workflow/scripts/create_patches.py b/spimquant/workflow/scripts/create_patches.py index f44b292..22d7ff1 100644 --- a/spimquant/workflow/scripts/create_patches.py +++ b/spimquant/workflow/scripts/create_patches.py @@ -57,7 +57,7 @@ # Create output directory Path(output_dir).mkdir(parents=True, exist_ok=True) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): # Load the atlas with labels atlas = ZarrNiiAtlas.from_files( input_dseg, diff --git a/spimquant/workflow/scripts/gaussian_biasfield.py b/spimquant/workflow/scripts/gaussian_biasfield.py index ef6cb9f..460d748 100644 --- a/spimquant/workflow/scripts/gaussian_biasfield.py +++ b/spimquant/workflow/scripts/gaussian_biasfield.py @@ -3,31 +3,32 @@ from zarrnii import ZarrNii from zarrnii.plugins import GaussianBiasFieldCorrection -hires_level = int(snakemake.wildcards.level) -proc_level = int(snakemake.params.proc_level) +if __name__ == "__main__": + hires_level = int(snakemake.wildcards.level) + proc_level = int(snakemake.params.proc_level) -unadjusted_downsample_factor = 2**proc_level -adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) + unadjusted_downsample_factor = 2**proc_level + adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - znimg = ZarrNii.from_ome_zarr( - snakemake.input.spim, - channel_labels=[snakemake.wildcards.stain], - level=hires_level, - downsample_near_isotropic=True, - **snakemake.params.zarrnii_kwargs, - ) + znimg = ZarrNii.from_ome_zarr( + snakemake.input.spim, + channel_labels=[snakemake.wildcards.stain], + level=hires_level, + downsample_near_isotropic=True, + **snakemake.params.zarrnii_kwargs, + ) - print("compute bias field correction") - with ProgressBar(): + print("compute bias field correction") + with ProgressBar(): - # Apply bias field correction - znimg_corrected = znimg.apply_scaled_processing( - GaussianBiasFieldCorrection(sigma=5.0), - downsample_factor=adjusted_downsample_factor, - upsampled_ome_zarr_path=snakemake.output.biasfield, - ) + # Apply bias field correction + znimg_corrected = znimg.apply_scaled_processing( + GaussianBiasFieldCorrection(sigma=5.0), + downsample_factor=adjusted_downsample_factor, + upsampled_ome_zarr_path=snakemake.output.biasfield, + ) - # write to ome_zarr - znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5) + # write to ome_zarr + znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5) diff --git a/spimquant/workflow/scripts/multiotsu.py b/spimquant/workflow/scripts/multiotsu.py index 6a15191..6b6c985 100644 --- a/spimquant/workflow/scripts/multiotsu.py +++ b/spimquant/workflow/scripts/multiotsu.py @@ -5,39 +5,40 @@ matplotlib.use("agg") -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - - # we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0 - znimg = ZarrNii.from_ome_zarr( - snakemake.input.corrected, **snakemake.params.zarrnii_kwargs - ) - - # first calculate histogram - using preset bins to avoid issues where bins are too large - # because of high intensity outliers - (hist_counts, bin_edges) = znimg.compute_histogram( - bins=snakemake.params.hist_bins, range=snakemake.params.hist_range - ) - - # get otsu thresholds (uses histogram) - print("computing thresholds") - (thresholds, fig) = compute_otsu_thresholds( - hist_counts, - classes=snakemake.params.otsu_k, - bin_edges=bin_edges, - return_figure=True, - ) - print(f" 📈 thresholds: {[f'{t:.3f}' for t in thresholds]}") - - fig.savefig(snakemake.output.thresholds_png) - - print("thresholding image, saving as ome zarr") - znimg_mask = znimg.segment_threshold( - thresholds[snakemake.params.otsu_threshold_index] - ) - - # multiplying binary mask by 100 (so values are 0 and 100) to enable - # field fraction calculation by subsequent local-mean downsampling - znimg_mask = znimg_mask * 100 - - # write to ome_zarr - znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + # we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0 + znimg = ZarrNii.from_ome_zarr( + snakemake.input.corrected, **snakemake.params.zarrnii_kwargs + ) + + # first calculate histogram - using preset bins to avoid issues where bins are too large + # because of high intensity outliers + (hist_counts, bin_edges) = znimg.compute_histogram( + bins=snakemake.params.hist_bins, range=snakemake.params.hist_range + ) + + # get otsu thresholds (uses histogram) + print("computing thresholds") + (thresholds, fig) = compute_otsu_thresholds( + hist_counts, + classes=snakemake.params.otsu_k, + bin_edges=bin_edges, + return_figure=True, + ) + print(f" 📈 thresholds: {[f'{t:.3f}' for t in thresholds]}") + + fig.savefig(snakemake.output.thresholds_png) + + print("thresholding image, saving as ome zarr") + znimg_mask = znimg.segment_threshold( + thresholds[snakemake.params.otsu_threshold_index] + ) + + # multiplying binary mask by 100 (so values are 0 and 100) to enable + # field fraction calculation by subsequent local-mean downsampling + znimg_mask = znimg_mask * 100 + + # write to ome_zarr + znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) diff --git a/spimquant/workflow/scripts/n4_biasfield.py b/spimquant/workflow/scripts/n4_biasfield.py index a47dd59..54bcd8b 100644 --- a/spimquant/workflow/scripts/n4_biasfield.py +++ b/spimquant/workflow/scripts/n4_biasfield.py @@ -3,33 +3,34 @@ from zarrnii.plugins import N4BiasFieldCorrection from dask.diagnostics import ProgressBar -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - hires_level = int(snakemake.wildcards.level) - proc_level = int(snakemake.params.proc_level) + hires_level = int(snakemake.wildcards.level) + proc_level = int(snakemake.params.proc_level) - unadjusted_downsample_factor = 2**proc_level - adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) + unadjusted_downsample_factor = 2**proc_level + adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) - znimg = ZarrNii.from_ome_zarr( - snakemake.input.spim, - channel_labels=[snakemake.wildcards.stain], - level=hires_level, - downsample_near_isotropic=True, - **snakemake.params.zarrnii_kwargs, - ) + znimg = ZarrNii.from_ome_zarr( + snakemake.input.spim, + channel_labels=[snakemake.wildcards.stain], + level=hires_level, + downsample_near_isotropic=True, + **snakemake.params.zarrnii_kwargs, + ) - print("compute bias field correction") + print("compute bias field correction") - adjusted_chunk = int(320 / (2**adjusted_downsample_factor)) + adjusted_chunk = int(320 / (2**adjusted_downsample_factor)) - with ProgressBar(): - # Apply bias field correction - znimg_corrected = znimg.apply_scaled_processing( - N4BiasFieldCorrection(shrink_factor=snakemake.params.shrink_factor), - downsample_factor=adjusted_downsample_factor, - upsampled_ome_zarr_path=snakemake.output.biasfield, - ) + with ProgressBar(): + # Apply bias field correction + znimg_corrected = znimg.apply_scaled_processing( + N4BiasFieldCorrection(shrink_factor=snakemake.params.shrink_factor), + downsample_factor=adjusted_downsample_factor, + upsampled_ome_zarr_path=snakemake.output.biasfield, + ) - # write to ome_zarr - znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5) + # write to ome_zarr + znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5) diff --git a/spimquant/workflow/scripts/ome_zarr_to_nii.py b/spimquant/workflow/scripts/ome_zarr_to_nii.py index c4c6e74..4869e38 100644 --- a/spimquant/workflow/scripts/ome_zarr_to_nii.py +++ b/spimquant/workflow/scripts/ome_zarr_to_nii.py @@ -2,7 +2,7 @@ from dask_setup import get_dask_client from zarrnii import ZarrNii -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): znimg = ZarrNii.from_ome_zarr( snakemake.input.spim, level=int(snakemake.wildcards.level), diff --git a/spimquant/workflow/scripts/signed_distance_transform.py b/spimquant/workflow/scripts/signed_distance_transform.py index d20ebe0..4c9fef4 100644 --- a/spimquant/workflow/scripts/signed_distance_transform.py +++ b/spimquant/workflow/scripts/signed_distance_transform.py @@ -25,95 +25,98 @@ from dask_setup import get_dask_client -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - - overlap_depth = snakemake.params.overlap_depth - - znimg = ZarrNii.from_ome_zarr( - snakemake.input.mask, **snakemake.params.zarrnii_kwargs - ) - - # Get physical voxel spacing from the ZarrNii scale metadata. - # znimg.scale is a dict keyed by dimension name (e.g. {'z': 0.004, 'y': 0.0027, 'x': 0.0027}). - # znimg.dims gives the ordered dimension names, e.g. ['c', 'z', 'y', 'x']. - scale = znimg.scale - _known_spatial = {"z", "y", "x"} - spatial_dims = [d for d in znimg.dims if d in _known_spatial] - spacing = [] - for d in spatial_dims: - s = scale.get(d) - if s is None: - import warnings - - warnings.warn( - f"Physical spacing for dimension '{d}' not found in OME-Zarr " - "metadata; defaulting to 1.0. Distance values may not be in " - "physical units.", - stacklevel=2, - ) - s = 1.0 - spacing.append(s) - spacing = np.array(spacing) - - # Calculate the maximum physical distance across a single block. - # This is used as the fill value for entirely-foreground or entirely-background - # blocks, where the true boundary lies outside the block. - # max() is used because the last chunk along each axis may be smaller than - # the others (when the array size is not divisible by the chunk size); the - # largest chunk determines the worst-case block diagonal. - spatial_chunk_indices = [znimg.dims.index(d) for d in spatial_dims] - block_size = np.array([max(znimg.darr.chunks[i]) for i in spatial_chunk_indices]) - max_dist = float(np.sqrt(np.sum((block_size * spacing) ** 2))) - - def signed_dt_block(block): - """Compute signed distance transform for a single block. - - Expects block shape (C, Z, Y, X). For each channel the Euclidean - distance transform is computed twice: - - dt_inside : distance (in physical units) from each foreground - voxel to the nearest background voxel. - - dt_outside: distance (in physical units) from each background - voxel to the nearest foreground voxel. - - The signed distance transform is dt_outside - dt_inside, giving - negative values inside the mask and positive values outside. - - Blocks that are entirely foreground are filled with -max_dist, and - blocks that are entirely background are filled with +max_dist. - """ - result = np.zeros(block.shape, dtype=np.float32) - for c in range(block.shape[0]): - binary = block[c] > 0 - n_fg = np.count_nonzero(binary) - n_total = binary.size - if n_fg == 0: - # All background: nearest foreground is at least one block away - result[c] = max_dist - elif n_fg == n_total: - # All foreground: nearest background is at least one block away - result[c] = -max_dist - else: - dt_inside = distance_transform_edt(binary, sampling=spacing).astype( - np.float32 +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + overlap_depth = snakemake.params.overlap_depth + + znimg = ZarrNii.from_ome_zarr( + snakemake.input.mask, **snakemake.params.zarrnii_kwargs + ) + + # Get physical voxel spacing from the ZarrNii scale metadata. + # znimg.scale is a dict keyed by dimension name (e.g. {'z': 0.004, 'y': 0.0027, 'x': 0.0027}). + # znimg.dims gives the ordered dimension names, e.g. ['c', 'z', 'y', 'x']. + scale = znimg.scale + _known_spatial = {"z", "y", "x"} + spatial_dims = [d for d in znimg.dims if d in _known_spatial] + spacing = [] + for d in spatial_dims: + s = scale.get(d) + if s is None: + import warnings + + warnings.warn( + f"Physical spacing for dimension '{d}' not found in OME-Zarr " + "metadata; defaulting to 1.0. Distance values may not be in " + "physical units.", + stacklevel=2, ) - dt_outside = distance_transform_edt(~binary, sampling=spacing).astype( - np.float32 - ) - result[c] = dt_outside - dt_inside - return result - - # depth=0 for the channel dimension, overlap_depth for spatial dims - depth = {0: 0, 1: overlap_depth, 2: overlap_depth, 3: overlap_depth} - - sdt_darr = da.map_overlap( - signed_dt_block, - znimg.darr, - depth=depth, - boundary=0, - dtype=np.float32, - ) - - znimg.darr = sdt_darr - - with ProgressBar(): - znimg.to_ome_zarr(snakemake.output.dist, max_layer=5) + s = 1.0 + spacing.append(s) + spacing = np.array(spacing) + + # Calculate the maximum physical distance across a single block. + # This is used as the fill value for entirely-foreground or entirely-background + # blocks, where the true boundary lies outside the block. + # max() is used because the last chunk along each axis may be smaller than + # the others (when the array size is not divisible by the chunk size); the + # largest chunk determines the worst-case block diagonal. + spatial_chunk_indices = [znimg.dims.index(d) for d in spatial_dims] + block_size = np.array( + [max(znimg.darr.chunks[i]) for i in spatial_chunk_indices] + ) + max_dist = float(np.sqrt(np.sum((block_size * spacing) ** 2))) + + def signed_dt_block(block): + """Compute signed distance transform for a single block. + + Expects block shape (C, Z, Y, X). For each channel the Euclidean + distance transform is computed twice: + - dt_inside : distance (in physical units) from each foreground + voxel to the nearest background voxel. + - dt_outside: distance (in physical units) from each background + voxel to the nearest foreground voxel. + + The signed distance transform is dt_outside - dt_inside, giving + negative values inside the mask and positive values outside. + + Blocks that are entirely foreground are filled with -max_dist, and + blocks that are entirely background are filled with +max_dist. + """ + result = np.zeros(block.shape, dtype=np.float32) + for c in range(block.shape[0]): + binary = block[c] > 0 + n_fg = np.count_nonzero(binary) + n_total = binary.size + if n_fg == 0: + # All background: nearest foreground is at least one block away + result[c] = max_dist + elif n_fg == n_total: + # All foreground: nearest background is at least one block away + result[c] = -max_dist + else: + dt_inside = distance_transform_edt(binary, sampling=spacing).astype( + np.float32 + ) + dt_outside = distance_transform_edt( + ~binary, sampling=spacing + ).astype(np.float32) + result[c] = dt_outside - dt_inside + return result + + # depth=0 for the channel dimension, overlap_depth for spatial dims + depth = {0: 0, 1: overlap_depth, 2: overlap_depth, 3: overlap_depth} + + sdt_darr = da.map_overlap( + signed_dt_block, + znimg.darr, + depth=depth, + boundary=0, + dtype=np.float32, + ) + + znimg.darr = sdt_darr + + with ProgressBar(): + znimg.to_ome_zarr(snakemake.output.dist, max_layer=5) diff --git a/spimquant/workflow/scripts/threshold.py b/spimquant/workflow/scripts/threshold.py index 57593e8..4ba122c 100644 --- a/spimquant/workflow/scripts/threshold.py +++ b/spimquant/workflow/scripts/threshold.py @@ -1,18 +1,19 @@ from dask_setup import get_dask_client from zarrnii import ZarrNii -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - znimg_hires = ZarrNii.from_ome_zarr( - snakemake.input.corrected, **snakemake.params.zarrnii_kwargs - ) + znimg_hires = ZarrNii.from_ome_zarr( + snakemake.input.corrected, **snakemake.params.zarrnii_kwargs + ) - print("thresholding image, saving as ome zarr") - znimg_mask = znimg_hires.segment_threshold(snakemake.params.threshold) + print("thresholding image, saving as ome zarr") + znimg_mask = znimg_hires.segment_threshold(snakemake.params.threshold) - # multiplying binary mask by 100 (so values are 0 and 100) to enable - # field fraction calculation by subsequent local-mean downsampling - znimg_mask = znimg_mask * 100 + # multiplying binary mask by 100 (so values are 0 and 100) to enable + # field fraction calculation by subsequent local-mean downsampling + znimg_mask = znimg_mask * 100 - # write to ome_zarr - znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) + # write to ome_zarr + znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) diff --git a/spimquant/workflow/scripts/vesselfm.py b/spimquant/workflow/scripts/vesselfm.py index df3c13c..b0e7daf 100644 --- a/spimquant/workflow/scripts/vesselfm.py +++ b/spimquant/workflow/scripts/vesselfm.py @@ -4,17 +4,18 @@ from dask.diagnostics import ProgressBar from dask_setup import get_dask_client -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - znimg = ZarrNii.from_ome_zarr( - snakemake.input.spim, - level=int(snakemake.wildcards.level), - channel_labels=[snakemake.wildcards.stain], - downsample_near_isotropic=True, - **snakemake.params.zarrnii_kwargs, - ) - znimg_mask = znimg.segment(VesselFMPlugin, **snakemake.params.vesselfm_kwargs) +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + znimg = ZarrNii.from_ome_zarr( + snakemake.input.spim, + level=int(snakemake.wildcards.level), + channel_labels=[snakemake.wildcards.stain], + downsample_near_isotropic=True, + **snakemake.params.zarrnii_kwargs, + ) + znimg_mask = znimg.segment(VesselFMPlugin, **snakemake.params.vesselfm_kwargs) - znimg_mask = znimg_mask * 100 + znimg_mask = znimg_mask * 100 - with ProgressBar(): - znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5, zarr_format=2) + with ProgressBar(): + znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5, zarr_format=2)