diff --git a/hs2p/api.py b/hs2p/api.py index 8b44188..299d808 100644 --- a/hs2p/api.py +++ b/hs2p/api.py @@ -24,7 +24,6 @@ ) from hs2p.configs.resolvers import build_default_sampling_spec from hs2p.progress import emit_progress, emit_progress_log -from hs2p.stderr_utils import run_with_filtered_stderr from hs2p.wsi import ( CoordinateOutputMode, CoordinateSelectionStrategy, @@ -33,7 +32,7 @@ overlay_mask_on_slide as _overlay_mask_on_slide, write_coordinate_preview, ) -from hs2p.wsi.backend import resolve_backend +from hs2p.wsi.backend import coerce_wsd_path, resolve_backend @dataclass(frozen=True) @@ -645,7 +644,7 @@ def _iter_cucim_tile_arrays_for_tar_extraction( if result.backend != "cucim": return None try: - cucim = run_with_filtered_stderr(lambda: importlib.import_module("cucim")) + cucim = importlib.import_module("cucim") except ModuleNotFoundError: warnings.warn( "CuCIM is unavailable for backend='cucim'; falling back to sequential wholeslidedata tile extraction.", @@ -695,7 +694,10 @@ def _iter_wsd_tile_arrays_for_tar_extraction( ): import wholeslidedata as wsd - wsi = wsd.WholeSlideImage(result.image_path, backend=result.backend) + wsi = wsd.WholeSlideImage( + coerce_wsd_path(result.image_path, backend=result.backend), + backend=result.backend, + ) read_step_px = _resolve_read_step_px(result) step_px_lv0 = _resolve_step_px_lv0(result) for read_plan in _iter_grouped_read_plans_for_tar_extraction( diff --git a/hs2p/wsi/backend.py b/hs2p/wsi/backend.py index 5c7fc1f..6ba95df 100644 --- a/hs2p/wsi/backend.py +++ b/hs2p/wsi/backend.py @@ -24,6 +24,17 @@ def _normalize_path(path: Path | None) -> str | None: return str(Path(path)) +def coerce_wsd_path(path: Path | str, *, backend: str) -> Path | str: + """Return a path object compatible with the requested WSD backend. + + CuCIM-backed WSD opens require plain strings, while the other backends + accept pathlib objects. + """ + if backend == "cucim": + return str(path) + return Path(path) + + def _is_cucim_supported_format(wsi_path: Path) -> bool: suffix = wsi_path.suffix.lower() return suffix in CUCIM_SUPPORTED_SUFFIXES @@ -37,9 +48,12 @@ def _backend_can_open_slide( backend: str, ) -> bool: try: - wsd.WholeSlideImage(Path(wsi_path), backend=backend) + wsd.WholeSlideImage(coerce_wsd_path(wsi_path, backend=backend), backend=backend) if mask_path is not None: - wsd.WholeSlideImage(Path(mask_path), backend=backend) + wsd.WholeSlideImage( + coerce_wsd_path(mask_path, backend=backend), + backend=backend, + ) return True except Exception: return False diff --git a/hs2p/wsi/wsi.py b/hs2p/wsi/wsi.py index a71465a..357fd7e 100644 --- a/hs2p/wsi/wsi.py +++ b/hs2p/wsi/wsi.py @@ -11,7 +11,7 @@ from PIL import Image from hs2p.configs import FilterConfig, SegmentationConfig, TilingConfig -from hs2p.wsi.backend import resolve_backend +from hs2p.wsi.backend import coerce_wsd_path, resolve_backend from hs2p.wsi.utils import HasEnoughTissue, ResolvedTileGeometry # ignore all warnings from wholeslidedata @@ -78,7 +78,10 @@ def __init__( self.requested_backend = backend selection = resolve_backend(backend, wsi_path=path, mask_path=mask_path) self.backend = selection.backend - self.wsi = wsd.WholeSlideImage(path, backend=self.backend) + self.wsi = wsd.WholeSlideImage( + coerce_wsd_path(path, backend=self.backend), + backend=self.backend, + ) self._scaled_contours_cache = {} # add a cache for scaled contours self._scaled_holes_cache = {} # add a cache for scaled holes @@ -97,7 +100,10 @@ def __init__( raise ValueError( "sampling_spec is required when loading a mask-backed slide" ) - self.mask = wsd.WholeSlideImage(mask_path, backend=self.backend) + self.mask = wsd.WholeSlideImage( + coerce_wsd_path(mask_path, backend=self.backend), + backend=self.backend, + ) self.seg_level = self.load_segmentation( segment_params, sampling_spec=sampling_spec, diff --git a/scripts/benchmark_tile_read_strategies.py b/scripts/benchmark_tile_read_strategies.py index 5906a6e..c45f6a5 100644 --- a/scripts/benchmark_tile_read_strategies.py +++ b/scripts/benchmark_tile_read_strategies.py @@ -26,7 +26,6 @@ TimeRemainingColumn, ) -from hs2p.stderr_utils import run_with_filtered_stderr MODE_CONFIG = { "regular_wsd": { @@ -380,7 +379,7 @@ def benchmark_wsd_mode( def _require_cucim(): try: - return run_with_filtered_stderr(lambda: importlib.import_module("cucim")) + return importlib.import_module("cucim") except ModuleNotFoundError as exc: raise RuntimeError( "CuCIM is required for the requested benchmark modes but is not installed." @@ -398,7 +397,7 @@ def benchmark_cucim_batch_mode( from hs2p.benchmarking import group_read_plans_by_read_size cucim = _require_cucim() - cu_image = run_with_filtered_stderr(lambda: cucim.CuImage(str(result.image_path))) + cu_image = cucim.CuImage(str(result.image_path)) tile_size_px = int(result.read_tile_size_px) checksum = 0 tile_count = 0 diff --git a/scripts/benchmark_tile_store.py b/scripts/benchmark_tile_store.py new file mode 100644 index 0000000..9357964 --- /dev/null +++ b/scripts/benchmark_tile_store.py @@ -0,0 +1,505 @@ +#!/usr/bin/env python3 +"""Benchmark tile-store creation with read/encode/write time breakdown. + +Replicates the ``extract_tiles_to_tar`` pipeline from ``hs2p.api`` with +per-phase timing instrumentation so that read, JPEG-encode, and tar-write +costs can be measured independently. +""" +from __future__ import annotations + +import argparse +import io +import statistics +import sys +import tarfile +import tempfile +import time +from pathlib import Path +from typing import Any, Callable + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from hs2p.api import _iter_tile_arrays_for_tar_extraction + +ProgressCallback = Callable[[int, int], None] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Benchmark tile-store creation (read/encode/write breakdown) " + "from a fresh tiling result generated from an hs2p config file." + ), + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config-file", + type=Path, + required=True, + help="Path to an hs2p config file whose CSV contains exactly one slide.", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Directory where benchmark CSV outputs are written.", + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="Number of timed repetitions.", + ) + parser.add_argument( + "--warmup", + type=int, + default=0, + help="Untimed warmup repetitions.", + ) + parser.add_argument( + "--max-tiles", + type=int, + default=0, + help="Use only the first N tiles. Set to 0 to use all tiles.", + ) + parser.add_argument( + "--workers", + type=int, + nargs="+", + default=[4, 8, 16, 32], + help="Worker counts to sweep.", + ) + parser.add_argument( + "--jpeg-quality", + type=int, + default=90, + help="JPEG encoding quality (1-100).", + ) + return parser.parse_args() + + +# ------------------------------------------------------------------ +# Core benchmark +# ------------------------------------------------------------------ + + +def benchmark_tile_store( + *, + result, + jpeg_quality: int, + num_workers: int, + output_dir: Path, + progress_callback: ProgressCallback | None = None, +) -> dict[str, float | int]: + """Run one extraction pass and return per-phase timing metrics.""" + from PIL import Image + + output_dir.mkdir(parents=True, exist_ok=True) + + read_s = 0.0 + encode_s = 0.0 + write_s = 0.0 + tile_count = 0 + jpeg_bytes = 0 + + temp_tar_path: Path | None = None + try: + with tempfile.NamedTemporaryFile( + suffix=".tar", dir=output_dir, delete=False + ) as tmp: + temp_tar_path = Path(tmp.name) + + total_start = time.perf_counter() + + with tarfile.open(temp_tar_path, "w") as tf: + iterator = _iter_tile_arrays_for_tar_extraction( + result=result, num_workers=num_workers, + ) + + while True: + t0 = time.perf_counter() + try: + tile_arr = next(iterator) + except StopIteration: + break + t1 = time.perf_counter() + read_s += t1 - t0 + + if tile_arr.shape[2] > 3: + tile_arr = tile_arr[:, :, :3] + + img = Image.fromarray(tile_arr).convert("RGB") + if result.read_tile_size_px != result.target_tile_size_px: + img = img.resize( + (result.target_tile_size_px, result.target_tile_size_px), + Image.LANCZOS, + ) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=jpeg_quality) + buf.seek(0) + t2 = time.perf_counter() + encode_s += t2 - t1 + + info = tarfile.TarInfo(name=f"{tile_count:06d}.jpg") + info.size = buf.getbuffer().nbytes + tf.addfile(info, buf) + t3 = time.perf_counter() + write_s += t3 - t2 + + jpeg_bytes += info.size + tile_count += 1 + if progress_callback is not None: + progress_callback(1, 1) + + total_s = time.perf_counter() - total_start + finally: + if temp_tar_path is not None: + temp_tar_path.unlink(missing_ok=True) + + return { + "tile_count": tile_count, + "jpeg_bytes": jpeg_bytes, + "read_s": read_s, + "encode_s": encode_s, + "write_s": write_s, + "total_s": total_s, + } + + +# ------------------------------------------------------------------ +# Result formatting +# ------------------------------------------------------------------ + + +def build_result_row( + *, + sample_id: str, + image_path: str, + repeat_index: int, + tiles: int, + jpeg_quality: int, + num_workers: int, + read_s: float, + encode_s: float, + write_s: float, + total_s: float, + jpeg_bytes: int, +) -> dict[str, Any]: + return { + "sample_id": sample_id, + "image_path": image_path, + "repeat_index": repeat_index, + "tiles": tiles, + "jpeg_quality": jpeg_quality, + "num_workers": num_workers, + "read_s": round(read_s, 6), + "encode_s": round(encode_s, 6), + "write_s": round(write_s, 6), + "total_s": round(total_s, 6), + "read_pct": round(100 * read_s / total_s, 2) if total_s > 0 else 0.0, + "encode_pct": round(100 * encode_s / total_s, 2) if total_s > 0 else 0.0, + "write_pct": round(100 * write_s / total_s, 2) if total_s > 0 else 0.0, + "tiles_per_second": round(tiles / total_s, 2) if total_s > 0 else 0.0, + "jpeg_bytes": jpeg_bytes, + "jpeg_mb_per_second": round( + (jpeg_bytes / 1_000_000) / total_s, 2 + ) if total_s > 0 else 0.0, + } + + +def summarize_results(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not rows: + return [] + + def _mean(vals): + return round(statistics.mean(vals), 6) + + def _pstdev(vals): + return round(statistics.pstdev(vals), 6) if len(vals) > 1 else 0.0 + + grouped: dict[int, list[dict[str, Any]]] = {} + for row in rows: + grouped.setdefault(int(row["num_workers"]), []).append(row) + + summary: list[dict[str, Any]] = [] + for nw, nw_rows in grouped.items(): + read_s = [float(r["read_s"]) for r in nw_rows] + encode_s = [float(r["encode_s"]) for r in nw_rows] + write_s = [float(r["write_s"]) for r in nw_rows] + total_s = [float(r["total_s"]) for r in nw_rows] + read_pct = [float(r["read_pct"]) for r in nw_rows] + encode_pct = [float(r["encode_pct"]) for r in nw_rows] + write_pct = [float(r["write_pct"]) for r in nw_rows] + tps = [float(r["tiles_per_second"]) for r in nw_rows] + summary.append( + { + "num_workers": nw, + "tiles": int(nw_rows[0]["tiles"]), + "jpeg_quality": int(nw_rows[0]["jpeg_quality"]), + "mean_read_s": _mean(read_s), + "mean_encode_s": _mean(encode_s), + "mean_write_s": _mean(write_s), + "mean_total_s": _mean(total_s), + "mean_read_pct": round(statistics.mean(read_pct), 2), + "mean_encode_pct": round(statistics.mean(encode_pct), 2), + "mean_write_pct": round(statistics.mean(write_pct), 2), + "mean_tiles_per_second": round(statistics.mean(tps), 2), + "std_tiles_per_second": round(_pstdev(tps), 2), + } + ) + return summary + + +# ------------------------------------------------------------------ +# Chart +# ------------------------------------------------------------------ + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt # noqa: E402 +import matplotlib.ticker as ticker # noqa: E402 + +_C_LINE = "#1a6faf" # main line โ€“ steel blue +_C_BAND = "#a8c8e8" # error band โ€“ desaturated blue +_C_TEXT = "#222222" # primary text +_C_MUTED = "#666666" # secondary text +_C_GRID = "#e8e8e8" # gridlines + + +def plot_results( + summary_rows: list[dict[str, Any]], + *, + output_path: Path, + max_tiles: int, + jpeg_quality: int, +) -> None: + workers = [int(r["num_workers"]) for r in summary_rows] + tps = [float(r["mean_tiles_per_second"]) for r in summary_rows] + err = [float(r["std_tiles_per_second"]) for r in summary_rows] + has_err = any(e > 0 for e in err) + + plt.rcParams.update({ + "font.family": "sans-serif", + "font.size": 11, + "axes.spines.top": False, + "axes.spines.right": False, + "axes.linewidth": 0.8, + "xtick.direction": "out", + "ytick.direction": "out", + "xtick.major.size": 4, + "ytick.major.size": 4, + "xtick.major.width": 0.8, + "ytick.major.width": 0.8, + }) + + fig, ax = plt.subplots(figsize=(8, 5)) + fig.patch.set_facecolor("white") + ax.set_facecolor("white") + + ax.yaxis.grid(True, color=_C_GRID, linewidth=0.6, linestyle="-", zorder=0) + ax.set_axisbelow(True) + ax.xaxis.grid(False) + + ax.set_xscale("log", base=2) + + y_min = min(tps) * 0.72 + y_max = max(tps) * 1.48 + + if has_err: + lower = [max(0.0, t - e) for t, e in zip(tps, err)] + upper = [t + e for t, e in zip(tps, err)] + ax.fill_between(workers, lower, upper, color=_C_BAND, alpha=0.35, zorder=3) + + ax.plot(workers, tps, color=_C_LINE, linewidth=2.4, + solid_capstyle="round", solid_joinstyle="round", zorder=4) + + for w, t in zip(workers, tps): + ax.scatter(w, t, s=65, color=_C_LINE, zorder=6) + ax.scatter(w, t, s=20, color="white", zorder=7) + + v_offset = y_max * 0.04 + for w, t in zip(workers, tps): + ax.text(w, t + v_offset, f"{t:,.0f}", + ha="center", va="bottom", fontsize=8.5, + color=_C_TEXT, fontweight="semibold") + + ax.set_xlabel("Number of workers", fontsize=11, labelpad=10, color=_C_TEXT) + ax.set_ylabel("Throughput (tiles / s)", fontsize=11, labelpad=10, color=_C_TEXT) + ax.set_xticks(workers) + ax.set_xticklabels([str(w) for w in workers], fontsize=10, color=_C_TEXT) + ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda v, _: f"{v:,.0f}")) + ax.tick_params(axis="y", labelsize=10, colors=_C_TEXT) + ax.tick_params(axis="x", colors=_C_TEXT) + log_pad = 0.30 + ax.set_xlim(workers[0] * 2 ** (-log_pad), workers[-1] * 2 ** log_pad) + ax.set_ylim(y_min, y_max) + + tiles_label = f"{max_tiles:,}" if max_tiles > 0 else "all" + fig.text(0.13, 0.97, "Tile Store Throughput", + ha="left", va="top", fontsize=14, fontweight="bold", color=_C_TEXT) + fig.text(0.13, 0.925, f"JPEG quality={jpeg_quality} ยท tiles={tiles_label}", + ha="left", va="top", fontsize=8.5, color=_C_MUTED) + + # inset table: read/encode/write breakdown per worker count + tbl_ax = fig.add_axes([0.60, 0.845, 0.37, 0.145]) + tbl_ax.axis("off") + table_data = [ + [str(int(r["num_workers"])), + f"{float(r['mean_read_pct']):.0f}%", + f"{float(r['mean_encode_pct']):.0f}%", + f"{float(r['mean_write_pct']):.0f}%"] + for r in summary_rows + ] + tbl = tbl_ax.table( + cellText=table_data, + colLabels=["workers", "read", "encode", "write"], + loc="center", + cellLoc="center", + ) + tbl.auto_set_font_size(False) + tbl.set_fontsize(7.5) + for (row, col), cell in tbl.get_celld().items(): + cell.set_linewidth(0.4) + cell.set_edgecolor(_C_GRID) + if row == 0: + cell.set_text_props(color=_C_MUTED, fontweight="semibold") + cell.set_facecolor("#f4f8fc") + else: + cell.set_facecolor("white") + cell.set_text_props(color=_C_MUTED) + cell.set_height(0.155) + + fig.subplots_adjust(top=0.84, bottom=0.13, left=0.13, right=0.97) + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white") + plt.close(fig) + plt.rcdefaults() + print(f"Chart saved โ†’ {output_path}") + + +# ------------------------------------------------------------------ +# CLI +# ------------------------------------------------------------------ + + +def main() -> int: + from benchmark_tile_read_strategies import ( + BenchmarkProgressReporter, + load_single_slide_result_from_config, + write_csv, + ) + from hs2p.benchmarking import limit_tiling_result + + args = parse_args() + workers = sorted(args.workers) + + result = load_single_slide_result_from_config( + config_file=args.config_file, + num_workers=workers[0], + ) + result = limit_tiling_result(result, max_tiles=int(args.max_tiles)) + args.output_dir.mkdir(parents=True, exist_ok=True) + + total_runs = len(workers) * (int(args.warmup) + int(args.repeat)) + + timed_rows: list[dict[str, Any]] = [] + run_counter = 0 + with BenchmarkProgressReporter(total_runs=total_runs) as reporter: + reporter.print_banner( + result=result, + modes=[f"workers={w}" for w in workers], + repeat=int(args.repeat), + warmup=int(args.warmup), + ) + for nw in workers: + mode_label = f"tile_store w={nw}" + for warmup_idx in range(int(args.warmup)): + run_counter += 1 + reporter.start_run( + run_counter=run_counter, + phase="warmup", + mode=mode_label, + iteration_index=warmup_idx, + iteration_total=int(args.warmup), + total_read_calls=int(result.num_tiles), + total_tiles=int(result.num_tiles), + ) + benchmark_tile_store( + result=result, + jpeg_quality=int(args.jpeg_quality), + num_workers=nw, + output_dir=args.output_dir, + progress_callback=reporter.advance, + ) + reporter.finish_run() + + for repeat_index in range(int(args.repeat)): + run_counter += 1 + reporter.start_run( + run_counter=run_counter, + phase="timed", + mode=mode_label, + iteration_index=repeat_index, + iteration_total=int(args.repeat), + total_read_calls=int(result.num_tiles), + total_tiles=int(result.num_tiles), + ) + metrics = benchmark_tile_store( + result=result, + jpeg_quality=int(args.jpeg_quality), + num_workers=nw, + output_dir=args.output_dir, + progress_callback=reporter.advance, + ) + reporter.finish_run() + + row = build_result_row( + sample_id=str(result.sample_id), + image_path=str(result.image_path), + repeat_index=repeat_index, + tiles=metrics["tile_count"], + jpeg_quality=int(args.jpeg_quality), + num_workers=nw, + read_s=metrics["read_s"], + encode_s=metrics["encode_s"], + write_s=metrics["write_s"], + total_s=metrics["total_s"], + jpeg_bytes=metrics["jpeg_bytes"], + ) + reporter.console.print( + ( + f"workers={nw:<3} rep={repeat_index + 1} " + f"tiles={int(row['tiles']):>7,d} " + f"read={float(row['read_pct']):>5.1f}% " + f"encode={float(row['encode_pct']):>5.1f}% " + f"write={float(row['write_pct']):>5.1f}% " + f"elapsed={float(row['total_s']):>8.3f}s " + f"throughput={float(row['tiles_per_second']):>10,.0f} tiles/s" + ), + highlight=False, + ) + timed_rows.append(row) + + summary_rows = summarize_results(timed_rows) + runs_csv_path = write_csv(timed_rows, args.output_dir / "benchmark_runs.csv") + summary_csv_path = write_csv(summary_rows, args.output_dir / "benchmark_summary.csv") + + print(f"\nWrote {runs_csv_path}", flush=True) + print(f"Wrote {summary_csv_path}", flush=True) + + if len(summary_rows) > 1: + chart_path = args.output_dir / "throughput.png" + plot_results( + summary_rows, + output_path=chart_path, + max_tiles=int(args.max_tiles), + jpeg_quality=int(args.jpeg_quality), + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/test_backend_selection.py b/tests/test_backend_selection.py index a1ff564..f0a56d1 100644 --- a/tests/test_backend_selection.py +++ b/tests/test_backend_selection.py @@ -5,6 +5,7 @@ import hs2p.api as api_mod import hs2p.wsi.backend as backend_mod +import hs2p.wsi.wsi as wsi_mod def test_resolve_backend_prefers_cucim_when_supported(monkeypatch): @@ -60,6 +61,68 @@ def _fake_can_open_slide(*, wsi_path: str, mask_path: str | None, backend: str): assert calls == [] +def test_backend_probe_coerces_cucim_paths_to_strings(monkeypatch): + seen_paths: list[tuple[object, str]] = [] + + def _fake_wholeslideimage(path, *, backend: str): + seen_paths.append((path, backend)) + return SimpleNamespace() + + backend_mod._backend_can_open_slide.cache_clear() + monkeypatch.setattr(backend_mod.wsd, "WholeSlideImage", _fake_wholeslideimage) + + assert backend_mod._backend_can_open_slide( + wsi_path="/tmp/slide.tiff", + mask_path="/tmp/mask.tiff", + backend="cucim", + ) + assert seen_paths == [ + ("/tmp/slide.tiff", "cucim"), + ("/tmp/mask.tiff", "cucim"), + ] + + +def test_wholeslideimage_coerces_cucim_paths_to_strings(monkeypatch): + seen_paths: list[tuple[object, str]] = [] + + class _FakeSlide: + spacings = [0.5] + shapes = [(100, 100)] + + def _fake_wholeslideimage(path, *, backend: str): + seen_paths.append((path, backend)) + return _FakeSlide() + + monkeypatch.setattr( + wsi_mod, + "resolve_backend", + lambda requested_backend, *, wsi_path, mask_path=None: backend_mod.BackendSelection( + backend="cucim", + tried=("cucim",), + ), + ) + monkeypatch.setattr(wsi_mod.wsd, "WholeSlideImage", _fake_wholeslideimage) + monkeypatch.setattr(wsi_mod.WholeSlideImage, "load_segmentation", lambda *args, **kwargs: 0) + + wsi_mod.WholeSlideImage( + path=Path("/tmp/slide.tiff"), + mask_path=Path("/tmp/mask.tiff"), + backend="auto", + sampling_spec=wsi_mod.ResolvedSamplingSpec( + pixel_mapping={"background": 0, "tumor": 1}, + color_mapping={"background": None, "tumor": None}, + tissue_percentage={"background": None, "tumor": 0.1}, + active_annotations=("tumor",), + ), + segment_params=SimpleNamespace(), + ) + + assert seen_paths == [ + ("/tmp/slide.tiff", "cucim"), + ("/tmp/mask.tiff", "cucim"), + ] + + def test_tile_slide_uses_resolved_backend_for_hash_and_result(monkeypatch): captured: dict[str, str] = {} diff --git a/tests/test_benchmark_tile_store.py b/tests/test_benchmark_tile_store.py new file mode 100644 index 0000000..29443f3 --- /dev/null +++ b/tests/test_benchmark_tile_store.py @@ -0,0 +1,168 @@ +import importlib.util +import statistics +from pathlib import Path +from unittest.mock import MagicMock + +import numpy as np +import pytest + + +def _load_benchmark_module(): + module_path = ( + Path(__file__).resolve().parents[1] / "scripts" / "benchmark_tile_store.py" + ) + spec = importlib.util.spec_from_file_location("benchmark_tile_store", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_build_result_row_computes_percentages_and_throughput(): + mod = _load_benchmark_module() + row = mod.build_result_row( + sample_id="slide-1", + image_path="/slides/slide-1.svs", + repeat_index=0, + tiles=100, + jpeg_quality=90, + num_workers=4, + read_s=1.0, + encode_s=2.0, + write_s=1.0, + total_s=4.0, + jpeg_bytes=500_000, + ) + assert row["sample_id"] == "slide-1" + assert row["image_path"] == "/slides/slide-1.svs" + assert row["repeat_index"] == 0 + assert row["tiles"] == 100 + assert row["jpeg_quality"] == 90 + assert row["num_workers"] == 4 + assert row["read_s"] == 1.0 + assert row["encode_s"] == 2.0 + assert row["write_s"] == 1.0 + assert row["total_s"] == 4.0 + assert row["read_pct"] == 25.0 + assert row["encode_pct"] == 50.0 + assert row["write_pct"] == 25.0 + assert row["tiles_per_second"] == 25.0 + assert row["jpeg_bytes"] == 500_000 + assert row["jpeg_mb_per_second"] == 0.12 + + +def test_summarize_results_computes_mean_and_std(): + mod = _load_benchmark_module() + rows = [ + mod.build_result_row( + sample_id="s", + image_path="/s.svs", + repeat_index=i, + tiles=100, + jpeg_quality=90, + num_workers=4, + read_s=r, + encode_s=e, + write_s=w, + total_s=r + e + w, + jpeg_bytes=500_000, + ) + for i, (r, e, w) in enumerate([(1.0, 2.0, 1.0), (1.2, 2.2, 0.8), (0.8, 1.8, 1.2)]) + ] + summary = mod.summarize_results(rows) + assert len(summary) == 1 + s = summary[0] + assert s["tiles"] == 100 + assert s["jpeg_quality"] == 90 + assert s["num_workers"] == 4 + + expected_totals = [4.0, 4.2, 3.8] + expected_tps = [100 / t for t in expected_totals] + assert s["mean_total_s"] == pytest.approx(statistics.mean(expected_totals), abs=1e-4) + assert s["mean_read_s"] == pytest.approx(statistics.mean([1.0, 1.2, 0.8]), abs=1e-4) + assert s["mean_encode_s"] == pytest.approx(statistics.mean([2.0, 2.2, 1.8]), abs=1e-4) + assert s["mean_write_s"] == pytest.approx(statistics.mean([1.0, 0.8, 1.2]), abs=1e-4) + assert s["mean_tiles_per_second"] == pytest.approx(statistics.mean(expected_tps), abs=0.1) + assert s["std_tiles_per_second"] == pytest.approx(statistics.pstdev(expected_tps), abs=0.1) + + +def test_benchmark_tile_store_accumulates_per_phase_times(tmp_path, monkeypatch): + mod = _load_benchmark_module() + tiles = [np.zeros((64, 64, 3), dtype=np.uint8) for _ in range(3)] + + monkeypatch.setattr( + mod, "_iter_tile_arrays_for_tar_extraction", lambda **kwargs: iter(tiles) + ) + + result_stub = MagicMock() + result_stub.read_tile_size_px = 64 + result_stub.target_tile_size_px = 64 + + metrics = mod.benchmark_tile_store( + result=result_stub, + jpeg_quality=90, + num_workers=4, + output_dir=tmp_path, + ) + + assert metrics["tile_count"] == 3 + assert metrics["jpeg_bytes"] > 0 + assert metrics["read_s"] >= 0 + assert metrics["encode_s"] > 0 + assert metrics["write_s"] > 0 + assert metrics["total_s"] >= metrics["read_s"] + metrics["encode_s"] + metrics["write_s"] + # temp tar cleaned up + assert not list(tmp_path.glob("*.tar")) + + +def test_benchmark_tile_store_handles_empty_iterator(tmp_path, monkeypatch): + mod = _load_benchmark_module() + + monkeypatch.setattr( + mod, "_iter_tile_arrays_for_tar_extraction", lambda **kwargs: iter([]) + ) + + result_stub = MagicMock() + result_stub.read_tile_size_px = 64 + result_stub.target_tile_size_px = 64 + + metrics = mod.benchmark_tile_store( + result=result_stub, + jpeg_quality=90, + num_workers=4, + output_dir=tmp_path, + ) + + assert metrics["tile_count"] == 0 + assert metrics["jpeg_bytes"] == 0 + assert metrics["read_s"] == 0.0 + assert metrics["encode_s"] == 0.0 + assert metrics["write_s"] == 0.0 + assert metrics["total_s"] >= 0.0 + + +def test_progress_callback_called_per_tile(tmp_path, monkeypatch): + mod = _load_benchmark_module() + tiles = [np.zeros((64, 64, 3), dtype=np.uint8) for _ in range(5)] + + monkeypatch.setattr( + mod, "_iter_tile_arrays_for_tar_extraction", lambda **kwargs: iter(tiles) + ) + + result_stub = MagicMock() + result_stub.read_tile_size_px = 64 + result_stub.target_tile_size_px = 64 + + callback = MagicMock() + mod.benchmark_tile_store( + result=result_stub, + jpeg_quality=90, + num_workers=4, + output_dir=tmp_path, + progress_callback=callback, + ) + + assert callback.call_count == 5 + for call in callback.call_args_list: + assert call == ((1, 1),) diff --git a/tests/test_tile_extraction.py b/tests/test_tile_extraction.py index 8aee70a..6f712ee 100644 --- a/tests/test_tile_extraction.py +++ b/tests/test_tile_extraction.py @@ -362,7 +362,7 @@ def _import_module(name): assert tar_path.is_file() assert out_result is result - mock_wsd.assert_called_once_with(result.image_path, backend="cucim") + mock_wsd.assert_called_once_with(str(result.image_path), backend="cucim") def test_cucim_iterator_groups_dense_4x4_grid_into_one_batched_read( self, monkeypatch