diff --git a/spimquant/config/snakebids.yml b/spimquant/config/snakebids.yml index a004784..795b33d 100644 --- a/spimquant/config/snakebids.yml +++ b/spimquant/config/snakebids.yml @@ -142,16 +142,18 @@ parse_args: - otsu+k3i2 nargs: '+' - --seg_hist_range: - help: "Range of intensities to use for histogram calculation in multiotsu segmentation. Only applicable when seg_method is otsu+k{}i{}. Specify 2 numbers, for min and max values. (default: %(default)s)" - default: - - 0 - - 1000 + --seg_hist_percentile_range: + help: "Percentile range to use for histogram calculation in multiotsu segmentation. Only applicable when seg_method is otsu+k{}i{}. Specify 2 numbers for the low and high percentiles (default: %(default)s)" + default: + - 1 + - 99 nargs: 2 + type: float - --seg_hist_bins: - help: "Number of bins to use for histogram calculation in multiotsu segmentation. Only applicable when seg_method is otsu+k{}i{}. (default: %(default)s)" - default: 1000 + --seg_hist_bin_width: + help: "Bin width to use for histogram calculation in multiotsu segmentation. Only applicable when seg_method is otsu+k{}i{}. (default: %(default)s)" + default: 1 + type: float --register_to_mri: diff --git a/spimquant/workflow/Snakefile b/spimquant/workflow/Snakefile index b9fc1ca..bec70e2 100644 --- a/spimquant/workflow/Snakefile +++ b/spimquant/workflow/Snakefile @@ -82,6 +82,9 @@ if stain_for_reg is None: stains_for_seg = list(set(config["stains_for_seg"]).intersection(set(stains))) +# seg methods that use multi-Otsu thresholding (otsu+k{}i{} pattern) +otsu_seg_methods = [m for m in config["seg_method"] if m.startswith("otsu+")] + if len(stains_for_seg) == 0 or config["no_segmentation"]: do_seg = False do_coloc = False @@ -223,6 +226,32 @@ rule all_fieldfrac_tune: ), +rule all_otsu_hist_qc: + """Target rule to generate otsu threshold sweep QC HTML reports. + +Produces one HTML report per subject/stain/method combination for every +otsu+k{}i{} segmentation method. Each report shows the multi-Otsu +histogram alongside 2D crops at a sweep of threshold values so that the +optimal threshold can be identified visually. Only meaningful when the +active ``seg_method`` contains at least one ``otsu+k{}i{}`` entry. +""" + input: + inputs["spim"].expand( + bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="otsuthreshqc.html", + **inputs["spim"].wildcards, + ), + stain=stains_for_seg, + desc=otsu_seg_methods, + ) + if (do_seg and len(otsu_seg_methods) > 0) + else [], + + rule all_register: input: inputs["spim"].expand( diff --git a/spimquant/workflow/rules/qc.smk b/spimquant/workflow/rules/qc.smk index 79dabf7..259c6ce 100644 --- a/spimquant/workflow/rules/qc.smk +++ b/spimquant/workflow/rules/qc.smk @@ -396,6 +396,62 @@ specified by the ``{stain}`` wildcard. "../scripts/qc_objectstats.py" +rule qc_otsu_threshold_sweep: + """Threshold sweep QC HTML report for multiotsu segmentation. + +Sweeps over a range of threshold values (spanning the configurable +percentile range of the bias-field corrected image) and generates 2D +crops at evenly-spaced positions in the image, one figure per threshold +value. The otsu histogram PNG produced by the ``multiotsu`` rule is +embedded at the top of the report. The resulting HTML report can be +visually assessed to select the optimal threshold before running the full +segmentation pipeline. + +Only applicable when ``seg_method`` uses the ``otsu+k{}i{}`` pattern. +""" + input: + corrected=bids( + root=work, + datatype="seg", + stain="{stain}", + level=str(config["segmentation_level"]), + desc="corrected{method}".format(method=config["correction_method"]), + suffix="SPIM.ome.zarr", + **inputs["spim"].wildcards, + ), + thresholds_png=bids( + root=root, + datatype="seg", + stain="{stain}", + level=str(config["segmentation_level"]), + desc="{desc}", + suffix="thresholds.png", + **inputs["spim"].wildcards, + ), + output: + html=bids( + root=root, + datatype="qc", + stain="{stain}", + desc="{desc}", + suffix="otsuthreshqc.html", + **inputs["spim"].wildcards, + ), + threads: 4 + resources: + mem_mb=32000, + runtime=30, + params: + n_thresholds=10, + n_crops=5, + patch_size=300, + level=config["segmentation_level"], + hist_percentile_range=[float(x) for x in config["seg_hist_percentile_range"]], + zarrnii_kwargs={"orientation": config["orientation"]}, + script: + "../scripts/qc_otsu_threshold_sweep.py" + + rule qc_roi_summary: """Per-ROI summary QC: top-region bar plots for a single subject. diff --git a/spimquant/workflow/rules/segmentation.smk b/spimquant/workflow/rules/segmentation.smk index 9b13252..70904b6 100644 --- a/spimquant/workflow/rules/segmentation.smk +++ b/spimquant/workflow/rules/segmentation.smk @@ -119,8 +119,8 @@ rule multiotsu: **inputs["spim"].wildcards, ), params: - hist_bins=int(config["seg_hist_bins"]), - hist_range=[int(x) for x in config["seg_hist_range"]], + hist_bin_width=float(config["seg_hist_bin_width"]), + hist_percentile_range=[float(x) for x in config["seg_hist_percentile_range"]], otsu_k=lambda wildcards: int(wildcards.k), otsu_threshold_index=lambda wildcards: int(wildcards.i), zarrnii_kwargs={"orientation": config["orientation"]}, diff --git a/spimquant/workflow/scripts/multiotsu.py b/spimquant/workflow/scripts/multiotsu.py index 6b6c985..8f85c7d 100644 --- a/spimquant/workflow/scripts/multiotsu.py +++ b/spimquant/workflow/scripts/multiotsu.py @@ -1,3 +1,5 @@ +import numpy as np + from dask_setup import get_dask_client from zarrnii import ZarrNii from zarrnii.analysis import compute_otsu_thresholds @@ -8,15 +10,45 @@ if __name__ == "__main__": with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): - # we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0 - znimg = ZarrNii.from_ome_zarr( - snakemake.input.corrected, **snakemake.params.zarrnii_kwargs + zarrnii_kwargs = snakemake.params.zarrnii_kwargs + pct_lo, pct_hi = snakemake.params.hist_percentile_range + bin_width = snakemake.params.hist_bin_width + + # load a downsampled version to estimate the percentile-based range + print("estimating intensity range from downsampled image...") + znimg_ds = None + for ds_level in [5, 4, 3, 2, 1]: + try: + candidate = ZarrNii.from_ome_zarr( + snakemake.input.corrected, level=ds_level, **zarrnii_kwargs + ) + znimg_ds = candidate + break + except Exception: + pass + + if znimg_ds is None: + znimg_ds = ZarrNii.from_ome_zarr( + snakemake.input.corrected, **zarrnii_kwargs + ) + + data_ds = znimg_ds.data.compute().ravel().astype(np.float32) + range_lo = float(np.percentile(data_ds, pct_lo)) + range_hi = float(np.percentile(data_ds, pct_hi)) + print( + f" πŸ“Š percentile range [{pct_lo}%, {pct_hi}%]: [{range_lo:.3f}, {range_hi:.3f}]" ) - # first calculate histogram - using preset bins to avoid issues where bins are too large - # because of high intensity outliers + # compute number of bins from bin width + n_bins = max(2, int(np.ceil((range_hi - range_lo) / bin_width))) + print(f" πŸ“Š bins: {n_bins} (bin width: {bin_width})") + + # we use the default level=0, since we are reading in the n4 output, which is already downsampled if level was >0 + znimg = ZarrNii.from_ome_zarr(snakemake.input.corrected, **zarrnii_kwargs) + + # calculate histogram using percentile-based range and bin-width-derived bin count (hist_counts, bin_edges) = znimg.compute_histogram( - bins=snakemake.params.hist_bins, range=snakemake.params.hist_range + bins=n_bins, range=[range_lo, range_hi] ) # get otsu thresholds (uses histogram) diff --git a/spimquant/workflow/scripts/qc_otsu_threshold_sweep.py b/spimquant/workflow/scripts/qc_otsu_threshold_sweep.py new file mode 100644 index 0000000..5790d66 --- /dev/null +++ b/spimquant/workflow/scripts/qc_otsu_threshold_sweep.py @@ -0,0 +1,262 @@ +"""Otsu threshold sweep QC report. + +For a given stain and multiotsu segmentation method, sweeps over a range of +threshold values (spanning the 1st–99th percentile of the image) and generates +2D crops at multiple positions, producing a self-contained HTML report. The +report can be visually assessed to select the optimal threshold before running +the full segmentation pipeline. + +The otsu histogram PNG figures produced by the ``multiotsu`` rule are embedded +in the report alongside the sweep visualizations. + +This is a Snakemake script; the ``snakemake`` object is automatically provided +when executed as part of a Snakemake workflow. +""" + +import base64 +from io import BytesIO + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import numpy as np + +from zarrnii import ZarrNii + + +def _fig_to_base64(fig): + """Convert a matplotlib figure to a base64-encoded PNG data URI.""" + buf = BytesIO() + fig.savefig(buf, format="png", dpi=100, bbox_inches="tight") + buf.seek(0) + b64 = base64.b64encode(buf.read()).decode("utf-8") + plt.close(fig) + return f"data:image/png;base64,{b64}" + + +def _file_to_base64_png(path): + """Read a PNG file and return a base64 data URI.""" + with open(path, "rb") as fh: + b64 = base64.b64encode(fh.read()).decode("utf-8") + return f"data:image/png;base64,{b64}" + + +def _norm(arr, lo, hi): + """Linearly normalise *arr* to [0, 1] using the given bounds.""" + if hi > lo: + return np.clip((arr.astype(np.float32) - lo) / (hi - lo), 0.0, 1.0) + return np.zeros_like(arr, dtype=np.float32) + + +def main(): + zarrnii_kwargs = snakemake.params.zarrnii_kwargs + n_thresholds = snakemake.params.n_thresholds + n_crops = snakemake.params.n_crops + patch_size = snakemake.params.patch_size + level = snakemake.params.level + pct_lo, pct_hi = snakemake.params.hist_percentile_range + + stain = snakemake.wildcards.stain + desc = snakemake.wildcards.desc + subject = snakemake.wildcards.subject + + # ------------------------------------------------------------------ + # Load a downsampled pyramid level for speed + # ------------------------------------------------------------------ + print("Loading image for threshold sweep QC...") + znimg = None + for ds_offset in [3, 2, 1, 0]: + try: + candidate = ZarrNii.from_ome_zarr( + snakemake.input.corrected, level=level + ds_offset, **zarrnii_kwargs + ) + znimg = candidate + print(f" Loaded at pyramid level {level + ds_offset}") + break + except Exception as exc: + print(f" Level {level + ds_offset} not available: {exc}") + + if znimg is None: + znimg = ZarrNii.from_ome_zarr(snakemake.input.corrected, **zarrnii_kwargs) + + # ------------------------------------------------------------------ + # Get array data β€” shape is (C, Z, Y, X) or (Z, Y, X) + # ------------------------------------------------------------------ + data = znimg.data.compute() + if data.ndim == 4: + data = data[0] # drop channel dim β†’ (Z, Y, X) + + # ------------------------------------------------------------------ + # Compute percentile-based range for the sweep + # ------------------------------------------------------------------ + flat = data.ravel().astype(np.float32) + range_lo = float(np.percentile(flat, pct_lo)) + range_hi = float(np.percentile(flat, pct_hi)) + print( + f" Sweep range [{pct_lo}th–{pct_hi}th percentile]: " + f"[{range_lo:.3f}, {range_hi:.3f}]" + ) + + # ------------------------------------------------------------------ + # Generate evenly-spaced threshold values across the percentile range + # ------------------------------------------------------------------ + thresholds = np.linspace(range_lo, range_hi, n_thresholds) + + # ------------------------------------------------------------------ + # Select crop positions: evenly spaced along Y at mid-Z + # ------------------------------------------------------------------ + Z, Y, X = data.shape + z_mid = Z // 2 + half_patch = patch_size // 2 + + crop_boxes = [] + for i in range(n_crops): + y_pos = int(Y * (i + 0.5) / n_crops) + x_pos = X // 2 + y0 = max(0, y_pos - half_patch) + y1 = min(Y, y_pos + half_patch) + x0 = max(0, x_pos - half_patch) + x1 = min(X, x_pos + half_patch) + crop_boxes.append((z_mid, y0, y1, x0, x1)) + + # Pre-normalise the raw image crops (shared across all thresholds) + img_crops_norm = [ + _norm(data[z, y0:y1, x0:x1], range_lo, range_hi) + for (z, y0, y1, x0, x1) in crop_boxes + ] + + # ------------------------------------------------------------------ + # Build a figure for each threshold value + # ------------------------------------------------------------------ + threshold_entries = [] + for thresh in thresholds: + fig, axes = plt.subplots(1, n_crops, figsize=(n_crops * 3, 3)) + fig.suptitle(f"Threshold = {thresh:.1f}", fontsize=10, fontweight="bold") + if n_crops == 1: + axes = [axes] + + for ax, (z, y0, y1, x0, x1), crop_norm in zip(axes, crop_boxes, img_crops_norm): + mask_crop = (data[z, y0:y1, x0:x1] > thresh).astype(np.float32) + ax.imshow(crop_norm, cmap="gray") + mask_ma = np.ma.masked_where(mask_crop < 0.5, mask_crop) + ax.imshow(mask_ma, cmap="Reds", alpha=0.6, vmin=0, vmax=1) + ax.set_xticks([]) + ax.set_yticks([]) + + plt.tight_layout() + threshold_entries.append( + {"threshold": float(thresh), "b64": _fig_to_base64(fig)} + ) + + # ------------------------------------------------------------------ + # Embed the otsu histogram PNG from the multiotsu rule + # ------------------------------------------------------------------ + otsu_hist_b64 = _file_to_base64_png(snakemake.input.thresholds_png) + + # ------------------------------------------------------------------ + # Generate HTML report + # ------------------------------------------------------------------ + sweep_html_parts = [] + for entry in threshold_entries: + thresh_label = f"{entry['threshold']:.1f}" + sweep_html_parts.append( + f'
' + f"

