diff --git a/README.md b/README.md index d6d63ba..5d20e57 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,15 @@ You can also upload your own pyramidal WSI (up to 1 GB). pip install hs2p ``` +Optional CuCIM install for faster tile tar export when using `tiling.backend="cucim"`: + +```bash +pip install cucim-cu12 +``` + +Use the CuCIM wheel that matches your CUDA runtime. The base `hs2p` install does not +require CuCIM. + ## Workflows ### Tiling @@ -122,6 +131,11 @@ For a first run, start from [hs2p/configs/default.yaml](hs2p/configs/default.yam - `tiling.params.target_spacing_um` - `tiling.params.target_tile_size_px` +Optional: + +- `save_tiles` + - also write `tiles/{sample_id}.tiles.tar` archives; with `tiling.backend="cucim"` this uses batched CuCIM reads during tar extraction + Run tiling: ```bash diff --git a/docs/cli.md b/docs/cli.md index 8961dc8..6144e73 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -52,6 +52,16 @@ Run sampling: python -m hs2p.sampling --config-file /path/to/config.yaml ``` +Optional CuCIM install for faster tar export with `save_tiles: true` and +`tiling.backend: cucim`: + +```bash +pip install cucim-cu12 +``` + +Use the CuCIM wheel that matches your CUDA runtime. Non-CuCIM backends continue to +use the default sequential tile export path. + ## Progress UX When stdout is an interactive terminal, `hs2p` uses `rich` to show live progress for both CLI entrypoints. @@ -85,8 +95,10 @@ Detailed logs still go to `output_dir/logs/log.txt`, which is the best place to - Annotation-specific sampling rules for `hs2p.sampling` - `save_previews` - Global switch for writing mask and tiling previews to disk +- `save_tiles` + - Global switch for writing `tiles/{sample_id}.tiles.tar` alongside coordinate artifacts - `speed.num_workers` - - Parallelism for slide processing + - Parallelism for slide processing, and the per-slide worker budget reused by CuCIM batched tile extraction when `tiling.backend: cucim` ## Sampling-specific settings @@ -120,6 +132,14 @@ These filters are **disabled by default** and should stay off unless your datase When enabled, every candidate tile that passes the tissue mask check is read from the slide at full resolution and its pixel values inspected. This is the **only step in the tiling pipeline that reads actual tile pixel data**. For slides with large internal JPEG tiles (common in some scanner formats), each read triggers a full JPEG decode of the underlying tile block — which can be an order of magnitude slower than the rest of the pipeline per slide. +### Saved tile export (`save_tiles`) + +When `save_tiles: true`, HS2P also writes a `tiles/{sample_id}.tiles.tar` archive with JPEG-encoded tile images. + +- For non-CuCIM backends, tar extraction uses the existing sequential reader. +- For `tiling.backend: cucim`, tar extraction uses a CuCIM batch-read fast path and reuses the per-slide worker count from `speed.num_workers`. +- Installing CuCIM is optional. If `backend: cucim` is selected but CuCIM is not installed, HS2P falls back to the sequential export path and emits a warning. + ## Resume and precomputed artifacts - `resume: true` expects the current `process_list.csv` schema and current-format artifacts diff --git a/hs2p/api.py b/hs2p/api.py index 31506b6..d7e2c71 100644 --- a/hs2p/api.py +++ b/hs2p/api.py @@ -1,4 +1,5 @@ import hashlib +import importlib import io import json import multiprocessing as mp @@ -487,6 +488,7 @@ def extract_tiles_to_tar( jpeg_quality: int = 90, tiles_dir: Path | None = None, filter_params: FilterConfig | None = None, + num_workers: int = 4, ) -> tuple[Path, TilingResult]: """Extract tile images from a WSI and save them as a JPEG tar archive. @@ -494,15 +496,12 @@ def extract_tiles_to_tar( during extraction so that pixel data is read only once. The returned ``TilingResult`` has its coordinate arrays trimmed to the surviving tiles. """ - import wholeslidedata as wsd from PIL import Image tiles_dir = Path(tiles_dir) if tiles_dir is not None else Path(output_dir) / "tiles" tiles_dir.mkdir(parents=True, exist_ok=True) tar_path = tiles_dir / f"{result.sample_id}.tiles.tar" - wsi = wsd.WholeSlideImage(result.image_path, backend=result.backend) - do_filter_white = filter_params is not None and filter_params.filter_white do_filter_black = filter_params is not None and filter_params.filter_black white_thresh = getattr(filter_params, "white_threshold", 220) if filter_params else 220 @@ -520,15 +519,12 @@ def extract_tiles_to_tar( temp_tar_path = Path(tmp.name) with tarfile.open(temp_tar_path, "w") as tf: - for i in range(result.num_tiles): - tile_arr = wsi.get_patch( - int(result.x[i]), - int(result.y[i]), - int(result.read_tile_size_px), - int(result.read_tile_size_px), - spacing=float(result.read_spacing_um), - center=False, + for i, tile_arr in enumerate( + _iter_tile_arrays_for_tar_extraction( + result=result, + num_workers=num_workers, ) + ): if tile_arr.shape[2] > 3: tile_arr = tile_arr[:, :, :3] @@ -601,6 +597,76 @@ def extract_tiles_to_tar( return tar_path, filtered_result +def _iter_tile_arrays_for_tar_extraction( + *, + result: TilingResult, + num_workers: int, +): + tile_arrays = _iter_cucim_tile_arrays_for_tar_extraction( + result=result, + num_workers=num_workers, + ) + if tile_arrays is not None: + yield from tile_arrays + return + yield from _iter_wsd_tile_arrays_for_tar_extraction(result=result) + + +def _iter_cucim_tile_arrays_for_tar_extraction( + *, + result: TilingResult, + num_workers: int, +): + if result.backend != "cucim": + return None + try: + cucim = importlib.import_module("cucim") + except ModuleNotFoundError: + warnings.warn( + "CuCIM is unavailable for backend='cucim'; falling back to sequential wholeslidedata tile extraction.", + UserWarning, + stacklevel=2, + ) + return None + + cu_image = cucim.CuImage(str(result.image_path)) + locations = [ + (int(x), int(y)) + for x, y in zip( + result.x.astype(np.int64, copy=False).tolist(), + result.y.astype(np.int64, copy=False).tolist(), + ) + ] + read_size = (int(result.read_tile_size_px), int(result.read_tile_size_px)) + return ( + np.asarray(region) + for region in cu_image.read_region( + locations, + read_size, + level=int(result.read_level), + num_workers=max(1, int(num_workers)), + ) + ) + + +def _iter_wsd_tile_arrays_for_tar_extraction( + *, + result: TilingResult, +): + import wholeslidedata as wsd + + wsi = wsd.WholeSlideImage(result.image_path, backend=result.backend) + for i in range(result.num_tiles): + yield wsi.get_patch( + int(result.x[i]), + int(result.y[i]), + int(result.read_tile_size_px), + int(result.read_tile_size_px), + spacing=float(result.read_spacing_um), + center=False, + ) + + def _needs_pixel_filtering(filtering: FilterConfig) -> bool: return bool(filtering.filter_white or filtering.filter_black) @@ -1000,6 +1066,7 @@ def _compute_tiling_result_from_request( result, output_dir=request.output_dir, filter_params=request.filtering if _needs_pixel_filtering(request.filtering) else None, + num_workers=request.num_workers, ) artifact = save_tiling_result(result, output_dir=request.output_dir) artifact = TilingArtifacts( diff --git a/hs2p/configs/default.yaml b/hs2p/configs/default.yaml index fd9c620..c64a4d3 100644 --- a/hs2p/configs/default.yaml +++ b/hs2p/configs/default.yaml @@ -5,6 +5,7 @@ resume: false # resume from a previous run resume_dirname: # directory name to resume from save_previews: true # save preview images of slide tiling and mask overlays +save_tiles: false # save extracted tiles as {sample_id}.tiles.tar in addition to coordinate artifacts seed: 0 # seed for reproducibility diff --git a/hs2p/tiling.py b/hs2p/tiling.py index 4774097..cbf9333 100644 --- a/hs2p/tiling.py +++ b/hs2p/tiling.py @@ -78,6 +78,7 @@ def main(args): num_workers=cfg.speed.num_workers, resume=cfg.resume, read_coordinates_from=read_coordinates_from, + save_tiles=bool(getattr(cfg, "save_tiles", False)), ) pd.read_csv(output_dir / "process_list.csv") progress.emit_progress( diff --git a/tests/test_cli_smoke.py b/tests/test_cli_smoke.py index f947528..b5d00ce 100644 --- a/tests/test_cli_smoke.py +++ b/tests/test_cli_smoke.py @@ -38,6 +38,7 @@ def _base_cfg(tmp_path: Path, csv_path: Path) -> SimpleNamespace: output_dir=str(tmp_path / "output"), resume=False, save_previews=False, + save_tiles=False, speed=SimpleNamespace(num_workers=1), tiling=SimpleNamespace( read_coordinates_from=None, @@ -102,9 +103,11 @@ def _fake_tile_slides( num_workers, resume, read_coordinates_from, + save_tiles, ): del tiling, segmentation, filtering, preview, num_workers, resume, read_coordinates_from captured["whole_slides"] = whole_slides + captured["save_tiles"] = save_tiles output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) process_df = pd.DataFrame( @@ -148,6 +151,7 @@ def _fake_tile_slides( mask_path=Path("slide-1-mask.png"), ) ] + assert captured["save_tiles"] is False process_df = pd.read_csv(Path(cfg.output_dir) / "process_list.csv") assert list(process_df.columns) == [ "sample_id", diff --git a/tests/test_tile_extraction.py b/tests/test_tile_extraction.py index 0b2182c..0ff02ff 100644 --- a/tests/test_tile_extraction.py +++ b/tests/test_tile_extraction.py @@ -1,11 +1,13 @@ """Tests for extract_tiles_to_tar() and the save_tiles pipeline option.""" import io +import types import tarfile from pathlib import Path from unittest.mock import MagicMock, patch import numpy as np +import pytest from PIL import Image from hs2p.api import TilingResult, extract_tiles_to_tar @@ -228,6 +230,75 @@ def test_no_filter_params_keeps_all_tiles(self, tmp_path: Path): assert out_result is result # unchanged + def test_cucim_backend_uses_batched_read_region(self, monkeypatch, tmp_path: Path): + result = _make_tiling_result(num_tiles=2) + result.backend = "cucim" + result.read_level = 3 + result.read_tile_size_px = 128 + + regions = [_solid_patch((10, 20, 30), size=128), _solid_patch((40, 50, 60), size=128)] + mock_cu_image = MagicMock() + mock_cu_image.read_region.return_value = iter(regions) + fake_cucim = types.SimpleNamespace(CuImage=MagicMock(return_value=mock_cu_image)) + + import hs2p.api as api_mod + + monkeypatch.setattr( + api_mod.importlib, + "import_module", + lambda name: fake_cucim if name == "cucim" else None, + ) + + with patch("wholeslidedata.WholeSlideImage") as mock_wsd: + tar_path, out_result = extract_tiles_to_tar( + result, + output_dir=tmp_path, + num_workers=5, + ) + + assert tar_path.is_file() + assert out_result is result + fake_cucim.CuImage.assert_called_once_with(str(result.image_path)) + mock_cu_image.read_region.assert_called_once_with( + [(0, 0), (256, 0)], + (128, 128), + level=3, + num_workers=5, + ) + mock_wsd.assert_not_called() + + def test_cucim_backend_falls_back_to_wsd_when_cucim_is_unavailable( + self, monkeypatch, tmp_path: Path + ): + result = _make_tiling_result(num_tiles=1) + result.backend = "cucim" + + mock_wsi = MagicMock() + mock_wsi.get_patch.return_value = _solid_patch((70, 80, 90)) + + import hs2p.api as api_mod + + def _import_module(name): + if name == "cucim": + raise ModuleNotFoundError("No module named 'cucim'") + raise AssertionError(f"unexpected module import: {name}") + + monkeypatch.setattr(api_mod.importlib, "import_module", _import_module) + + with pytest.warns(UserWarning, match="CuCIM is unavailable"), patch( + "wholeslidedata.WholeSlideImage", + return_value=mock_wsi, + ) as mock_wsd: + tar_path, out_result = extract_tiles_to_tar( + result, + output_dir=tmp_path, + num_workers=4, + ) + + assert tar_path.is_file() + assert out_result is result + mock_wsd.assert_called_once_with(result.image_path, backend="cucim") + class TestNeedsPixelFiltering: def test_no_filtering(self): diff --git a/tests/test_tiling_api.py b/tests/test_tiling_api.py index 46f73dc..c3f6f02 100644 --- a/tests/test_tiling_api.py +++ b/tests/test_tiling_api.py @@ -37,6 +37,7 @@ from hs2p.utils import load_csv from hs2p.wsi import CoordinateExtractionResult import hs2p.wsi.wsi as wsi_mod +import hs2p.api as api_mod @pytest.fixture @@ -711,6 +712,84 @@ def imap_unordered(self, fn, args_list): assert seen["inner_workers"] == [4, 4] +def test_compute_request_passes_inner_workers_to_tile_extraction( + monkeypatch, tmp_path: Path +): + seen = {} + + def _fake_compute_tiling_result(*args, **kwargs): + del args, kwargs + return TilingResult( + sample_id="slide-1", + image_path=Path("slide-1.svs"), + mask_path=None, + backend="cucim", + x=np.array([10], dtype=np.int64), + y=np.array([20], dtype=np.int64), + tile_index=np.array([0], dtype=np.int32), + target_spacing_um=0.5, + target_tile_size_px=224, + read_level=0, + read_spacing_um=0.5, + read_tile_size_px=224, + tile_size_lv0=224, + overlap=0.0, + tissue_threshold=0.1, + num_tiles=1, + config_hash="hash", + ) + + def _fake_extract_tiles_to_tar(result, output_dir, *, filter_params=None, num_workers=1, **kwargs): + del output_dir, filter_params, kwargs + seen["num_workers"] = num_workers + return tmp_path / "tiles" / "slide-1.tiles.tar", result + + def _fake_save_tiling_result(result, output_dir, *, tiles_dir=None): + del tiles_dir + tiles_dir = Path(output_dir) / "tiles" + tiles_dir.mkdir(parents=True, exist_ok=True) + npz_path = tiles_dir / f"{result.sample_id}.coordinates.npz" + meta_path = tiles_dir / f"{result.sample_id}.coordinates.meta.json" + npz_path.write_bytes(b"npz") + meta_path.write_text("{}") + return TilingArtifacts( + sample_id=result.sample_id, + coordinates_npz_path=npz_path, + coordinates_meta_path=meta_path, + num_tiles=result.num_tiles, + ) + + request = api_mod._SlideComputeRequest( + input_index=0, + whole_slide=SlideSpec(sample_id="slide-1", image_path=Path("slide-1.svs")), + tiling=TilingConfig( + backend="cucim", + target_spacing_um=0.5, + target_tile_size_px=224, + tolerance=0.07, + overlap=0.0, + tissue_threshold=0.1, + drop_holes=False, + use_padding=True, + ), + segmentation=SegmentationConfig(64, 8, 255, 7, 4, False, True), + filtering=FilterConfig(224, 4, 2, 8, False, False, 220, 25, 0.9), + config_hash="hash", + mask_preview_path=None, + output_dir=tmp_path, + num_workers=6, + save_tiles=True, + ) + + monkeypatch.setattr(api_mod, "_compute_tiling_result", _fake_compute_tiling_result) + monkeypatch.setattr(api_mod, "extract_tiles_to_tar", _fake_extract_tiles_to_tar) + monkeypatch.setattr(api_mod, "save_tiling_result", _fake_save_tiling_result) + response = api_mod._compute_tiling_result_from_request(request) + + assert response.ok + assert seen["num_workers"] == 6 + + def test_save_tiling_result_rejects_invalid_tile_index(tmp_path: Path): invalid = TilingResult( sample_id="broken-slide",