diff --git a/spimquant/config/snakebids.yml b/spimquant/config/snakebids.yml index 08565e5..3330062 100644 --- a/spimquant/config/snakebids.yml +++ b/spimquant/config/snakebids.yml @@ -155,7 +155,6 @@ parse_args: help: "Method to use for microscopy segmentation (e.g. plaques, protein deposits, cells) applied to 'stains_for_seg' channels, and used to calculate field fractions. " default: - otsu+k3i2 - - th900 nargs: '+' --seg_hist_range: @@ -189,7 +188,7 @@ parse_args: choices: - threads - distributed - default: threads + default: distributed --sloppy: help: "Use low-quality parameters for speed (USE FOR TESTING ONLY)" diff --git a/spimquant/workflow/Snakefile b/spimquant/workflow/Snakefile index ef1b04f..e3e0a69 100644 --- a/spimquant/workflow/Snakefile +++ b/spimquant/workflow/Snakefile @@ -271,6 +271,20 @@ rule all_vessels: ), +# inputs["spim"].expand( +# bids( +# root=root, +# datatype="micr", +# stain="{stain}", +# desc="{desc}", +# suffix="mask.ozx", +# **inputs["spim"].wildcards, +# ), +# desc=config["vessel_seg_method"], +# stain=stains_for_vessels, +# ) + + rule all_segment: input: inputs["spim"].expand( @@ -676,7 +690,6 @@ rule all_participant: rules.all_register.input, rules.all_vessels.input if do_vessels else [], rules.all_segment.input if do_seg else [], - rules.all_spim_patches.input if do_seg else [], rules.all_mri_reg.input if config["register_to_mri"] else [], rules.all_segment_coloc.input if do_coloc else [], rules.all_segment_maskcoloc.input if do_maskcoloc else [], diff --git a/spimquant/workflow/rules/groupstats.smk b/spimquant/workflow/rules/groupstats.smk index 93286fa..1fec69e 100644 --- a/spimquant/workflow/rules/groupstats.smk +++ b/spimquant/workflow/rules/groupstats.smk @@ -46,7 +46,7 @@ rule perform_group_stats: ), threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=10, script: "../scripts/perform_group_stats.py" @@ -89,7 +89,7 @@ rule create_stats_heatmap: threads: 1 resources: mem_mb=8000, - runtime=5, + runtime=15, script: "../scripts/create_stats_heatmap.py" @@ -131,7 +131,7 @@ rule map_groupstats_to_template_nii: threads: 8 resources: mem_mb=16000, - runtime=5, + runtime=15, script: "../scripts/map_tsv_dseg_to_nii.py" @@ -167,7 +167,7 @@ rule concat_subj_parquet: ), threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=10, script: "../scripts/concat_subj_parquet.py" @@ -196,12 +196,10 @@ rule group_counts_per_voxel: desc="{desc}", suffix="{stain}+count.nii.gz", ), - group: - "subj" threads: 16 resources: - mem_mb=15000, - runtime=10, + mem_mb=200000, + runtime=30, script: "../scripts/counts_per_voxel_template.py" @@ -229,8 +227,6 @@ rule group_coloc_counts_per_voxel: desc="{desc}", suffix="coloccount.nii.gz", ), - group: - "subj" threads: 16 resources: mem_mb=15000, @@ -275,7 +271,7 @@ rule concat_subj_parquet_contrast: ), threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=10, script: "../scripts/concat_subj_parquet_contrast.py" @@ -306,12 +302,10 @@ rule group_counts_per_voxel_contrast: contrast="{contrast_column}+{contrast_value}", suffix="{stain}+count.nii.gz", ), - group: - "subj" threads: 16 resources: - mem_mb=15000, - runtime=10, + mem_mb=200000, + runtime=30, script: "../scripts/counts_per_voxel_template.py" @@ -341,8 +335,6 @@ rule group_coloc_counts_per_voxel_contrast: contrast="{contrast_column}+{contrast_value}", suffix="coloccount.nii.gz", ), - group: - "subj" threads: 16 resources: mem_mb=15000, @@ -395,7 +387,7 @@ rule concat_subj_segstats_contrast: ), threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=10, script: "../scripts/concat_subj_segstats_contrast.py" @@ -441,6 +433,6 @@ rule map_groupavg_segstats_to_template_nii: threads: 8 resources: mem_mb=16000, - runtime=5, + runtime=15, script: "../scripts/map_tsv_dseg_to_nii.py" diff --git a/spimquant/workflow/rules/import.smk b/spimquant/workflow/rules/import.smk index a8a24c0..5443a00 100644 --- a/spimquant/workflow/rules/import.smk +++ b/spimquant/workflow/rules/import.smk @@ -42,12 +42,10 @@ rule get_downsampled_nii: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=15, script: "../scripts/ome_zarr_to_nii.py" @@ -77,8 +75,8 @@ rule import_template_anat: anat=bids_tpl(root=root, template="{template}", suffix="anat.nii.gz"), threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, log: bids_tpl( root="logs", @@ -109,8 +107,8 @@ rule import_template_spim: ), threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, log: bids_tpl( root="logs", @@ -133,8 +131,8 @@ rule import_mask: ), threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, log: bids_tpl( root="logs", @@ -153,8 +151,8 @@ rule generic_lut_bids_to_itksnap: lut="{prefix}_dseg.itksnap.txt", threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/lut_bids_to_itksnap.py" @@ -182,8 +180,8 @@ rule import_dseg: ), threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/copy_nii.py" @@ -201,8 +199,8 @@ rule import_lut_tsv: tsv=bids_tpl(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, shell: "cp {input} {output}" @@ -214,7 +212,7 @@ rule import_DSURQE_tsv: tsv=bids_tpl(root=root, template="DSURQE", seg="all", suffix="dseg.tsv"), threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/import_DSURQE_dseg_tsv.py" diff --git a/spimquant/workflow/rules/masking.smk b/spimquant/workflow/rules/masking.smk index b6058be..cc60b35 100644 --- a/spimquant/workflow/rules/masking.smk +++ b/spimquant/workflow/rules/masking.smk @@ -49,7 +49,7 @@ rule pre_atropos: desc="preAtropos", suffix="SPIM.nii.gz", **inputs["spim"].wildcards, - ) + ), ), mask=temp( bids( @@ -60,14 +60,12 @@ rule pre_atropos: desc="preAtropos", suffix="mask.nii", **inputs["spim"].wildcards, - ) + ), ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=3000, + runtime=15, conda: "../envs/c3d.yaml" shell: @@ -98,7 +96,7 @@ rule atropos_seg: k="{k}", suffix="dseg.nii", **inputs["spim"].wildcards, - ) + ), ), posteriors_dir=temp( directory( @@ -112,18 +110,16 @@ rule atropos_seg: suffix="posteriors", **inputs["spim"].wildcards, ) - ) + ), ), - group: - "subj" conda: "../envs/ants.yaml" shadow: "minimal" - threads: 1 + threads: 16 resources: - mem_mb=8000, - runtime=15, + mem_mb=32000, + runtime=45, shell: "mkdir -p {output.posteriors_dir} && " "ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} " @@ -155,14 +151,12 @@ rule post_atropos: k="{k}", suffix="dseg.nii", **inputs["spim"].wildcards, - ) + ), ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=3000, + runtime=15, conda: "../envs/c3d.yaml" shell: @@ -199,7 +193,7 @@ rule init_affine_reg: desc="initaffine", suffix="xfm.txt", **inputs["spim"].wildcards, - ) + ), ), warped=temp( bids( @@ -209,10 +203,8 @@ rule init_affine_reg: desc="initaffinewarped", suffix="SPIM.nii.gz", **inputs["spim"].wildcards, - ) + ), ), - group: - "subj" log: bids( root="logs", @@ -224,7 +216,7 @@ rule init_affine_reg: threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=15, shell: "greedy -threads {threads} -d 3 -i {input.template} {input.subject} " " -a -dof 12 -ia-image-centers -m NMI -o {output.xfm_ras} -n {params.iters} && " @@ -254,14 +246,12 @@ rule affine_transform_template_mask_to_subject: from_="{template}", suffix="mask.nii.gz", **inputs["spim"].wildcards, - ) + ), ), - group: - "subj" threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=15, shell: " greedy -threads {threads} -d 3 -rf {input.ref} " " -ri NN" @@ -307,12 +297,10 @@ rule create_mask_from_gmm_and_prior: suffix="mask.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=4000, + runtime=15, script: "../scripts/create_mask_from_gmm_and_prior.py" @@ -341,12 +329,10 @@ rule create_mask_from_gmm: suffix="mask.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=4000, + runtime=15, conda: "../envs/c3d.yaml" shell: diff --git a/spimquant/workflow/rules/patches.smk b/spimquant/workflow/rules/patches.smk index 67efed6..cb27392 100644 --- a/spimquant/workflow/rules/patches.smk +++ b/spimquant/workflow/rules/patches.smk @@ -66,8 +66,6 @@ rule create_spim_patches: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 32 resources: mem_mb=32000, @@ -85,12 +83,12 @@ rule create_mask_patches: """ input: mask=bids( - root=work, + root=root, datatype="micr", stain="{stain}", level=config["segmentation_level"], desc="{desc}", - suffix="mask.ome.zarr", + suffix="mask.ozx", **inputs["spim"].wildcards, ), dseg=bids( @@ -127,8 +125,6 @@ rule create_mask_patches: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 32 resources: mem_mb=32000, @@ -188,8 +184,6 @@ rule create_corrected_spim_patches: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 32 resources: mem_mb=32000, @@ -237,8 +231,6 @@ rule create_imaris_crops: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 32 resources: mem_mb=32000, diff --git a/spimquant/workflow/rules/preproc_mri.smk b/spimquant/workflow/rules/preproc_mri.smk index be6a02d..2f0389e 100644 --- a/spimquant/workflow/rules/preproc_mri.smk +++ b/spimquant/workflow/rules/preproc_mri.smk @@ -70,11 +70,9 @@ rule n4_mri_individual: nii=inputs["mri"].path, output: nii=temp(bids(root=root, datatype="anat", desc="N4", **inputs["mri"].wildcards)), - group: - "subj" threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, conda: "../envs/ants.yaml" @@ -99,11 +97,9 @@ rule resample_mri_ref: **inputs["mri"].wildcards, ) ), - group: - "subj" threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, conda: "../envs/c3d.yaml" @@ -134,8 +130,6 @@ rule register_mri_to_first: **inputs.subj_wildcards, ) ), - group: - "subj" threads: 8 resources: mem_mb=8000, @@ -178,8 +172,6 @@ rule resample_mri_to_first: **inputs.subj_wildcards, ) ), - group: - "subj" threads: 8 conda: "../envs/c3d.yaml" @@ -217,14 +209,12 @@ rule average_mri: suffix=f"{mri_suffix}.nii.gz", **inputs.subj_wildcards, ), - group: - "subj" threads: 8 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, - group: - "subj" + conda: + "../envs/c3d.yaml" shell: "c3d {input.resampled_images} -accum -add -endaccum -o {output.nii}" @@ -312,12 +302,12 @@ rule rigid_nlin_reg_mri_to_template: **inputs.subj_wildcards, ) ), - group: - "subj" threads: 32 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, + conda: + "../envs/c3d.yaml" shell: "greedy -threads {threads} -d 3 -i {input.template} {input.subject} " " -a -dof 6 -ia-image-centers -m {params.metric} -o {output.xfm_ras} && " @@ -349,8 +339,6 @@ rule all_tune_mri_mask: warpsigma=range(3, 6), radius=[f"{i}x{i}x{i}" for i in range(2, 5)], ), - group: - "subj" rule transform_template_mask_to_mri: @@ -410,11 +398,9 @@ rule transform_template_mask_to_mri: ), shadow: "minimal" - group: - "subj" threads: 32 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, conda: "../envs/c3d.yaml" @@ -453,11 +439,9 @@ rule apply_mri_brain_mask: suffix=f"{mri_suffix}.nii.gz", **inputs.subj_wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, conda: "../envs/c3d.yaml" @@ -573,12 +557,12 @@ rule affine_nlin_reg_mri_to_spim: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 32 resources: - mem_mb=16000, - runtime=15, + mem_mb=32000, + runtime=30, + conda: + "../envs/c3d.yaml" shell: "greedy -threads {threads} -d 3 -i {input.spim} {input.mri} " " -a -dof {params.dof} -ia-image-centers -m {params.metric_rigid} -o {output.xfm_ras} && " @@ -616,8 +600,6 @@ rule all_tune_mri_spim_reg: dof=[12], radius=[f"{i}x{i}x{i}" for i in range(2, 4)], ), - group: - "subj" rule warp_mri_to_template_via_spim: @@ -686,12 +668,12 @@ rule warp_mri_to_template_via_spim: suffix=f"{mri_suffix}.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: - mem_mb=16000, + mem_mb=32000, runtime=15, + conda: + "../envs/c3d.yaml" shell: " greedy -threads {threads} -d 3 -rf {input.ref} " " -rm {input.mri} {output.warped} " @@ -780,12 +762,12 @@ rule warp_mri_brainmask_to_spim: suffix="warp.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, + conda: + "../envs/c3d.yaml" shell: " greedy -threads {threads} -d 3 -rf {input.ref} -ri NN" " -rm {input.mask} {output.mask} " @@ -860,8 +842,6 @@ rule mri_spim_registration_qc_report: suffix="regqc.html", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: mem_mb=8000, diff --git a/spimquant/workflow/rules/segmentation.smk b/spimquant/workflow/rules/segmentation.smk index 0b6a53e..3c88ea1 100644 --- a/spimquant/workflow/rules/segmentation.smk +++ b/spimquant/workflow/rules/segmentation.smk @@ -42,7 +42,8 @@ rule gaussian_biasfield: suffix="SPIM.ome.zarr", **inputs["spim"].wildcards, ) - ) + ), + group_jobs=True, ), biasfield=temp( directory( @@ -55,11 +56,10 @@ rule gaussian_biasfield: suffix="biasfield.ome.zarr", **inputs["spim"].wildcards, ) - ) + ), + group_jobs=True, ), - group: - "subj" - threads: 128 + threads: 128 if config["dask_scheduler"] == "distributed" else 32 resources: mem_mb=256000, disk_mb=2097152, @@ -75,6 +75,8 @@ rule n4_biasfield: params: proc_level=5, zarrnii_kwargs={"orientation": config["orientation"]}, + shrink_factor=16 if config["sloppy"] else 1, + target_chunk_size=512, #this sets the chunk size for this and downstream masks output: corrected=temp( directory( @@ -87,28 +89,13 @@ rule n4_biasfield: suffix="SPIM.ome.zarr", **inputs["spim"].wildcards, ) - ) - ), - biasfield=temp( - directory( - bids( - root=work, - datatype="micr", - stain="{stain}", - level="{level}", - desc="n4", - suffix="biasfield.ome.zarr", - **inputs["spim"].wildcards, - ) - ) + ), + group_jobs=True, ), - group: - "subj" - threads: 128 + threads: 128 if config["dask_scheduler"] == "distributed" else 32 resources: - mem_mb=500000, - disk_mb=2097152, - runtime=60, + mem_mb=500000 if config["dask_scheduler"] == "distributed" else 250000, + runtime=180, script: "../scripts/n4_biasfield.py" @@ -138,18 +125,14 @@ rule multiotsu: otsu_threshold_index=lambda wildcards: int(wildcards.i), zarrnii_kwargs={"orientation": config["orientation"]}, output: - mask=temp( - directory( - bids( - root=work, - datatype="micr", - stain="{stain}", - level="{level}", - desc="otsu+k{k,[0-9]+}i{i,[0-9]+}", - suffix="mask.ome.zarr", - **inputs["spim"].wildcards, - ) - ) + mask=bids( + root=root, + datatype="micr", + stain="{stain}", + level="{level}", + desc="otsu+k{k,[0-9]+}i{i,[0-9]+}", + suffix="mask.ozx", + **inputs["spim"].wildcards, ), thresholds_png=bids( root=root, @@ -160,33 +143,15 @@ rule multiotsu: suffix="thresholds.png", **inputs["spim"].wildcards, ), - group: - "subj" - threads: 128 + threads: 128 if config["dask_scheduler"] == "distributed" else 32 resources: - mem_mb=500000, + mem_mb=500000 if config["dask_scheduler"] == "distributed" else 250000, disk_mb=2097152, - runtime=15, + runtime=180, script: "../scripts/multiotsu.py" -rule convert_zarr_to_ozx: - """generic rule to convert ome zarr to zip (.ozx)""" - input: - zarr=str(Path(work) / "{prefix}.ome.zarr"), - output: - ozx=str(Path(root) / "{prefix}.ozx"), - threads: 4 - resources: - mem_mb=32000, - runtime=60, - group: - "subj" - script: - "../scripts/convert_zarr_to_ozx.py" - - rule threshold: """Apply simple intensity threshold for segmentation. @@ -207,25 +172,19 @@ rule threshold: threshold=lambda wildcards: int(wildcards.threshold), zarrnii_kwargs={"orientation": config["orientation"]}, output: - mask=temp( - directory( - bids( - root=work, - datatype="micr", - stain="{stain}", - level="{level}", - desc="th{threshold,[0-9]+}", - suffix="mask.ome.zarr", - **inputs["spim"].wildcards, - ) - ) + mask=bids( + root=root, + datatype="micr", + stain="{stain}", + level="{level}", + desc="th{threshold,[0-9]+}", + suffix="mask.ozx", + **inputs["spim"].wildcards, ), - group: - "subj" - threads: 128 + threads: 128 if config["dask_scheduler"] == "distributed" else 32 resources: - mem_mb=256000, - runtime=15, + mem_mb=500000 if config["dask_scheduler"] == "distributed" else 250000, + runtime=180, script: "../scripts/threshold.py" @@ -239,12 +198,12 @@ rule clean_segmentation: """ input: mask=bids( - root=work, + root=root, datatype="micr", stain="{stain}", level="{level}", desc="{desc}", - suffix="mask.ome.zarr", + suffix="mask.ozx", **inputs["spim"].wildcards, ), params: @@ -252,99 +211,43 @@ rule clean_segmentation: proc_level=2, #level at which to calculate conncomp zarrnii_kwargs={"orientation": config["orientation"]}, output: - exclude_mask=temp( - directory( - bids( - root=work, - datatype="micr", - stain="{stain}", - level="{level}", - desc="{desc}+cleaned", - suffix="excludemask.ome.zarr", - **inputs["spim"].wildcards, - ) - ) - ), - cleaned_mask=temp( - directory( - bids( - root=work, - datatype="micr", - stain="{stain}", - level="{level}", - desc="{desc}+cleaned", - suffix="mask.ome.zarr", - **inputs["spim"].wildcards, - ) - ) - ), - group: - "subj" - threads: 128 - resources: - mem_mb=256000, - disk_mb=2097152, - runtime=30, - script: - "../scripts/clean_segmentation.py" - - -rule signed_distance_transform: - """Compute signed distance transform from a binary mask. - - Applies the chamfer distance transform (distance_transform_cdt from scipy) - to a binary mask using dask map_overlap for chunked, parallel processing. - The output is a signed distance transform where positive values indicate - the interior and negative values indicate the exterior of the mask. - """ - input: - mask=bids( - root=work, + exclude_mask=bids( + root=root, datatype="micr", stain="{stain}", level="{level}", - desc="{desc}", - suffix="mask.ome.zarr", + desc="{desc}+cleaned", + suffix="excludemask.ozx", **inputs["spim"].wildcards, ), - params: - overlap_depth=32, - zarrnii_kwargs={"orientation": config["orientation"]}, - output: - dist=temp( - directory( - bids( - root=work, - datatype="micr", - stain="{stain}", - level="{level}", - desc="{desc}", - suffix="dist.ome.zarr", - **inputs["spim"].wildcards, - ) - ) + cleaned_mask=bids( + root=root, + datatype="micr", + stain="{stain}", + level="{level}", + desc="{desc}+cleaned", + suffix="mask.ozx", + **inputs["spim"].wildcards, ), - group: - "subj" - threads: 32 + threads: 128 if config["dask_scheduler"] == "distributed" else 32 resources: - mem_mb=64000, + mem_mb=256000, disk_mb=2097152, runtime=30, script: - "../scripts/signed_distance_transform.py" + "../scripts/clean_segmentation.py" rule compute_filtered_regionprops: """Calculate region props from filtered objects of segmentation.""" input: mask=bids( - root=work, + root=root, datatype="micr", stain="{stain}", level=config["segmentation_level"], desc="{desc}", - suffix="mask.ome.zarr", + suffix="mask.ozx", **inputs["spim"].wildcards, ), params: @@ -364,12 +267,10 @@ rule compute_filtered_regionprops: **inputs["spim"].wildcards, ) ), - group: - "subj" - threads: 128 + threads: 64 if config["dask_scheduler"] == "distributed" else 32 resources: mem_mb=256000, - runtime=30, + runtime=180, script: "../scripts/compute_filtered_regionprops.py" @@ -417,12 +318,10 @@ rule transform_regionprops_to_template: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 1 resources: mem_mb=16000, - runtime=5, + runtime=15, script: "../scripts/transform_regionprops_to_template.py" @@ -454,12 +353,10 @@ rule aggregate_regionprops_across_stains: suffix="regionprops.parquet", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/aggregate_regionprops_across_stains.py" @@ -494,12 +391,10 @@ rule colocalize_regionprops: suffix="coloc.parquet", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=10, + mem_mb=1500, + runtime=30, script: "../scripts/compute_colocalization.py" @@ -550,12 +445,12 @@ rule colocalize_regionprops_with_mask: **inputs["spim"].wildcards, ), dist=lambda wildcards: bids( - root=work, + root=root, datatype="micr", stain=wildcards.stain_b, level=config["segmentation_level"], desc=get_mask_seg_desc(wildcards.stain_b), - suffix="dist.ome.zarr", + suffix="dist.ozx", **{k: getattr(wildcards, k) for k in inputs["spim"].wildcards}, ), params: @@ -575,12 +470,10 @@ rule colocalize_regionprops_with_mask: wildcard_constraints: stain_a="[a-zA-Z0-9]+", stain_b="[a-zA-Z0-9]+", - group: - "subj" threads: 1 resources: mem_mb=32000, - runtime=10, + runtime=30, script: "../scripts/compute_colocalization_with_mask.py" @@ -633,12 +526,10 @@ rule transform_maskcoloc_to_template: wildcard_constraints: stain_a="[a-zA-Z0-9]+", stain_b="[a-zA-Z0-9]+", - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=30, script: "../scripts/transform_regionprops_to_template.py" @@ -703,12 +594,10 @@ rule map_maskcoloc_to_atlas_rois: wildcard_constraints: stain_a="[a-zA-Z0-9]+", stain_b="[a-zA-Z0-9]+", - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=30, script: "../scripts/map_atlas_to_regionprops.py" @@ -738,12 +627,10 @@ rule counts_per_voxel: suffix="counts.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 16 resources: mem_mb=15000, - runtime=10, + runtime=20, script: "../scripts/counts_per_voxel.py" @@ -773,12 +660,10 @@ rule counts_per_voxel_template: suffix="counts.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 16 resources: - mem_mb=15000, - runtime=10, + mem_mb=64000, + runtime=30, script: "../scripts/counts_per_voxel_template.py" @@ -807,16 +692,18 @@ rule coloc_per_voxel_template: suffix="coloccounts.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 16 resources: mem_mb=15000, - runtime=10, + runtime=30, script: "../scripts/coloc_per_voxel_template.py" +# to avoid wildcard conflicts: +ruleorder: fieldfrac > fieldfrac_vessels + + rule fieldfrac: """Calculate field fraction from binary mask. @@ -827,12 +714,12 @@ rule fieldfrac: """ input: mask=bids( - root=work, + root=root, datatype="micr", stain="{stain}", level=config["segmentation_level"], desc="{desc}", - suffix="mask.ome.zarr", + suffix="mask.ozx", **inputs["spim"].wildcards, ), params: @@ -848,12 +735,10 @@ rule fieldfrac: suffix="fieldfrac.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=30, script: "../scripts/fieldfrac.py" @@ -881,11 +766,9 @@ rule deform_negative_mask_to_subject_nii: suffix="mask.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, shell: " greedy -threads {threads} -d 3 -rf {input.ref} " @@ -932,12 +815,10 @@ rule map_img_to_roi_tsv: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/map_img_to_roi_tsv.py" @@ -994,12 +875,10 @@ rule map_regionprops_to_atlas_rois: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/map_atlas_to_regionprops.py" @@ -1045,12 +924,10 @@ rule map_coloc_to_atlas_rois: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/map_atlas_to_coloc.py" @@ -1104,12 +981,10 @@ rule merge_into_segstats_tsv: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/merge_into_segstats_tsv.py" @@ -1160,12 +1035,10 @@ rule merge_into_colocsegstats_tsv: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/merge_into_segstats_tsv.py" @@ -1218,12 +1091,10 @@ rule merge_indiv_and_coloc_segstats_tsv: suffix="mergedsegstats.tsv", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, script: "../scripts/merge_indiv_and_coloc_segstats_tsv.py" @@ -1259,12 +1130,10 @@ rule map_segstats_tsv_dseg_to_template_nii: suffix="{suffix}.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: mem_mb=16000, - runtime=5, + runtime=30, script: "../scripts/map_tsv_dseg_to_nii.py" @@ -1308,12 +1177,10 @@ rule map_segstats_tsv_dseg_to_subject_nii: suffix="{suffix}.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: mem_mb=16000, - runtime=15, + runtime=30, script: "../scripts/map_tsv_dseg_to_nii.py" @@ -1359,12 +1226,10 @@ rule deform_fieldfrac_nii_to_template_nii: suffix="fieldfrac.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: - mem_mb=16000, - runtime=5, + mem_mb=32000, + runtime=30, conda: "../envs/ants.yaml" shell: diff --git a/spimquant/workflow/rules/templatereg.smk b/spimquant/workflow/rules/templatereg.smk index 75d399c..98f030a 100644 --- a/spimquant/workflow/rules/templatereg.smk +++ b/spimquant/workflow/rules/templatereg.smk @@ -62,12 +62,10 @@ rule n4: suffix="biasfield.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, conda: "../envs/ants.yaml" shell: @@ -112,12 +110,10 @@ rule apply_mask_to_corrected: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, conda: "../envs/c3d.yaml" shell: @@ -143,7 +139,7 @@ rule crop_template: hemisphere="{hemisphere}", threads: 1 resources: - mem_mb=16000, + mem_mb=1500, runtime=15, script: "../scripts/crop_template.py" @@ -189,8 +185,6 @@ rule affine_reg: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" log: bids( root="logs", @@ -202,7 +196,7 @@ rule affine_reg: threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=15, shell: "greedy -threads {threads} -d 3 -i {input.template} {input.subject} " " -a -dof 12 -ia-image-centers -m NMI -o {output.xfm_ras} -n {params.iters} && " @@ -239,12 +233,10 @@ rule convert_ras_to_itk: suffix="xfm.txt", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: - mem_mb=16000, - runtime=5, + mem_mb=1500, + runtime=15, conda: "../envs/c3d.yaml" shell: @@ -303,8 +295,6 @@ rule deform_reg: **inputs["spim"].wildcards, ) ), - group: - "subj" log: bids( root="logs", @@ -316,7 +306,7 @@ rule deform_reg: threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=5 if config["sloppy"] else 30, shell: "greedy -threads {threads} -d 3 -i {input.template} {input.subject} " " -it {input.xfm_ras} -m {params.metric} " @@ -342,20 +332,14 @@ rule resample_labels_to_zarr: label_name="dseg", scaling_method="nearest", output: - zarr=temp( - directory( - bids( - root=work, - datatype="micr", - desc="resampled", - from_="{template}", - suffix="dseg.ome.zarr", - **inputs["spim"].wildcards, - ) - ) + zarr=bids( + root=root, + datatype="micr", + desc="resampled", + from_="{template}", + suffix="dseg.ozx", + **inputs["spim"].wildcards, ), - group: - "subj" threads: 10 resources: mem_mb=16000, @@ -390,8 +374,6 @@ rule affine_zarr_to_template_nii: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: mem_mb=16000, @@ -408,21 +390,15 @@ rule affine_zarr_to_template_ome_zarr: params: ref_opts={"chunks": (1, 50, 50, 50)}, output: - ome_zarr=temp( - directory( - bids( - root=work, - datatype="micr", - desc="affine", - space="{template}", - stain="{stain}", - suffix="spim.ome.zarr", - **inputs["spim"].wildcards, - ) - ) + ome_zarr=bids( + root=root, + datatype="micr", + desc="affine", + space="{template}", + stain="{stain}", + suffix="spim.ozx", + **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: mem_mb=16000, @@ -453,8 +429,6 @@ rule deform_zarr_to_template_nii: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: mem_mb=16000, @@ -492,11 +466,9 @@ rule deform_to_template_nii_zoomed: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 4 resources: - mem_mb=15000, + mem_mb=16000, runtime=15, script: "../scripts/deform_to_template_nii.py" @@ -541,12 +513,10 @@ rule deform_spim_nii_to_template_nii: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=15, conda: "../envs/ants.yaml" shell: @@ -597,12 +567,10 @@ rule deform_template_dseg_to_subject_nii: suffix="dseg.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 resources: mem_mb=16000, - runtime=5, + runtime=15, conda: "../envs/ants.yaml" shell: @@ -627,12 +595,8 @@ rule copy_template_dseg_tsv: suffix="dseg.tsv", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 - resources: - mem_mb=16000, - runtime=5, + localrule: True shell: "cp {input} {output}" @@ -657,8 +621,6 @@ rule deform_transform_labels_to_subj: **inputs["spim"].wildcards, ) ), - group: - "subj" threads: 32 script: #TODO this script doesn't exist?? "../scripts/deform_transform_channel_to_template_nii.py" @@ -685,8 +647,6 @@ rule transform_labels_to_zoomed_template: suffix="dseg.nii.gz", **inputs["spim"].wildcards, ), - group: - "subj" threads: 32 conda: "../envs/ants.yaml" @@ -752,11 +712,9 @@ rule registration_qc_report: suffix="regqc.html", **inputs["spim"].wildcards, ), - group: - "subj" threads: 1 resources: mem_mb=8000, - runtime=10, + runtime=30, script: "../scripts/reg_qc_report.py" diff --git a/spimquant/workflow/rules/vessels.smk b/spimquant/workflow/rules/vessels.smk index 1c70286..e0b7e9a 100644 --- a/spimquant/workflow/rules/vessels.smk +++ b/spimquant/workflow/rules/vessels.smk @@ -3,6 +3,7 @@ rule import_vesselfm_model: model=storage(config["models"]["vesselfm"]), output: "resources/models/vesselfm.pt", + localrule: True shell: "cp {input} {output}" @@ -18,22 +19,99 @@ rule run_vesselfm: "model_path": input.model_path, }, output: - mask=directory( - bids( - root=work, - datatype="micr", - stain="{stain}", - level="{level}", - desc="vesselfm", - suffix="mask.ome.zarr", - **inputs["spim"].wildcards, - ) + mask=bids( + root=root, + datatype="micr", + stain="{stain}", + level="{level}", + desc="vesselfm", + suffix="mask.ozx", + **inputs["spim"].wildcards, ), threads: 32 - group: - "subj" resources: - mem_mb=32000, + gpu=1, + cpus_per_gpu=32, + mem_mb=64000, runtime=lambda wildcards: max(1, int(200.0 / (3.0 ** float(wildcards.level)))), # rough estimate, clamped to >=1 script: "../scripts/vesselfm.py" + + +rule fieldfrac_vessels: + """Calculate field fraction from binary mask. + + Computes the fraction of brain tissue occupied by the vessels. + Note: This is a separate rule from `fieldfrac` to allow the + dags groups to be disjoint. + + """ + input: + mask=bids( + root=root, + datatype="micr", + stain="{stain}", + level=config["segmentation_level"], + desc="{desc}", + suffix="mask.ozx", + **inputs["spim"].wildcards, + ), + params: + hires_level=config["segmentation_level"], + zarrnii_kwargs={"orientation": config["orientation"]}, + output: + fieldfrac_nii=bids( + root=root, + datatype="micr", + stain="{stain}", + level="{level}", + desc="{desc,vesselfm}", + suffix="fieldfrac.nii.gz", + **inputs["spim"].wildcards, + ), + threads: 32 + resources: + mem_mb=1500, + runtime=15, + script: + "../scripts/fieldfrac.py" + + +rule signed_distance_transform: + """Compute signed distance transform from a binary mask. + + Applies the chamfer distance transform (distance_transform_cdt from scipy) + to a binary mask using dask map_overlap for chunked, parallel processing. + The output is a signed distance transform where positive values indicate + the interior and negative values indicate the exterior of the mask. + """ + input: + mask=bids( + root=root, + datatype="micr", + stain="{stain}", + level="{level}", + desc="vesselfm", + suffix="mask.ozx", + **inputs["spim"].wildcards, + ), + params: + overlap_depth=32, + zarrnii_kwargs={"orientation": config["orientation"]}, + output: + dist=bids( + root=root, + datatype="micr", + stain="{stain}", + level="{level}", + desc="{desc}", + suffix="dist.ozx", + **inputs["spim"].wildcards, + ), + threads: 128 if config["dask_scheduler"] == "distributed" else 32 + resources: + mem_mb=256000, + disk_mb=2097152, + runtime=360, + script: + "../scripts/signed_distance_transform.py" diff --git a/spimquant/workflow/scripts/clean_segmentation.py b/spimquant/workflow/scripts/clean_segmentation.py index 9e6523b..0d72c2f 100644 --- a/spimquant/workflow/scripts/clean_segmentation.py +++ b/spimquant/workflow/scripts/clean_segmentation.py @@ -15,31 +15,32 @@ from zarrnii import ZarrNii from zarrnii.plugins import SegmentationCleaner -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - - hires_level = int(snakemake.wildcards.level) - proc_level = int(snakemake.params.proc_level) - - znimg = ZarrNii.from_ome_zarr( - snakemake.input.mask, - level=0, # we load level 0 since we are already at the highres level - **snakemake.params.zarrnii_kwargs, - ) - - # perform cleaning of artifactual positives by - # removing objects with low extent (extent is ratio of num voxels to bounding box) - - # the downsample_factor we use should be proportional to the segmentation level - # e.g. if segmentation level is 3, then we have already downsampled by 2^3, so - # the downsample factor should be divided by that.. - unadjusted_downsample_factor = 2**proc_level - adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) - - znimg_cleaned = znimg.apply_scaled_processing( - SegmentationCleaner(max_extent=snakemake.params.max_extent), - downsample_factor=adjusted_downsample_factor, - upsampled_ome_zarr_path=snakemake.output.exclude_mask, - ) - - # write to final ome_zarr - znimg_cleaned.to_ome_zarr(snakemake.output.cleaned_mask, max_layer=5) +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + hires_level = int(snakemake.wildcards.level) + proc_level = int(snakemake.params.proc_level) + + znimg = ZarrNii.from_ome_zarr( + snakemake.input.mask, + level=0, # we load level 0 since we are already at the highres level + **snakemake.params.zarrnii_kwargs, + ) + + # perform cleaning of artifactual positives by + # removing objects with low extent (extent is ratio of num voxels to bounding box) + + # the downsample_factor we use should be proportional to the segmentation level + # e.g. if segmentation level is 3, then we have already downsampled by 2^3, so + # the downsample factor should be divided by that.. + unadjusted_downsample_factor = 2**proc_level + adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) + + znimg_cleaned = znimg.apply_scaled_processing( + SegmentationCleaner(max_extent=snakemake.params.max_extent), + downsample_factor=adjusted_downsample_factor, + upsampled_ome_zarr_path=snakemake.output.exclude_mask, + ) + + # write to final ome_zarr + znimg_cleaned.to_ome_zarr(snakemake.output.cleaned_mask, max_layer=5) diff --git a/spimquant/workflow/scripts/coloc_per_voxel_template.py b/spimquant/workflow/scripts/coloc_per_voxel_template.py index 47b206a..b3bf4b1 100644 --- a/spimquant/workflow/scripts/coloc_per_voxel_template.py +++ b/spimquant/workflow/scripts/coloc_per_voxel_template.py @@ -1,7 +1,6 @@ import numpy as np from zarrnii import ZarrNii, density_from_points from dask.diagnostics import ProgressBar -from dask_setup import get_dask_client import pandas as pd img = ZarrNii.from_nifti( @@ -11,12 +10,15 @@ if hasattr(snakemake.wildcards, "level"): img = img.downsample(level=int(snakemake.wildcards.level)) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - df = pd.read_parquet(snakemake.input.coloc_parquet) +df = pd.read_parquet(snakemake.input.coloc_parquet) - points = df[snakemake.params.coord_column_names].values +points = df[snakemake.params.coord_column_names].values - # Create counts map (zarrnii is calling this density right now).. - counts = density_from_points(points, img, in_physical_space=True) - with ProgressBar(): - counts.to_nifti(snakemake.output.counts_nii) +# Create counts map (zarrnii is calling this density right now).. +counts = density_from_points(points, img, in_physical_space=True) + +# force atlas units (until zarrnii issue #203 fixed): +counts.ngff_image.axes_units = {"x": "millimeter", "y": "millimeter", "z": "millimeter"} + +with ProgressBar(): + counts.to_nifti(snakemake.output.counts_nii) diff --git a/spimquant/workflow/scripts/compute_filtered_regionprops.py b/spimquant/workflow/scripts/compute_filtered_regionprops.py index da2a511..e85c03f 100644 --- a/spimquant/workflow/scripts/compute_filtered_regionprops.py +++ b/spimquant/workflow/scripts/compute_filtered_regionprops.py @@ -8,16 +8,17 @@ from dask_setup import get_dask_client from zarrnii import ZarrNii -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - znimg = ZarrNii.from_ome_zarr( - snakemake.input.mask, - level=0, # input image is already downsampled to the wildcard level - **snakemake.params.zarrnii_kwargs, - ) + znimg = ZarrNii.from_ome_zarr( + snakemake.input.mask, + level=0, # input image is already downsampled to the wildcard level + **snakemake.params.zarrnii_kwargs, + ) - znimg.compute_region_properties( - output_path=snakemake.output.regionprops_parquet, - region_filters=snakemake.params.region_filters, - output_properties=snakemake.params.output_properties, - ) + znimg.compute_region_properties( + output_path=snakemake.output.regionprops_parquet, + region_filters=snakemake.params.region_filters, + output_properties=snakemake.params.output_properties, + ) diff --git a/spimquant/workflow/scripts/convert_zarr_to_ozx.py b/spimquant/workflow/scripts/convert_zarr_to_ozx.py index f068e25..8673f43 100644 --- a/spimquant/workflow/scripts/convert_zarr_to_ozx.py +++ b/spimquant/workflow/scripts/convert_zarr_to_ozx.py @@ -1,12 +1,6 @@ -# import ngff_zarr as nz -from zarrnii import ZarrNii +from ngff_zarr.rfc9_zip import write_store_to_zip +from zarr.storage import LocalStore -# Read from directory store -# multiscales = nz.from_ngff_zarr(snakemake.input.zarr) - -# print(multiscales) -# Write as .ozx file -# nz.to_ngff_zarr(snakemake.output.ozx, multiscales, version='0.5') - -# note, this recomputes the multiscales -ZarrNii.from_ome_zarr(snakemake.input.zarr).to_ome_zarr(snakemake.output.ozx) +# Direct conversion of existing store to .ozx +source_store = LocalStore(snakemake.input.zarr) +write_store_to_zip(source_store, snakemake.output.ozx, version="0.5") diff --git a/spimquant/workflow/scripts/counts_per_voxel.py b/spimquant/workflow/scripts/counts_per_voxel.py index 7e5711e..10af002 100644 --- a/spimquant/workflow/scripts/counts_per_voxel.py +++ b/spimquant/workflow/scripts/counts_per_voxel.py @@ -7,19 +7,23 @@ level = int(snakemake.wildcards.level) stain = snakemake.wildcards.stain -img = ZarrNii.from_ome_zarr( - snakemake.input.ref_spim, - level=level, - channel_labels=[stain], - downsample_near_isotropic=True, -) +with get_dask_client("threads", snakemake.threads): + img = ZarrNii.from_ome_zarr( + snakemake.input.ref_spim, + level=level, + channel_labels=[stain], + downsample_near_isotropic=True, + ) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): df = pd.read_parquet(snakemake.input.regionprops_parquet) points = df[snakemake.params.coord_column_names].values - # Create counts map (zarrnii is calling this density right now).. - counts = density_from_points(points, img, in_physical_space=True) - with ProgressBar(): - counts.to_nifti(snakemake.output.counts_nii) +# Create counts map (zarrnii is calling this density right now).. +counts = density_from_points(points, img, in_physical_space=True) + +# force atlas units (until zarrnii issue #203 fixed): +counts.ngff_image.axes_units = {"x": "millimeter", "y": "millimeter", "z": "millimeter"} + +with ProgressBar(): + counts.to_nifti(snakemake.output.counts_nii) diff --git a/spimquant/workflow/scripts/counts_per_voxel_template.py b/spimquant/workflow/scripts/counts_per_voxel_template.py index d0db2c1..bf179e3 100644 --- a/spimquant/workflow/scripts/counts_per_voxel_template.py +++ b/spimquant/workflow/scripts/counts_per_voxel_template.py @@ -1,7 +1,7 @@ import numpy as np from zarrnii import ZarrNii, density_from_points -from dask.diagnostics import ProgressBar from dask_setup import get_dask_client +from dask.diagnostics import ProgressBar import pandas as pd stain = snakemake.wildcards.stain @@ -13,7 +13,7 @@ if hasattr(snakemake.wildcards, "level"): img = img.downsample(level=int(snakemake.wildcards.level)) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): df = pd.read_parquet(snakemake.input.regionprops_parquet) df = df[df["stain"] == snakemake.wildcards.stain] @@ -21,5 +21,13 @@ # Create counts map (zarrnii is calling this density right now).. counts = density_from_points(points, img, in_physical_space=True) + + # force atlas units (until zarrnii issue #203 fixed): + counts.ngff_image.axes_units = { + "x": "millimeter", + "y": "millimeter", + "z": "millimeter", + } + with ProgressBar(): counts.to_nifti(snakemake.output.counts_nii) diff --git a/spimquant/workflow/scripts/create_imaris_crops.py b/spimquant/workflow/scripts/create_imaris_crops.py index 4fac9a0..889f8fd 100644 --- a/spimquant/workflow/scripts/create_imaris_crops.py +++ b/spimquant/workflow/scripts/create_imaris_crops.py @@ -48,7 +48,7 @@ # Create output directory Path(output_dir).mkdir(parents=True, exist_ok=True) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): # Load the atlas with labels atlas = ZarrNiiAtlas.from_files( input_dseg, diff --git a/spimquant/workflow/scripts/create_patches.py b/spimquant/workflow/scripts/create_patches.py index f44b292..22d7ff1 100644 --- a/spimquant/workflow/scripts/create_patches.py +++ b/spimquant/workflow/scripts/create_patches.py @@ -57,7 +57,7 @@ # Create output directory Path(output_dir).mkdir(parents=True, exist_ok=True) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): # Load the atlas with labels atlas = ZarrNiiAtlas.from_files( input_dseg, diff --git a/spimquant/workflow/scripts/gaussian_biasfield.py b/spimquant/workflow/scripts/gaussian_biasfield.py index ef6cb9f..460d748 100644 --- a/spimquant/workflow/scripts/gaussian_biasfield.py +++ b/spimquant/workflow/scripts/gaussian_biasfield.py @@ -3,31 +3,32 @@ from zarrnii import ZarrNii from zarrnii.plugins import GaussianBiasFieldCorrection -hires_level = int(snakemake.wildcards.level) -proc_level = int(snakemake.params.proc_level) +if __name__ == "__main__": + hires_level = int(snakemake.wildcards.level) + proc_level = int(snakemake.params.proc_level) -unadjusted_downsample_factor = 2**proc_level -adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) + unadjusted_downsample_factor = 2**proc_level + adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - znimg = ZarrNii.from_ome_zarr( - snakemake.input.spim, - channel_labels=[snakemake.wildcards.stain], - level=hires_level, - downsample_near_isotropic=True, - **snakemake.params.zarrnii_kwargs, - ) + znimg = ZarrNii.from_ome_zarr( + snakemake.input.spim, + channel_labels=[snakemake.wildcards.stain], + level=hires_level, + downsample_near_isotropic=True, + **snakemake.params.zarrnii_kwargs, + ) - print("compute bias field correction") - with ProgressBar(): + print("compute bias field correction") + with ProgressBar(): - # Apply bias field correction - znimg_corrected = znimg.apply_scaled_processing( - GaussianBiasFieldCorrection(sigma=5.0), - downsample_factor=adjusted_downsample_factor, - upsampled_ome_zarr_path=snakemake.output.biasfield, - ) + # Apply bias field correction + znimg_corrected = znimg.apply_scaled_processing( + GaussianBiasFieldCorrection(sigma=5.0), + downsample_factor=adjusted_downsample_factor, + upsampled_ome_zarr_path=snakemake.output.biasfield, + ) - # write to ome_zarr - znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5) + # write to ome_zarr + znimg_corrected.to_ome_zarr(snakemake.output.corrected, max_layer=5) diff --git a/spimquant/workflow/scripts/map_tsv_dseg_to_nii.py b/spimquant/workflow/scripts/map_tsv_dseg_to_nii.py index 1d096a3..8dde5a1 100644 --- a/spimquant/workflow/scripts/map_tsv_dseg_to_nii.py +++ b/spimquant/workflow/scripts/map_tsv_dseg_to_nii.py @@ -11,4 +11,7 @@ label_column=snakemake.params.label_column, ) +# force atlas units (until zarrnii issue #203 fixed): +img.ngff_image.axes_units = {"x": "millimeter", "y": "millimeter", "z": "millimeter"} + img.to_nifti(snakemake.output.nii) diff --git a/spimquant/workflow/scripts/multiotsu.py b/spimquant/workflow/scripts/multiotsu.py index 6a15191..6b6c985 100644 --- a/spimquant/workflow/scripts/multiotsu.py +++ b/spimquant/workflow/scripts/multiotsu.py @@ -5,39 +5,40 @@ matplotlib.use("agg") -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - - # we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0 - znimg = ZarrNii.from_ome_zarr( - snakemake.input.corrected, **snakemake.params.zarrnii_kwargs - ) - - # first calculate histogram - using preset bins to avoid issues where bins are too large - # because of high intensity outliers - (hist_counts, bin_edges) = znimg.compute_histogram( - bins=snakemake.params.hist_bins, range=snakemake.params.hist_range - ) - - # get otsu thresholds (uses histogram) - print("computing thresholds") - (thresholds, fig) = compute_otsu_thresholds( - hist_counts, - classes=snakemake.params.otsu_k, - bin_edges=bin_edges, - return_figure=True, - ) - print(f" 📈 thresholds: {[f'{t:.3f}' for t in thresholds]}") - - fig.savefig(snakemake.output.thresholds_png) - - print("thresholding image, saving as ome zarr") - znimg_mask = znimg.segment_threshold( - thresholds[snakemake.params.otsu_threshold_index] - ) - - # multiplying binary mask by 100 (so values are 0 and 100) to enable - # field fraction calculation by subsequent local-mean downsampling - znimg_mask = znimg_mask * 100 - - # write to ome_zarr - znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + # we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0 + znimg = ZarrNii.from_ome_zarr( + snakemake.input.corrected, **snakemake.params.zarrnii_kwargs + ) + + # first calculate histogram - using preset bins to avoid issues where bins are too large + # because of high intensity outliers + (hist_counts, bin_edges) = znimg.compute_histogram( + bins=snakemake.params.hist_bins, range=snakemake.params.hist_range + ) + + # get otsu thresholds (uses histogram) + print("computing thresholds") + (thresholds, fig) = compute_otsu_thresholds( + hist_counts, + classes=snakemake.params.otsu_k, + bin_edges=bin_edges, + return_figure=True, + ) + print(f" 📈 thresholds: {[f'{t:.3f}' for t in thresholds]}") + + fig.savefig(snakemake.output.thresholds_png) + + print("thresholding image, saving as ome zarr") + znimg_mask = znimg.segment_threshold( + thresholds[snakemake.params.otsu_threshold_index] + ) + + # multiplying binary mask by 100 (so values are 0 and 100) to enable + # field fraction calculation by subsequent local-mean downsampling + znimg_mask = znimg_mask * 100 + + # write to ome_zarr + znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) diff --git a/spimquant/workflow/scripts/n4_biasfield.py b/spimquant/workflow/scripts/n4_biasfield.py index 148a7ff..cb853d0 100644 --- a/spimquant/workflow/scripts/n4_biasfield.py +++ b/spimquant/workflow/scripts/n4_biasfield.py @@ -1,32 +1,35 @@ -from dask_setup import get_dask_client -from zarrnii import ZarrNii -from zarrnii.plugins import N4BiasFieldCorrection -from dask.diagnostics import ProgressBar +if __name__ == "__main__": -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + from dask_setup import get_dask_client + from zarrnii import ZarrNii + from zarrnii.plugins import N4BiasFieldCorrection - hires_level = int(snakemake.wildcards.level) - proc_level = int(snakemake.params.proc_level) + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - unadjusted_downsample_factor = 2**proc_level - adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) + hires_level = int(snakemake.wildcards.level) + proc_level = int(snakemake.params.proc_level) - znimg = ZarrNii.from_ome_zarr( - snakemake.input.spim, - channel_labels=[snakemake.wildcards.stain], - level=hires_level, - downsample_near_isotropic=True, - **snakemake.params.zarrnii_kwargs, - ) + unadjusted_downsample_factor = 2**proc_level + adjusted_downsample_factor = unadjusted_downsample_factor / (2**hires_level) - print("compute bias field correction") + znimg = ZarrNii.from_ome_zarr( + snakemake.input.spim, + channel_labels=[snakemake.wildcards.stain], + level=hires_level, + downsample_near_isotropic=True, + **snakemake.params.zarrnii_kwargs, + ) + + print("compute bias field correction") + + adjusted_chunk = int( + snakemake.params.target_chunk_size / (2**adjusted_downsample_factor) + ) - with ProgressBar(): # Apply bias field correction znimg_corrected = znimg.apply_scaled_processing( - N4BiasFieldCorrection(), + N4BiasFieldCorrection(shrink_factor=snakemake.params.shrink_factor), downsample_factor=adjusted_downsample_factor, - upsampled_ome_zarr_path=snakemake.output.biasfield, ) # write to ome_zarr diff --git a/spimquant/workflow/scripts/ome_zarr_to_nii.py b/spimquant/workflow/scripts/ome_zarr_to_nii.py index c4c6e74..4869e38 100644 --- a/spimquant/workflow/scripts/ome_zarr_to_nii.py +++ b/spimquant/workflow/scripts/ome_zarr_to_nii.py @@ -2,7 +2,7 @@ from dask_setup import get_dask_client from zarrnii import ZarrNii -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): znimg = ZarrNii.from_ome_zarr( snakemake.input.spim, level=int(snakemake.wildcards.level), diff --git a/spimquant/workflow/scripts/signed_distance_transform.py b/spimquant/workflow/scripts/signed_distance_transform.py index d09946a..4c9fef4 100644 --- a/spimquant/workflow/scripts/signed_distance_transform.py +++ b/spimquant/workflow/scripts/signed_distance_transform.py @@ -21,97 +21,102 @@ import numpy as np from scipy.ndimage import distance_transform_edt from zarrnii import ZarrNii +from dask.diagnostics import ProgressBar from dask_setup import get_dask_client -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - - overlap_depth = snakemake.params.overlap_depth - - znimg = ZarrNii.from_ome_zarr( - snakemake.input.mask, **snakemake.params.zarrnii_kwargs - ) - - # Get physical voxel spacing from the ZarrNii scale metadata. - # znimg.scale is a dict keyed by dimension name (e.g. {'z': 0.004, 'y': 0.0027, 'x': 0.0027}). - # znimg.dims gives the ordered dimension names, e.g. ['c', 'z', 'y', 'x']. - scale = znimg.scale - _known_spatial = {"z", "y", "x"} - spatial_dims = [d for d in znimg.dims if d in _known_spatial] - spacing = [] - for d in spatial_dims: - s = scale.get(d) - if s is None: - import warnings - - warnings.warn( - f"Physical spacing for dimension '{d}' not found in OME-Zarr " - "metadata; defaulting to 1.0. Distance values may not be in " - "physical units.", - stacklevel=2, - ) - s = 1.0 - spacing.append(s) - spacing = np.array(spacing) - - # Calculate the maximum physical distance across a single block. - # This is used as the fill value for entirely-foreground or entirely-background - # blocks, where the true boundary lies outside the block. - # max() is used because the last chunk along each axis may be smaller than - # the others (when the array size is not divisible by the chunk size); the - # largest chunk determines the worst-case block diagonal. - spatial_chunk_indices = [znimg.dims.index(d) for d in spatial_dims] - block_size = np.array([max(znimg.darr.chunks[i]) for i in spatial_chunk_indices]) - max_dist = float(np.sqrt(np.sum((block_size * spacing) ** 2))) - - def signed_dt_block(block): - """Compute signed distance transform for a single block. - - Expects block shape (C, Z, Y, X). For each channel the Euclidean - distance transform is computed twice: - - dt_inside : distance (in physical units) from each foreground - voxel to the nearest background voxel. - - dt_outside: distance (in physical units) from each background - voxel to the nearest foreground voxel. - - The signed distance transform is dt_outside - dt_inside, giving - negative values inside the mask and positive values outside. - - Blocks that are entirely foreground are filled with -max_dist, and - blocks that are entirely background are filled with +max_dist. - """ - result = np.zeros(block.shape, dtype=np.float32) - for c in range(block.shape[0]): - binary = block[c] > 0 - n_fg = np.count_nonzero(binary) - n_total = binary.size - if n_fg == 0: - # All background: nearest foreground is at least one block away - result[c] = max_dist - elif n_fg == n_total: - # All foreground: nearest background is at least one block away - result[c] = -max_dist - else: - dt_inside = distance_transform_edt(binary, sampling=spacing).astype( - np.float32 +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + overlap_depth = snakemake.params.overlap_depth + + znimg = ZarrNii.from_ome_zarr( + snakemake.input.mask, **snakemake.params.zarrnii_kwargs + ) + + # Get physical voxel spacing from the ZarrNii scale metadata. + # znimg.scale is a dict keyed by dimension name (e.g. {'z': 0.004, 'y': 0.0027, 'x': 0.0027}). + # znimg.dims gives the ordered dimension names, e.g. ['c', 'z', 'y', 'x']. + scale = znimg.scale + _known_spatial = {"z", "y", "x"} + spatial_dims = [d for d in znimg.dims if d in _known_spatial] + spacing = [] + for d in spatial_dims: + s = scale.get(d) + if s is None: + import warnings + + warnings.warn( + f"Physical spacing for dimension '{d}' not found in OME-Zarr " + "metadata; defaulting to 1.0. Distance values may not be in " + "physical units.", + stacklevel=2, ) - dt_outside = distance_transform_edt(~binary, sampling=spacing).astype( - np.float32 - ) - result[c] = dt_outside - dt_inside - return result - - # depth=0 for the channel dimension, overlap_depth for spatial dims - depth = {0: 0, 1: overlap_depth, 2: overlap_depth, 3: overlap_depth} - - sdt_darr = da.map_overlap( - signed_dt_block, - znimg.darr, - depth=depth, - boundary=0, - dtype=np.float32, - ) - - znimg.darr = sdt_darr - - znimg.to_ome_zarr(snakemake.output.dist, max_layer=5) + s = 1.0 + spacing.append(s) + spacing = np.array(spacing) + + # Calculate the maximum physical distance across a single block. + # This is used as the fill value for entirely-foreground or entirely-background + # blocks, where the true boundary lies outside the block. + # max() is used because the last chunk along each axis may be smaller than + # the others (when the array size is not divisible by the chunk size); the + # largest chunk determines the worst-case block diagonal. + spatial_chunk_indices = [znimg.dims.index(d) for d in spatial_dims] + block_size = np.array( + [max(znimg.darr.chunks[i]) for i in spatial_chunk_indices] + ) + max_dist = float(np.sqrt(np.sum((block_size * spacing) ** 2))) + + def signed_dt_block(block): + """Compute signed distance transform for a single block. + + Expects block shape (C, Z, Y, X). For each channel the Euclidean + distance transform is computed twice: + - dt_inside : distance (in physical units) from each foreground + voxel to the nearest background voxel. + - dt_outside: distance (in physical units) from each background + voxel to the nearest foreground voxel. + + The signed distance transform is dt_outside - dt_inside, giving + negative values inside the mask and positive values outside. + + Blocks that are entirely foreground are filled with -max_dist, and + blocks that are entirely background are filled with +max_dist. + """ + result = np.zeros(block.shape, dtype=np.float32) + for c in range(block.shape[0]): + binary = block[c] > 0 + n_fg = np.count_nonzero(binary) + n_total = binary.size + if n_fg == 0: + # All background: nearest foreground is at least one block away + result[c] = max_dist + elif n_fg == n_total: + # All foreground: nearest background is at least one block away + result[c] = -max_dist + else: + dt_inside = distance_transform_edt(binary, sampling=spacing).astype( + np.float32 + ) + dt_outside = distance_transform_edt( + ~binary, sampling=spacing + ).astype(np.float32) + result[c] = dt_outside - dt_inside + return result + + # depth=0 for the channel dimension, overlap_depth for spatial dims + depth = {0: 0, 1: overlap_depth, 2: overlap_depth, 3: overlap_depth} + + sdt_darr = da.map_overlap( + signed_dt_block, + znimg.darr, + depth=depth, + boundary=0, + dtype=np.float32, + ) + + znimg.darr = sdt_darr + + with ProgressBar(): + znimg.to_ome_zarr(snakemake.output.dist, max_layer=5) diff --git a/spimquant/workflow/scripts/threshold.py b/spimquant/workflow/scripts/threshold.py index 57593e8..4ba122c 100644 --- a/spimquant/workflow/scripts/threshold.py +++ b/spimquant/workflow/scripts/threshold.py @@ -1,18 +1,19 @@ from dask_setup import get_dask_client from zarrnii import ZarrNii -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - znimg_hires = ZarrNii.from_ome_zarr( - snakemake.input.corrected, **snakemake.params.zarrnii_kwargs - ) + znimg_hires = ZarrNii.from_ome_zarr( + snakemake.input.corrected, **snakemake.params.zarrnii_kwargs + ) - print("thresholding image, saving as ome zarr") - znimg_mask = znimg_hires.segment_threshold(snakemake.params.threshold) + print("thresholding image, saving as ome zarr") + znimg_mask = znimg_hires.segment_threshold(snakemake.params.threshold) - # multiplying binary mask by 100 (so values are 0 and 100) to enable - # field fraction calculation by subsequent local-mean downsampling - znimg_mask = znimg_mask * 100 + # multiplying binary mask by 100 (so values are 0 and 100) to enable + # field fraction calculation by subsequent local-mean downsampling + znimg_mask = znimg_mask * 100 - # write to ome_zarr - znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) + # write to ome_zarr + znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5) diff --git a/spimquant/workflow/scripts/vesselfm.py b/spimquant/workflow/scripts/vesselfm.py index eb37b16..caccece 100644 --- a/spimquant/workflow/scripts/vesselfm.py +++ b/spimquant/workflow/scripts/vesselfm.py @@ -4,11 +4,12 @@ from dask.diagnostics import ProgressBar from dask_setup import get_dask_client -with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): +with get_dask_client("threads", snakemake.threads): znimg = ZarrNii.from_ome_zarr( snakemake.input.spim, level=int(snakemake.wildcards.level), channel_labels=[snakemake.wildcards.stain], + downsample_near_isotropic=True, **snakemake.params.zarrnii_kwargs, ) znimg_mask = znimg.segment(VesselFMPlugin, **snakemake.params.vesselfm_kwargs)