Threshold = {thresh_label}

" + f'threshold {thresh_label}' + f"
" + ) + sweep_html = "\n".join(sweep_html_parts) + + html = f""" + + + + + Otsu Threshold Sweep QC – {subject} + + + +
+

Otsu Threshold Sweep QC Report

+ +
+

Subject: {subject}

+

Stain: {stain}

+

Method: {desc}

+

Intensity range ({pct_lo}th–{pct_hi}th percentile): + {range_lo:.1f} – {range_hi:.1f}

+

Number of threshold values: {n_thresholds}

+
+ +

Otsu Histogram (Multi-level Otsu Thresholds)

+

The histogram below shows the multi-Otsu threshold positions computed + from the bias-field corrected image. Use these as a guide when selecting + the optimal threshold from the sweep below.

+
+ Otsu Histogram +
+ +

Threshold Sweep (2D Crops)

+

Each row shows {n_crops} axial crops at different Y positions (mid-Z + slice) with the resulting binary mask overlaid in red. Use this to select + a threshold that captures the signal of interest without excessive + background segmentation.

+ + {sweep_html} + +
+ +""" + + with open(snakemake.output.html, "w") as fh: + fh.write(html) + + print(f"Saved threshold sweep QC to {snakemake.output.html}") + + +if __name__ == "__main__": + main()