diff --git a/spimquant/config/snakebids.yml b/spimquant/config/snakebids.yml index 77e89b3..551c40d 100644 --- a/spimquant/config/snakebids.yml +++ b/spimquant/config/snakebids.yml @@ -275,6 +275,11 @@ templates: YoPro: https://zenodo.org/records/18749025/files/tpl-ABAv3_level-5_stain-YoPro_SPIM.nii.gz Iba1: https://zenodo.org/records/18749025/files/tpl-ABAv3_level-5_stain-Iba1_SPIM.nii.gz Abeta: https://zenodo.org/records/18749025/files/tpl-ABAv3_level-5_stain-Abeta_SPIM.nii.gz + default_segs: + - all + - coarse + - mid + - fine atlases: all: dseg: https://zenodo.org/records/18906782/files/tpl-ABAv3_seg-all_dseg.nii.gz diff --git a/spimquant/workflow/Snakefile b/spimquant/workflow/Snakefile index b83ae82..3e8b5fb 100644 --- a/spimquant/workflow/Snakefile +++ b/spimquant/workflow/Snakefile @@ -55,10 +55,13 @@ stains = get_stains_all_subjects() stain_for_reg = None +template = config["templates"][config["template"]] + + # first, check if there are any SPIM templates defined # for the stains we have use_spim_template = False -spim_templates = config["templates"][config["template"]].get("spim_templates", None) +spim_templates = template.get("spim_templates", None) if spim_templates is not None: for stain in spim_templates.keys(): @@ -103,16 +106,15 @@ else: # atlas segmentations to use -all_atlas_segs = config["templates"][config["template"]]["atlases"].keys() if config["atlas_segs"] is None: - atlas_segs = all_atlas_segs + atlas_segs = template.get("default_segs", template["atlases"].keys()) else: atlas_segs = [] for seg in config["atlas_segs"]: - if seg not in all_atlas_segs: + if seg not in template["atlases"]: raise ValueError( - f"Chosen segmentation {seg} was not found in the template {config['template']}" + f"Chosen segmentation {seg} was not found in the template {template}" ) else: atlas_segs.append(seg) @@ -120,9 +122,9 @@ else: # atlas segmentations to use for patches (defaults to roi22) patch_atlas_segs = [] for seg in config["patch_atlas_segs"]: - if seg not in all_atlas_segs: + if seg not in template["atlases"]: raise ValueError( - f"Chosen patch segmentation {seg} was not found in the template {config['template']}" + f"Chosen patch segmentation {seg} was not found in the template {template}" ) else: patch_atlas_segs.append(seg) @@ -130,9 +132,9 @@ for seg in config["patch_atlas_segs"]: # atlas segmentations to use for Imaris crops (defaults to roi22) crop_atlas_segs = [] for seg in config["crop_atlas_segs"]: - if seg not in all_atlas_segs: + if seg not in template["atlases"]: raise ValueError( - f"Chosen crop segmentation {seg} was not found in the template {config['template']}" + f"Chosen crop segmentation {seg} was not found in the template {template}" ) else: crop_atlas_segs.append(seg) @@ -632,6 +634,137 @@ rule all_group_stats_coloc: ), +rule all_qc: + """Target rule for subject-level visual QC outputs.""" + input: + # Intensity histograms for every stain channel + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + stain="{stain}", + suffix="histogram.png", + **inputs["spim"].wildcards, + ), + stain=stains, + ), + # Segmentation overview figures (per stain, per seg method) + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="segslices.png", + **inputs["spim"].wildcards, + ), + stain=stains_for_seg, + desc=config["seg_method"], + ) + if do_seg + else [], + # Segmentation ROI zoom montage (per stain, per atlas, per seg method) + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + seg="{seg}", + from_="{template}", + stain="{stain}", + desc="{desc}", + suffix="roimontage.png", + **inputs["spim"].wildcards, + ), + seg=atlas_segs, + template=config["template"], + stain=stains_for_seg, + desc=config["seg_method"], + ) + if do_seg + else [], + # Vessel overview figures + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="vesselslices.png", + **inputs["spim"].wildcards, + ), + stain=stain_for_vessels, + desc=config["vessel_seg_method"], + ) + if do_vessels + else [], + # Vessel ROI zoom montage (per atlas) + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + seg="{seg}", + from_="{template}", + stain="{stain}", + desc="{desc}", + suffix="vesselroimontage.png", + **inputs["spim"].wildcards, + ), + seg=atlas_segs, + template=config["template"], + stain=stain_for_vessels, + desc=config["vessel_seg_method"], + ) + if do_vessels + else [], + # Z-profile QC (per stain, per seg method) + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="zprofile.png", + **inputs["spim"].wildcards, + ), + stain=stains_for_seg, + desc=config["seg_method"], + ) + if do_seg + else [], + # Object-level statistics (per stain, per seg method) + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="objectstats.png", + **inputs["spim"].wildcards, + ), + stain=stains_for_seg, + desc=config["seg_method"], + ) + if do_seg + else [], + # Per-ROI summary (per atlas seg, per seg method) + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + seg="{seg}", + from_="{template}", + desc="{desc}", + suffix="roisummary.png", + **inputs["spim"].wildcards, + ), + seg=atlas_segs, + template=config["template"], + desc=config["seg_method"], + ) + if do_seg + else [], + + rule all_participant: default_target: True input: @@ -640,6 +773,7 @@ rule all_participant: rules.all_segment.input if do_seg else [], rules.all_mri_reg.input if config["register_to_mri"] else [], rules.all_segment_coloc.input if do_coloc else [], + rules.all_qc.input, rule all_group: @@ -660,6 +794,7 @@ include: "rules/regionprops.smk" include: "rules/segstats.smk" include: "rules/patches.smk" include: "rules/groupstats.smk" +include: "rules/qc.smk" if config["register_to_mri"]: diff --git a/spimquant/workflow/rules/qc.smk b/spimquant/workflow/rules/qc.smk new file mode 100644 index 0000000..882fd26 --- /dev/null +++ b/spimquant/workflow/rules/qc.smk @@ -0,0 +1,372 @@ +"""Visual QC outputs for SPIMquant, generated at subject/channel level. + +This module produces PNG quality-control figures covering: + +1. Raw intensity / histogram plots + - Per-channel linear and log-scale intensity histograms + - Cumulative distribution and saturation/clip fraction + +2. Segmentation overview figures + - Slice montages (axial, coronal, sagittal) with field-fraction overlay, + aspect ratio corrected from NIfTI voxel dimensions + - Max-intensity projection (MIP) with overlay + - Zoomed ROI montage: per-atlas-region crops with overlay detail + +3. Spatial QC / coverage + - Z-profile of mean signal intensity and segmented field fraction + +4. Object-level summaries + - Volume distribution, log-volume distribution, equivalent-radius + distribution and summary statistics for detected objects + +5. Per-ROI summaries (subject level) + - Top-regions bar plots for field fraction, count, and density + +Rules 2 and the ROI zoom are also generated for vessel segmentations. + +All outputs are written to the ``qc`` datatype directory for each subject. +""" + + +rule qc_intensity_histogram: + """Per-channel intensity histogram QC. + +Reads the raw OME-Zarr at the registration downsampling level and +generates a four-panel figure: linear histogram, log-scale histogram, +cumulative distribution, and a summary-statistics panel including the +saturation/clip fraction (percentage of voxels at the maximum bin). +""" + input: + spim=inputs["spim"].path, + output: + png=bids( + root=root, + datatype="qc", + stain="{stain}", + suffix="histogram.png", + **inputs["spim"].wildcards, + ), + threads: 4 + resources: + mem_mb=16000, + runtime=30, + params: + level=config["registration_level"], + hist_bins=500, + hist_range=[0, 65535], + zarrnii_kwargs={"orientation": config["orientation"]}, + script: + "../scripts/qc_intensity_histogram.py" + + +rule qc_segmentation_overview: + """Segmentation overview slice montage QC. + +Loads SPIM data via ZarrNii (``downsample_near_isotropic=True``) and the +raw binary segmentation mask at the corresponding pyramid level. Displays +sample slices in axial, coronal, and sagittal orientations with the mask +overlay, and a max-intensity projection column for each orientation. +Aspect ratio is corrected using voxel spacings from ``ZarrNii.get_zooms()``. +""" + input: + spim=inputs["spim"].path, + mask=bids( + root=root, + datatype="seg", + stain="{stain}", + level=config["segmentation_level"], + desc="{desc}", + suffix="mask.ozx", + **inputs["spim"].wildcards, + ), + output: + png=bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="segslices.png", + **inputs["spim"].wildcards, + ), + threads: 4 + resources: + mem_mb=16000, + runtime=30, + params: + level=config["registration_level"], + mask_level=config["registration_level"] - config["segmentation_level"], + zarrnii_kwargs={"orientation": config["orientation"]}, + script: + "../scripts/qc_segmentation_overview.py" + + +rule qc_vessels_overview: + """Vessel segmentation overview slice montage QC. + +Identical visualisation to ``qc_segmentation_overview`` but applied to the +vessel binary mask. Loads data via ZarrNii with ``downsample_near_isotropic`` +for isotropic display and physically correct aspect ratio. +""" + input: + spim=inputs["spim"].path, + mask=bids( + root=root, + datatype="vessels", + stain="{stain}", + level=config["segmentation_level"], + desc="{desc}", + suffix="mask.ozx", + **inputs["spim"].wildcards, + ), + output: + png=bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="vesselslices.png", + **inputs["spim"].wildcards, + ), + threads: 4 + resources: + mem_mb=16000, + runtime=30, + params: + level=config["registration_level"], + mask_level=config["registration_level"] - config["segmentation_level"], + zarrnii_kwargs={"orientation": config["orientation"]}, + script: + "../scripts/qc_segmentation_overview.py" + + +rule qc_segmentation_roi_zoom: + """Zoomed ROI montage QC for segmentation. + +Crops the SPIM image and segmentation field-fraction mask to each atlas +region's bounding box (in subject space) and displays the best axial slice +with the field-fraction overlay. Aspect ratio is corrected from NIfTI +voxel dimensions. Provides detail-level visualisation of segmentation +quality within individual brain regions. +""" + input: + spim=inputs["spim"].path, + mask=bids( + root=root, + datatype="seg", + stain="{stain}", + level=config["segmentation_level"], + desc="{desc}", + suffix="mask.ozx", + **inputs["spim"].wildcards, + ), + dseg_nii=bids( + root=root, + datatype="parc", + seg="{seg}", + level=config["registration_level"], + from_="{template}", + suffix="dseg.nii.gz", + **inputs["spim"].wildcards, + ), + label_tsv=bids( + root=root, + template="{template}", + seg="{seg}", + suffix="dseg.tsv", + ), + output: + png=bids( + root=root, + datatype="qc", + seg="{seg}", + from_="{template}", + stain="{stain}", + desc="{desc}", + suffix="roimontage.png", + **inputs["spim"].wildcards, + ), + threads: 4 + resources: + mem_mb=32000, + runtime=15, + params: + max_rois=lambda wildcards: 25 if wildcards.seg == "coarse" else 100, + n_cols=lambda wildcards: 5 if wildcards.seg == "coarse" else 10, + patch_size=lambda wildcards: 2000 if wildcards.seg == "coarse" else 500, + level=config["segmentation_level"], + script: + "../scripts/qc_segmentation_roi_zoom.py" + + +rule qc_vessels_roi_zoom: + """Zoomed ROI montage QC for vessel segmentation. + +Identical to ``qc_segmentation_roi_zoom`` but applied to the vessel +binary mask. Uses ZarrNii to load full-resolution data and +ZarrNiiAtlas for atlas-based ROI cropping. +""" + input: + spim=inputs["spim"].path, + mask=bids( + root=root, + datatype="vessels", + stain="{stain}", + level=config["segmentation_level"], + desc="{desc}", + suffix="mask.ozx", + **inputs["spim"].wildcards, + ), + dseg_nii=bids( + root=root, + datatype="parc", + seg="{seg}", + level=config["registration_level"], + from_="{template}", + suffix="dseg.nii.gz", + **inputs["spim"].wildcards, + ), + label_tsv=bids( + root=root, + template="{template}", + seg="{seg}", + suffix="dseg.tsv", + ), + output: + png=bids( + root=root, + datatype="qc", + seg="{seg}", + from_="{template}", + stain="{stain}", + desc="{desc}", + suffix="vesselroimontage.png", + **inputs["spim"].wildcards, + ), + threads: 4 + resources: + mem_mb=32000, + runtime=15, + params: + max_rois=lambda wildcards: 25 if wildcards.seg == "coarse" else 100, + n_cols=lambda wildcards: 5 if wildcards.seg == "coarse" else 10, + patch_size=lambda wildcards: 2000 if wildcards.seg == "coarse" else 500, + level=config["segmentation_level"], + script: + "../scripts/qc_segmentation_roi_zoom.py" + + +rule qc_zprofile: + """Z-profile QC: per-slice signal intensity and segmented fraction. + +Plots the mean signal intensity (with ±1 SD band) and mean field fraction +across Z-slices. Reveals depth-dependent artefacts such as striping, +illumination fall-off, or uneven staining. +""" + input: + spim=bids( + root=root, + datatype="micr", + stain="{stain}", + level=config["registration_level"], + suffix="SPIM.nii.gz", + **inputs["spim"].wildcards, + ), + fieldfrac=bids( + root=root, + datatype="seg", + stain="{stain}", + level=config["registration_level"], + desc="{desc}", + suffix="fieldfrac.nii.gz", + **inputs["spim"].wildcards, + ), + output: + png=bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="zprofile.png", + **inputs["spim"].wildcards, + ), + threads: 1 + resources: + mem_mb=4000, + runtime=10, + script: + "../scripts/qc_zprofile.py" + + +rule qc_objectstats: + """Object-level statistics QC. + +Loads the aggregated region-properties parquet (all stains combined) and +plots volume distribution, log-volume distribution, equivalent spherical +radius distribution, and a summary-statistics panel for the stain +specified by the ``{stain}`` wildcard. +""" + input: + regionprops=bids( + root=root, + datatype="tabular", + desc="{desc}", + space=config["template"], + suffix="regionprops.parquet", + **inputs["spim"].wildcards, + ), + output: + png=bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="objectstats.png", + **inputs["spim"].wildcards, + ), + threads: 1 + resources: + mem_mb=4000, + runtime=10, + script: + "../scripts/qc_objectstats.py" + + +rule qc_roi_summary: + """Per-ROI summary QC: top-region bar plots for a single subject. + +Reads the merged segmentation-statistics TSV (all stains) and the atlas +label table, then produces horizontal bar charts of the top brain regions +ranked by field fraction, object count, and density for every stain. +""" + input: + segstats=bids( + root=root, + datatype="tabular", + seg="{seg}", + from_="{template}", + desc="{desc}", + suffix="mergedsegstats.tsv", + **inputs["spim"].wildcards, + ), + label_tsv=bids( + root=root, + template="{template}", + seg="{seg}", + suffix="dseg.tsv", + ), + output: + png=bids( + root=root, + datatype="qc", + seg="{seg}", + from_="{template}", + desc="{desc}", + suffix="roisummary.png", + **inputs["spim"].wildcards, + ), + threads: 1 + resources: + mem_mb=4000, + runtime=10, + script: + "../scripts/qc_roi_summary.py" diff --git a/spimquant/workflow/scripts/qc_intensity_histogram.py b/spimquant/workflow/scripts/qc_intensity_histogram.py new file mode 100644 index 0000000..7d4f457 --- /dev/null +++ b/spimquant/workflow/scripts/qc_intensity_histogram.py @@ -0,0 +1,163 @@ +"""Per-channel intensity histogram QC for SPIM data. + +Generates linear and log-scale histograms, a cumulative distribution, +saturation/clip fraction, and summary statistics for a single stain channel. + +This is a Snakemake script that expects the ``snakemake`` object to be +available, which is automatically provided when executed as part of a +Snakemake workflow. +""" + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import numpy as np + +from dask_setup import get_dask_client +from zarrnii import ZarrNii + + +def main(): + stain = snakemake.wildcards.stain + level = snakemake.params.level + hist_bins = snakemake.params.hist_bins + hist_range = snakemake.params.hist_range + + with get_dask_client("threads", snakemake.threads): + znimg = ZarrNii.from_ome_zarr( + snakemake.input.spim, + level=level, + channel_labels=[stain], + **snakemake.params.zarrnii_kwargs, + ) + hist_counts, bin_edges = znimg.compute_histogram( + bins=hist_bins, + range=hist_range, + ) + + hist_counts = np.asarray(hist_counts, dtype=float) + bin_edges = np.asarray(bin_edges, dtype=float) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + bin_width = bin_edges[1] - bin_edges[0] + + total_voxels = hist_counts.sum() + max_range = hist_range[1] + + # Determine effective display max (last bin with data, with 5 % headroom) + nonzero_mask = hist_counts > 0 + if nonzero_mask.any(): + disp_max = float(bin_centers[nonzero_mask][-1]) * 1.05 + else: + disp_max = max_range + sat_fraction = ( + float(hist_counts[-1]) / total_voxels * 100 if total_voxels > 0 else 0.0 + ) + + # Summary statistics derived from histogram + if total_voxels > 0: + mean_val = float(np.sum(bin_centers * hist_counts) / total_voxels) + cumsum_norm = np.cumsum(hist_counts) / total_voxels + p50_val = float( + bin_centers[min(np.searchsorted(cumsum_norm, 0.50), len(bin_centers) - 1)] + ) + p99_val = float( + bin_centers[min(np.searchsorted(cumsum_norm, 0.99), len(bin_centers) - 1)] + ) + else: + mean_val = p50_val = p99_val = 0.0 + + # Percentile-based display bounds for the linear-scale histogram panel + # X: cap at the 99th percentile value (+ 5 % headroom) to avoid long empty tails + lin_xlim = p99_val * 1.05 if total_voxels > 0 else max_range + # Y: cap at the tallest bar in the visible x range (+ 5 % headroom) so the + # body of the distribution is visible rather than dominated by a background spike + visible = hist_counts[bin_centers <= lin_xlim] + lin_ylim = ( + float(visible.max()) * 1.05 if visible.size and visible.max() > 0 else 1.0 + ) + + subject = snakemake.wildcards.subject + + fig, axes = plt.subplots(2, 2, figsize=(12, 8)) + fig.suptitle( + f"Intensity Histogram QC\nSubject: {subject} | Stain: {stain}", + fontsize=13, + fontweight="bold", + ) + + # Panel 1: linear-scale histogram + ax = axes[0, 0] + ax.bar(bin_centers, hist_counts, width=bin_width, color="steelblue", alpha=0.75) + ax.set_xlabel("Intensity") + ax.set_ylabel("Voxel count") + ax.set_title("Linear-scale histogram") + ax.set_xlim(0, lin_xlim) + ax.set_ylim(0, lin_ylim) + + # Panel 2: log-scale histogram + ax = axes[0, 1] + log_counts = np.where(hist_counts > 0, np.log10(hist_counts), np.nan) + ax.bar(bin_centers, log_counts, width=bin_width, color="darkorange", alpha=0.75) + ax.set_xlabel("Intensity") + ax.set_ylabel("log\u2081\u2080(voxel count)") + ax.set_title("Log-scale histogram") + ax.set_xlim(0, disp_max) + + # Panel 3: cumulative distribution + ax = axes[1, 0] + if total_voxels > 0: + cumsum_pct = cumsum_norm * 100 + ax.plot(bin_centers, cumsum_pct, color="forestgreen", lw=1.5) + ax.axvline( + x=p50_val, + color="purple", + linestyle="--", + alpha=0.7, + label=f"Median ({p50_val:.1f})", + ) + ax.axvline( + x=p99_val, + color="red", + linestyle="--", + alpha=0.7, + label=f"99th pctile ({p99_val:.1f})", + ) + ax.legend(fontsize=8) + ax.set_xlabel("Intensity") + ax.set_ylabel("Cumulative voxels (%)") + ax.set_title("Cumulative distribution") + ax.set_ylim(0, 105) + ax.set_xlim(0, disp_max) + + # Panel 4: summary statistics + ax = axes[1, 1] + ax.axis("off") + summary_text = ( + f"Total voxels: {int(total_voxels):>14,}\n" + f"Mean intensity: {mean_val:>14.2f}\n" + f"Median (50th): {p50_val:>14.2f}\n" + f"99th percentile: {p99_val:>14.2f}\n" + f"Max range: {max_range:>14.1f}\n" + f"Saturation frac.: {sat_fraction:>13.3f}%" + ) + ax.text( + 0.1, + 0.55, + summary_text, + transform=ax.transAxes, + fontsize=11, + verticalalignment="center", + fontfamily="monospace", + bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8), + ) + ax.set_title("Summary statistics") + + plt.tight_layout() + plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved intensity histogram QC to {snakemake.output.png}") + + +if __name__ == "__main__": + main() diff --git a/spimquant/workflow/scripts/qc_objectstats.py b/spimquant/workflow/scripts/qc_objectstats.py new file mode 100644 index 0000000..91cf4d1 --- /dev/null +++ b/spimquant/workflow/scripts/qc_objectstats.py @@ -0,0 +1,150 @@ +"""Object-level statistics QC: distributions of detected objects. + +Plots count, volume/size distribution, and equivalent-radius distribution +for detected objects (plaques, cells, etc.) from segmentation region +properties. Objects are filtered to the stain specified by the ``stain`` +wildcard. + +This is a Snakemake script that expects the ``snakemake`` object to be +available, which is automatically provided when executed as part of a +Snakemake workflow. +""" + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +def main(): + stain = snakemake.wildcards.stain + desc = snakemake.wildcards.desc + subject = snakemake.wildcards.subject + + df = pd.read_parquet(snakemake.input.regionprops) + + # Filter to the requested stain (the aggregated parquet has a 'stain' column) + if "stain" in df.columns: + df = df[df["stain"] == stain].copy() + + n_objects = len(df) + + fig, axes = plt.subplots(2, 2, figsize=(12, 8)) + fig.suptitle( + f"Object Statistics QC\n" + f"Subject: {subject} | Stain: {stain} | Method: {desc} | " + f"Total objects: {n_objects:,}", + fontsize=12, + fontweight="bold", + ) + + # If no objects were detected, show an empty panel + if n_objects == 0: + for ax in axes.flat: + ax.text( + 0.5, + 0.5, + "No objects detected", + transform=ax.transAxes, + ha="center", + va="center", + fontsize=14, + color="gray", + ) + ax.axis("off") + plt.tight_layout() + plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight") + plt.close() + return + + has_nvoxels = "nvoxels" in df.columns + + # --- Panel 1: object size (nvoxels) linear scale --- + ax = axes[0, 0] + if has_nvoxels: + nvoxels = df["nvoxels"].values.astype(float) + ax.hist(nvoxels, bins=50, color="steelblue", alpha=0.75, edgecolor="white") + ax.set_xlabel("Volume (voxels)") + ax.set_ylabel("Count") + ax.set_title("Object size distribution") + else: + ax.text( + 0.5, + 0.5, + "nvoxels not available", + ha="center", + va="center", + transform=ax.transAxes, + color="gray", + ) + ax.axis("off") + + # --- Panel 2: object size log scale --- + ax = axes[0, 1] + if has_nvoxels: + ax.hist( + np.log10(nvoxels + 1), + bins=50, + color="darkorange", + alpha=0.75, + edgecolor="white", + ) + ax.set_xlabel("log\u2081\u2080(volume in voxels + 1)") + ax.set_ylabel("Count") + ax.set_title("Object size distribution (log scale)") + else: + ax.axis("off") + + # --- Panel 3: equivalent spherical radius --- + ax = axes[1, 0] + if has_nvoxels: + # Radius distribution. + # Assumes objects are approximately spherical and voxels are isotropic. + # Formula: r = (3V / 4π)^(1/3), where V is volume in voxels. + # For anisotropic voxels this is an approximation; physical radius + # would require multiplying by the voxel size in each dimension. + radii = (3.0 * nvoxels / (4.0 * np.pi)) ** (1.0 / 3.0) + ax.hist(radii, bins=50, color="forestgreen", alpha=0.75, edgecolor="white") + ax.set_xlabel("Equivalent radius (voxels)") + ax.set_ylabel("Count") + ax.set_title("Object radius distribution (spherical approximation)") + else: + ax.axis("off") + + # --- Panel 4: summary statistics --- + ax = axes[1, 1] + ax.axis("off") + if has_nvoxels: + summary_text = ( + f"Object count: {n_objects:>12,}\n" + f"Mean volume: {np.mean(nvoxels):>12.1f} vx\n" + f"Median volume: {np.median(nvoxels):>12.1f} vx\n" + f"Max volume: {np.max(nvoxels):>12.1f} vx\n" + f"Total volume: {np.sum(nvoxels):>12.0f} vx\n" + f"Mean radius: {np.mean(radii):>12.2f} vx\n" + f"Median radius: {np.median(radii):>12.2f} vx" + ) + else: + summary_text = f"Object count: {n_objects:,}" + ax.text( + 0.1, + 0.55, + summary_text, + transform=ax.transAxes, + fontsize=11, + verticalalignment="center", + fontfamily="monospace", + bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8), + ) + ax.set_title("Summary statistics") + + plt.tight_layout() + plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved object statistics QC to {snakemake.output.png}") + + +if __name__ == "__main__": + main() diff --git a/spimquant/workflow/scripts/qc_roi_summary.py b/spimquant/workflow/scripts/qc_roi_summary.py new file mode 100644 index 0000000..52271ec --- /dev/null +++ b/spimquant/workflow/scripts/qc_roi_summary.py @@ -0,0 +1,141 @@ +"""Per-ROI summary QC: top-region bar plots for a single subject. + +Loads the merged segmentation-statistics TSV (all stains) together with +the atlas label table, then produces bar-chart visualisations of the +top brain-regions ranked by field fraction and density for each stain. + +This is a Snakemake script that expects the ``snakemake`` object to be +available, which is automatically provided when executed as part of a +Snakemake workflow. +""" + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# Suffixes used to identify stain-prefixed metric columns in mergedsegstats TSV. +# Columns follow the pattern "{stain}+{metric}", e.g. "Abeta+fieldfrac". +_SUFFIX_FIELDFRAC = "+fieldfrac" +_SUFFIX_DENSITY = "+density" + + +def _top_regions(df, col, n=20, ascending=False): + """Return the top *n* rows of *df* sorted by *col*.""" + valid = df[df[col].notna() & (df[col] != 0)] + return valid.nlargest(n, col) if not ascending else valid.nsmallest(n, col) + + +def _bar_plot(ax, names, values, title, xlabel, color="steelblue"): + """Draw a horizontal bar chart on *ax*.""" + y_pos = np.arange(len(names)) + ax.barh(y_pos, values, color=color, alpha=0.8, edgecolor="white") + ax.set_yticks(y_pos) + ax.set_yticklabels(names, fontsize=8) + ax.invert_yaxis() # highest value at top + ax.set_xlabel(xlabel, fontsize=9) + ax.set_title(title, fontsize=10) + ax.grid(True, axis="x", alpha=0.3) + + +def main(): + desc = snakemake.wildcards.desc + seg = snakemake.wildcards.seg + template = snakemake.wildcards.template + subject = snakemake.wildcards.subject + + stats_df = pd.read_csv(snakemake.input.segstats, sep="\t") + label_df = pd.read_csv(snakemake.input.label_tsv, sep="\t") + + # Drop background (atlas label 0) — those voxels are outside the brain + if "index" in stats_df.columns: + stats_df = stats_df[stats_df["index"] != 0].copy() + + # Merge region names (label_df has 'index' and 'name' columns) + if "name" not in stats_df.columns and "index" in stats_df.columns: + stats_df = stats_df.merge(label_df[["index", "name"]], on="index", how="left") + + region_name_col = "name" if "name" in stats_df.columns else "index" + + # Identify stain-prefixed metric columns (pattern: "{stain}+{metric}") + ff_cols = [c for c in stats_df.columns if c.endswith(_SUFFIX_FIELDFRAC)] + density_cols = [c for c in stats_df.columns if c.endswith(_SUFFIX_DENSITY)] + + # Determine number of rows: 1 row per metric type (ff, density) + # with one subplot per stain within each row + n_ff = len(ff_cols) + n_density = len(density_cols) + n_rows = (1 if n_ff else 0) + (1 if n_density else 0) + + if n_rows == 0: + fig, ax = plt.subplots(figsize=(8, 4)) + ax.text( + 0.5, + 0.5, + "No stain-prefixed metric columns found in segstats TSV", + ha="center", + va="center", + fontsize=12, + color="gray", + transform=ax.transAxes, + ) + ax.axis("off") + plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight") + plt.close() + return + + n_top = 20 + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + + row_specs = [] + if n_ff: + row_specs.append(("Field Fraction (%)", ff_cols, "steelblue")) + if n_density: + row_specs.append(("Density (objects/vol)", density_cols, "forestgreen")) + + max_stains = max(len(cols) for _, cols, _ in row_specs) + fig_width = max(10, max_stains * 6) + fig_height = n_rows * 6 + + fig, axes = plt.subplots( + n_rows, max_stains, figsize=(fig_width, fig_height), squeeze=False + ) + fig.suptitle( + f"Per-ROI Summary QC\n" + f"Subject: {subject} | Atlas: {seg} | Template: {template} | " + f"Method: {desc}", + fontsize=12, + fontweight="bold", + ) + + for row_idx, (metric_label, metric_cols, base_color) in enumerate(row_specs): + for col_idx in range(max_stains): + ax = axes[row_idx, col_idx] + if col_idx >= len(metric_cols): + ax.axis("off") + continue + col = metric_cols[col_idx] + stain_label = col.split("+")[0] + top = _top_regions(stats_df, col, n=n_top) + names = top[region_name_col].astype(str).tolist() + values = top[col].values + color = colors[col_idx % len(colors)] + _bar_plot( + ax, + names, + values, + title=f"Top {n_top} regions — {stain_label}", + xlabel=metric_label, + color=color, + ) + + plt.tight_layout() + plt.savefig(snakemake.output.png, dpi=120, bbox_inches="tight") + plt.close() + print(f"Saved per-ROI summary QC to {snakemake.output.png}") + + +if __name__ == "__main__": + main() diff --git a/spimquant/workflow/scripts/qc_segmentation_overview.py b/spimquant/workflow/scripts/qc_segmentation_overview.py new file mode 100644 index 0000000..4885b49 --- /dev/null +++ b/spimquant/workflow/scripts/qc_segmentation_overview.py @@ -0,0 +1,146 @@ +"""Segmentation / vessel overview QC: whole-brain slice montage with mask overlay. + +Generates a multi-panel figure with sample slices in three anatomical +orientations (axial, coronal, sagittal), each with the binary segmentation or +vessel mask overlaid on the SPIM background image, plus a max-intensity +projection (MIP) column for each orientation. + +Data are loaded via ZarrNii so that the correct physical resolution and +aspect ratio are applied automatically. The SPIM is loaded with +``downsample_near_isotropic=True`` to obtain near-isotropic voxels. +Voxel spacings are read with ``get_zooms()`` and used to compute the correct +``imshow`` aspect ratio for each panel. + +This is a Snakemake script that expects the ``snakemake`` object to be +available, which is automatically provided when executed as part of a +Snakemake workflow. +""" + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import numpy as np +from scipy.ndimage import zoom +from zarrnii import ZarrNii + + +def _percentile_norm(arr, pct_low=1, pct_high=99): + """Percentile-normalise *arr* to the range [0, 1].""" + lo = np.percentile(arr, pct_low) + hi = np.percentile(arr, pct_high) + if hi > lo: + return np.clip((arr.astype(float) - lo) / (hi - lo), 0.0, 1.0) + return np.zeros_like(arr, dtype=float) + + +def _sample_slices(size, n=5): + """Return *n* evenly-spaced slice indices in the central 60 % of *size*.""" + start = int(size * 0.2) + stop = int(size * 0.8) + return np.linspace(start, stop - 1, n, dtype=int) + + +def _slice_aspect(zooms, step_axis): + """Compute imshow aspect ratio for a slice through *step_axis*. + + ZarrNii data is indexed (x, y, z) and ``np.rot90`` is applied before + display, so the displayed image rows and columns are: + + - step_axis=2 (Z-slice): rot90 of data[:, :, z] → rows=y, cols=x → dy/dx + - step_axis=1 (Y-slice): rot90 of data[:, y, :] → rows=z, cols=x → dz/dx + - step_axis=0 (X-slice): rot90 of data[x, :, :] → rows=z, cols=y → dz/dy + """ + dx, dy, dz = float(zooms[0]), float(zooms[1]), float(zooms[2]) + if step_axis == 2: + return dy / dx + if step_axis == 1: + return dz / dx + return dz / dy # step_axis == 0 + + +def main(): + stain = snakemake.wildcards.stain + desc = snakemake.wildcards.desc + subject = snakemake.wildcards.subject + + spim_img = ZarrNii.from_ome_zarr( + snakemake.input.spim, + level=snakemake.params.level, + downsample_near_isotropic=True, + channel_labels=[stain], + **snakemake.params.zarrnii_kwargs, + ) + mask_img = ZarrNii.from_ome_zarr( + snakemake.input.mask, + level=snakemake.params.mask_level, + downsample_near_isotropic=True, + **snakemake.params.zarrnii_kwargs, + ) + + # Voxel dimensions (mm) for physical aspect-ratio correction + zooms = spim_img.get_zooms() + + spim_data = spim_img.data[0].compute() # (X, Y, Z) + mask_data = mask_img.data[0].compute() # (X, Y, Z), values 0–100 + + # Bring mask to the same grid as SPIM if needed + if mask_data.shape != spim_data.shape: + factors = [t / s for t, s in zip(spim_data.shape, mask_data.shape)] + mask_data = zoom(mask_data, factors, order=1) + + spim_norm = _percentile_norm(spim_data) + # Mask values are 0–100 (field-fraction percent); normalise to 0–1 for display + mask_norm = np.clip(mask_data / 100.0, 0.0, 1.0) + + n_slices = 5 + orient_labels = ["Axial (Z)", "Coronal (Y)", "Sagittal (X)"] + # axis along which we step through slices: 2 = Z, 1 = Y, 0 = X + step_axes = [2, 1, 0] + + fig, axes = plt.subplots(3, n_slices + 1, figsize=(18, 9), constrained_layout=True) + fig.suptitle( + f"Segmentation Overview QC\n" + f"Subject: {subject} | Stain: {stain} | Method: {desc}", + fontsize=12, + fontweight="bold", + ) + + for row, (orient_name, ax_idx) in enumerate(zip(orient_labels, step_axes)): + aspect = _slice_aspect(zooms, ax_idx) + slice_indices = _sample_slices(spim_data.shape[ax_idx], n_slices) + + for col, sl in enumerate(slice_indices): + idx = [slice(None)] * 3 + idx[ax_idx] = int(sl) + spim_sl = spim_norm[tuple(idx)] + mask_sl = mask_norm[tuple(idx)] + + ax = axes[row, col] + ax.imshow(spim_sl, cmap="gray", vmin=0, vmax=1, aspect=aspect) + mask_masked = np.ma.masked_where(mask_sl < 0.01, mask_sl) + ax.imshow(mask_masked, cmap="hot", alpha=0.6, vmin=0, vmax=1, aspect=aspect) + ax.set_xticks([]) + ax.set_yticks([]) + if col == 0: + ax.set_ylabel(orient_name, fontsize=9) + ax.set_title(f"sl {sl}", fontsize=7) + + # MIP column (last column) + ax = axes[row, n_slices] + mip_spim = np.rot90(np.max(spim_norm, axis=ax_idx)) + mip_mask = np.rot90(np.max(mask_norm, axis=ax_idx)) + ax.imshow(mip_spim, cmap="gray", vmin=0, vmax=1, aspect=aspect) + mip_masked = np.ma.masked_where(mip_mask < 0.01, mip_mask) + ax.imshow(mip_masked, cmap="hot", alpha=0.6, vmin=0, vmax=1, aspect=aspect) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title("MIP", fontsize=9) + + plt.savefig(snakemake.output.png, dpi=120, bbox_inches="tight") + plt.close() + print(f"Saved segmentation overview QC to {snakemake.output.png}") + + +if __name__ == "__main__": + main() diff --git a/spimquant/workflow/scripts/qc_segmentation_roi_zoom.py b/spimquant/workflow/scripts/qc_segmentation_roi_zoom.py new file mode 100644 index 0000000..5f17cb9 --- /dev/null +++ b/spimquant/workflow/scripts/qc_segmentation_roi_zoom.py @@ -0,0 +1,172 @@ +"""ROI-cropped segmentation montage QC. + +For each brain region in the atlas parcellation (resampled to subject space), +crops the SPIM image and the segmentation mask to a fixed 2D bounding box +at the region centroid. This provides a detail-level view of segmentation +quality within individual brain regions, complementing the whole-brain +overview in ``qc_segmentation_overview``. + + +This is a Snakemake script that expects the ``snakemake`` object to be +available, which is automatically provided when executed as part of a +Snakemake workflow. +""" + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +from zarrnii import ZarrNii, ZarrNiiAtlas +import nibabel as nib +import numpy as np +import pandas as pd +from scipy.ndimage import zoom + + +def _estimate_global_percentiles( + arr, + pct_low=1, + pct_high=99, +): + """ + Estimate percentile normalization bounds from an image + """ + lo = np.percentile(arr, pct_low) + hi = np.percentile(arr, pct_high) + return float(lo), float(hi) + + +def _apply_fixed_percentile_norm(arr, lo, hi): + if hi > lo: + return np.clip((arr.astype(np.float32) - lo) / (hi - lo), 0.0, 1.0) + return np.zeros_like(arr, dtype=np.float32) + + +def main(): + stain = snakemake.wildcards.stain + desc = snakemake.wildcards.desc + subject = snakemake.wildcards.subject + max_rois = snakemake.params.max_rois + n_cols = snakemake.params.n_cols + + spim_img = ZarrNii.from_ome_zarr( + snakemake.input.spim, + level=snakemake.params.level, + downsample_near_isotropic=True, + channel_labels=[snakemake.wildcards.stain], + ) + mask_img = ZarrNii.from_ome_zarr(snakemake.input.mask, level=0) + + atlas = ZarrNiiAtlas.from_files(snakemake.input.dseg_nii, snakemake.input.label_tsv) + + dseg_data = atlas.dseg.data.compute() + + # Voxel dimensions (mm) for physical aspect-ratio correction - not implemented yet + # but should be easy with ZarrNii image .scale + aspect_axial = 1 + + spim_img_ds = ZarrNii.from_ome_zarr( + snakemake.input.spim, + level=(int(snakemake.params.level) + 5), + downsample_near_isotropic=True, + channel_labels=[snakemake.wildcards.stain], + ) + + # estimate once globally, from a coarse version of the full image + glob_lo, glob_hi = _estimate_global_percentiles( + spim_img_ds.data.compute(), pct_low=1, pct_high=99 + ) + + # Load atlas label table + label_df = atlas.labels_df + + # Keep non-background labels that are present in this subject's dseg + present_ids = set(np.unique(dseg_data)) - {0} + roi_rows = [ + row + for _, row in label_df[label_df["index"] > 0].iterrows() + if int(row["index"]) in present_ids + ][:max_rois] + + n_rois = len(roi_rows) + + if n_rois == 0: + fig, ax = plt.subplots(figsize=(18, 12)) + ax.text( + 0.5, + 0.5, + "No atlas ROIs found in subject", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="gray", + ) + ax.axis("off") + plt.savefig(snakemake.output.png, dpi=120, bbox_inches="tight") + plt.close() + return + + n_rows = int(np.ceil(n_rois / n_cols)) + fig, axes = plt.subplots( + n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3), constrained_layout=True + ) + fig.suptitle( + f"ROI Zoom Montage QC\n" + f"Subject: {subject} | Stain: {stain} | Method: {desc}", + fontsize=11, + fontweight="bold", + ) + + # Normalise axes array to always be 2-D + if n_rows == 1 and n_cols == 1: + axes = np.array([[axes]]) + elif n_rows == 1: + axes = axes[np.newaxis, :] + elif n_cols == 1: + axes = axes[:, np.newaxis] + + for i, row in enumerate(roi_rows): + ax_row = i // n_cols + ax_col = i % n_cols + ax = axes[ax_row, ax_col] + + label_id = int(row["index"]) + label_name = str(row.get("name", label_id)) + + # get cropped images for this label + bbox_min, bbox_max = atlas.get_region_bounding_box(region_ids=label_id) + center_coord = tuple((x + y) / 2 for x, y in zip(bbox_min, bbox_max)) + spim_crop = spim_img.crop_centered( + center_coord, + patch_size=(snakemake.params.patch_size, snakemake.params.patch_size, 1), + ) + mask_crop = mask_img.crop_centered( + center_coord, + patch_size=(snakemake.params.patch_size, snakemake.params.patch_size, 1), + ) + + spim_sl = spim_crop.data[0, :, :].squeeze().compute() + spim_sl = _apply_fixed_percentile_norm(spim_sl, glob_lo, glob_hi) + mask_sl = mask_crop.data[0, :, :].squeeze().compute() + + ax.imshow(spim_sl, cmap="gray") + mask_masked = np.ma.masked_where(mask_sl < 100, mask_sl) + ax.imshow( + mask_masked, cmap="spring", alpha=0.6, vmin=0, vmax=100, aspect=aspect_axial + ) + ax.set_title(label_name, fontsize=7, pad=2) + ax.set_xticks([]) + ax.set_yticks([]) + + # Hide unused axes + for i in range(n_rois, n_rows * n_cols): + axes[i // n_cols, i % n_cols].axis("off") + + plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved ROI zoom montage to {snakemake.output.png}") + + +if __name__ == "__main__": + main() diff --git a/spimquant/workflow/scripts/qc_zprofile.py b/spimquant/workflow/scripts/qc_zprofile.py new file mode 100644 index 0000000..0cedeee --- /dev/null +++ b/spimquant/workflow/scripts/qc_zprofile.py @@ -0,0 +1,96 @@ +"""Z-profile QC: per-slice signal intensity and segmented fraction. + +Plots the mean signal intensity and mean field fraction (segmented area) +across Z-slices, revealing depth-dependent artefacts, striping, or uneven +staining/illumination. + +This is a Snakemake script that expects the ``snakemake`` object to be +available, which is automatically provided when executed as part of a +Snakemake workflow. +""" + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np + + +def _per_slice_stats(data, axis=2): + """Return mean and std of *data* for each index along *axis*.""" + # Compute reduction axes: all axes except the slice axis + reduce_axes = tuple(i for i in range(data.ndim) if i != axis) + means = np.mean(data, axis=reduce_axes) + stds = np.std(data, axis=reduce_axes) + return means, stds + + +def _resample_to_length(arr, target_len): + """Linearly interpolate *arr* to *target_len* samples.""" + if len(arr) == target_len: + return arr + src = np.linspace(0, 1, len(arr)) + dst = np.linspace(0, 1, target_len) + return np.interp(dst, src, arr) + + +def main(): + stain = snakemake.wildcards.stain + desc = snakemake.wildcards.desc + subject = snakemake.wildcards.subject + + spim_data = nib.load(snakemake.input.spim).get_fdata() + ff_data = nib.load(snakemake.input.fieldfrac).get_fdata() + + # Compute per-z-slice statistics (NIfTI convention: axis 2 = Z) + spim_mean_z, spim_std_z = _per_slice_stats(spim_data, axis=2) + ff_mean_z, _ = _per_slice_stats(ff_data, axis=2) + + n_z = len(spim_mean_z) + z_idx = np.arange(n_z) + + # Resample field-fraction profile to match SPIM Z length if needed + ff_mean_z = _resample_to_length(ff_mean_z, n_z) + + fig, axes = plt.subplots(2, 1, figsize=(12, 7), sharex=True) + fig.suptitle( + f"Z-Profile QC\nSubject: {subject} | Stain: {stain} | Method: {desc}", + fontsize=12, + fontweight="bold", + ) + + # Panel 1: mean intensity per slice + ax = axes[0] + ax.plot(z_idx, spim_mean_z, color="steelblue", lw=1.5, label="Mean intensity") + ax.fill_between( + z_idx, + np.maximum(spim_mean_z - spim_std_z, 0), + spim_mean_z + spim_std_z, + alpha=0.25, + color="steelblue", + label="\u00b1 1 SD", + ) + ax.set_ylabel("Mean intensity (a.u.)") + ax.set_title("Mean signal intensity per Z-slice") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # Panel 2: mean field fraction per slice + ax = axes[1] + ax.plot(z_idx, ff_mean_z, color="darkorange", lw=1.5, label="Mean field fraction") + ax.fill_between(z_idx, 0, ff_mean_z, alpha=0.25, color="darkorange") + ax.set_xlabel("Z-slice index") + ax.set_ylabel("Mean field fraction (%)") + ax.set_title("Mean segmented field fraction per Z-slice") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved Z-profile QC to {snakemake.output.png}") + + +if __name__ == "__main__": + main()