diff --git a/.dockerignore b/.dockerignore index 05aa985..a934c7b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,3 @@ data/ output/ -docker/ \ No newline at end of file +outputs/ diff --git a/.gitignore b/.gitignore index ee76d01..639bf73 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ archive/ tasks/ docs/documentation.md docs/20*-*.md +data/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 6d920d5..a846bd7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,10 @@ ARG UBUNTU_VERSION=22.04 -ARG CUDA_MAJOR_VERSION=11.8.0 -ARG CUDNN_MAJOR_VERSION=8 +ARG CUDA_MAJOR_VERSION=12.8.1 ######################## # Stage 1: build stage # ######################## -FROM nvidia/cuda:${CUDA_MAJOR_VERSION}-cudnn${CUDNN_MAJOR_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS build +FROM nvidia/cuda:${CUDA_MAJOR_VERSION}-cudnn-devel-ubuntu${UBUNTU_VERSION} AS build ARG USER_UID=1001 ARG USER_GID=1001 @@ -29,6 +28,7 @@ ENV PATH="/home/user/.local/bin:${PATH}" RUN apt-get update && apt-get install -y --no-install-recommends \ libtiff-dev \ + cmake \ zlib1g-dev \ curl \ vim screen \ @@ -40,6 +40,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* +# libjpeg-turbo 3.x (required by PyTurboJPEG>=2) +ARG LIBJPEG_TURBO_VERSION=3.1.0 +RUN curl -fsSL https://github.com/libjpeg-turbo/libjpeg-turbo/releases/download/${LIBJPEG_TURBO_VERSION}/libjpeg-turbo-${LIBJPEG_TURBO_VERSION}.tar.gz \ + | tar xz -C /tmp \ + && cd /tmp/libjpeg-turbo-${LIBJPEG_TURBO_VERSION} \ + && cmake -G"Unix Makefiles" -DCMAKE_INSTALL_PREFIX=/usr/local . \ + && make -j"$(nproc)" && make install \ + && ldconfig \ + && rm -rf /tmp/libjpeg-turbo-${LIBJPEG_TURBO_VERSION} + WORKDIR /opt/app/ # core deps live in requirements.in; model runtime extras live in requirements-models.in @@ -70,7 +80,7 @@ RUN python -m pip install 'flash-attn>=2.7.1,<=2.8.0' --no-build-isolation ########################## # Stage 2: runtime stage # ########################## -FROM nvidia/cuda:${CUDA_MAJOR_VERSION}-cudnn${CUDNN_MAJOR_VERSION}-runtime-ubuntu${UBUNTU_VERSION} +FROM nvidia/cuda:${CUDA_MAJOR_VERSION}-cudnn-runtime-ubuntu${UBUNTU_VERSION} ARG USER_UID=1001 ARG USER_GID=1001 @@ -104,6 +114,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* +# libjpeg-turbo 3.x (copied from build stage) +COPY --from=build /usr/local/lib/libjpeg* /usr/local/lib/ +COPY --from=build /usr/local/lib/libturbojpeg* /usr/local/lib/ +RUN ldconfig + # install ASAP ARG ASAP_URL=https://github.com/computationalpathologygroup/ASAP/releases/download/ASAP-2.2-(Nightly)/ASAP-2.2-Ubuntu2204.deb RUN apt-get update && curl -L ${ASAP_URL} -o /tmp/ASAP.deb && apt-get install --assume-yes /tmp/ASAP.deb && \ @@ -116,6 +131,10 @@ RUN apt-get update && curl -L ${ASAP_URL} -o /tmp/ASAP.deb && apt-get install -- COPY --from=build /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages COPY --from=build /usr/local/bin /usr/local/bin +# register libnvimgcodec so cucim can use GPU-accelerated JPEG decoding +RUN echo "/usr/local/lib/python3.10/dist-packages/nvidia/nvimgcodec" > /etc/ld.so.conf.d/nvimgcodec.conf && \ + ldconfig + # copy app code COPY --from=build /opt/app /opt/app diff --git a/Dockerfile.ci b/Dockerfile.ci index 96a27db..3a4515c 100755 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -21,6 +21,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libtiff-dev \ zlib1g-dev \ curl \ + cmake \ vim screen \ zip unzip \ git \ @@ -31,6 +32,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* +# libjpeg-turbo 3.x (required by PyTurboJPEG>=2) +ARG LIBJPEG_TURBO_VERSION=3.1.0 +RUN curl -fsSL https://github.com/libjpeg-turbo/libjpeg-turbo/releases/download/${LIBJPEG_TURBO_VERSION}/libjpeg-turbo-${LIBJPEG_TURBO_VERSION}.tar.gz \ + | tar xz -C /tmp \ + && cd /tmp/libjpeg-turbo-${LIBJPEG_TURBO_VERSION} \ + && cmake -G"Unix Makefiles" -DCMAKE_INSTALL_PREFIX=/usr/local . \ + && make -j"$(nproc)" && make install \ + && ldconfig \ + && rm -rf /tmp/libjpeg-turbo-${LIBJPEG_TURBO_VERSION} + # ASAP ARG ASAP_URL=https://github.com/computationalpathologygroup/ASAP/releases/download/ASAP-2.2-(Nightly)/ASAP-2.2-Ubuntu2204.deb RUN set -eux; \ @@ -65,5 +76,9 @@ COPY --chown=user:user LICENSE /opt/app/LICENSE RUN python -m pip install /opt/app +# register libnvimgcodec so cucim can use GPU-accelerated JPEG decoding +RUN echo "/usr/local/lib/python3.10/dist-packages/nvidia/nvimgcodec" > /etc/ld.so.conf.d/nvimgcodec.conf && \ + ldconfig + USER user WORKDIR /opt/app diff --git a/Dockerfile.coding-agents b/Dockerfile.coding-agents new file mode 100644 index 0000000..0e456c8 --- /dev/null +++ b/Dockerfile.coding-agents @@ -0,0 +1,156 @@ +ARG UBUNTU_VERSION=22.04 +ARG CUDA_MAJOR_VERSION=12.8.1 + +######################## +# Stage 1: build stage # +######################## +FROM nvidia/cuda:${CUDA_MAJOR_VERSION}-cudnn-devel-ubuntu${UBUNTU_VERSION} AS build + +ARG USER_UID=1001 +ARG USER_GID=1001 + +# ensures that Python output to stdout/stderr is not buffered: prevents missing information when terminating +ENV PYTHONUNBUFFERED=1 +ENV DEBIAN_FRONTEND=noninteractive TZ=Europe/Amsterdam + +USER root + +RUN groupadd --gid ${USER_GID} user \ + && useradd -m --no-log-init --uid ${USER_UID} --gid ${USER_GID} user + +# create input/output directory +RUN mkdir /input /output && \ + chown user:user /input /output + +# set /home/user as working directory +WORKDIR /home/user +ENV PATH="/home/user/.local/bin:${PATH}" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libtiff-dev \ + cmake \ + zlib1g-dev \ + curl \ + vim screen \ + zip unzip \ + git \ + openssh-server \ + python3-pip python3-dev python-is-python3 \ + && mkdir /var/run/sshd \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# libjpeg-turbo 3.x (required by PyTurboJPEG>=2) +ARG LIBJPEG_TURBO_VERSION=3.1.0 +RUN curl -fsSL https://github.com/libjpeg-turbo/libjpeg-turbo/releases/download/${LIBJPEG_TURBO_VERSION}/libjpeg-turbo-${LIBJPEG_TURBO_VERSION}.tar.gz \ + | tar xz -C /tmp \ + && cd /tmp/libjpeg-turbo-${LIBJPEG_TURBO_VERSION} \ + && cmake -G"Unix Makefiles" -DCMAKE_INSTALL_PREFIX=/usr/local . \ + && make -j"$(nproc)" && make install \ + && ldconfig \ + && rm -rf /tmp/libjpeg-turbo-${LIBJPEG_TURBO_VERSION} + +WORKDIR /opt/app/ + +# core deps live in requirements.in; model runtime extras live in requirements-models.in +RUN python -m pip install --upgrade pip setuptools pip-tools \ + && rm -rf /home/user/.cache/pip + +# install slide2vec +COPY --chown=user:user requirements.in /opt/app/requirements.in +COPY --chown=user:user requirements-models.in /opt/app/requirements-models.in +RUN python -m pip install \ + --no-cache-dir \ + --no-color \ + --requirement /opt/app/requirements-models.in \ + && rm -rf /home/user/.cache/pip + +COPY --chown=user:user slide2vec /opt/app/slide2vec +COPY --chown=user:user setup.py /opt/app/setup.py +COPY --chown=user:user setup.cfg /opt/app/setup.cfg +COPY --chown=user:user pyproject.toml /opt/app/pyproject.toml +COPY --chown=user:user MANIFEST.in /opt/app/MANIFEST.in +COPY --chown=user:user README.md /opt/app/README.md +COPY --chown=user:user LICENSE /opt/app/LICENSE + +RUN python -m pip install /opt/app +RUN python -m pip install 'flash-attn>=2.7.1,<=2.8.0' --no-build-isolation + + +########################## +# Stage 2: runtime stage # +########################## +FROM nvidia/cuda:${CUDA_MAJOR_VERSION}-cudnn-runtime-ubuntu${UBUNTU_VERSION} + +ARG USER_UID=1001 +ARG USER_GID=1001 + +ENV PYTHONUNBUFFERED=1 +ENV DEBIAN_FRONTEND=noninteractive TZ=Europe/Amsterdam + +USER root + +RUN groupadd --gid ${USER_GID} user \ + && useradd -m --no-log-init --uid ${USER_UID} --gid ${USER_GID} user + +# create input/output directory +RUN mkdir /input /output && \ + chown user:user /input /output + +# set /home/user as working directory +WORKDIR /home/user +ENV PATH="/home/user/.local/bin:${PATH}" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libtiff-dev \ + zlib1g-dev \ + curl \ + vim screen \ + zip unzip \ + git \ + openssh-server \ + python3-pip python3-dev python-is-python3 \ + && mkdir /var/run/sshd \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# libjpeg-turbo 3.x (copied from build stage) +COPY --from=build /usr/local/lib/libjpeg* /usr/local/lib/ +COPY --from=build /usr/local/lib/libturbojpeg* /usr/local/lib/ +RUN ldconfig + +RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \ + && apt-get install -y --no-install-recommends nodejs + +# install ASAP +ARG ASAP_URL=https://github.com/computationalpathologygroup/ASAP/releases/download/ASAP-2.2-(Nightly)/ASAP-2.2-Ubuntu2204.deb +RUN apt-get update && curl -L ${ASAP_URL} -o /tmp/ASAP.deb && apt-get install --assume-yes /tmp/ASAP.deb && \ + SITE_PACKAGES=`python3 -c "import sysconfig; print(sysconfig.get_paths()['purelib'])"` && \ + printf "/opt/ASAP/bin/\n" > "${SITE_PACKAGES}/asap.pth" && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# install codex +RUN npm i -g @openai/codex + +# install claude +RUN curl -fsSL https://claude.ai/install.sh | bash + +# copy Python libs & entrypoints from build stage (includes flash-attn, your deps, ASAP .pth) +COPY --from=build /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages +COPY --from=build /usr/local/bin /usr/local/bin + +# register libnvimgcodec so cucim can use GPU-accelerated JPEG decoding +RUN echo "/usr/local/lib/python3.10/dist-packages/nvidia/nvimgcodec" > /etc/ld.so.conf.d/nvimgcodec.conf && \ + ldconfig + +# copy app code +COPY --from=build /opt/app /opt/app + +# expose port for ssh and jupyter +EXPOSE 22 8888 + +WORKDIR /opt/app/ + +# switch to user +USER user diff --git a/docs/benchmarking.md b/docs/benchmarking.md new file mode 100644 index 0000000..b7c6c9b --- /dev/null +++ b/docs/benchmarking.md @@ -0,0 +1,120 @@ +# Benchmarking + +`slide2vec` includes a benchmark runner for end-to-end embedding throughput sweeps across different GPU environments and multiple model configs. + +The script samples a balanced subset of your manifest, runs untimed warmups plus repeated measured trials, tunes only: + +- `model.batch_size` +- `speed.num_workers_embedding` + +It keeps the rest of each model config fixed, disables previews / resume / Weights & Biases, and writes: + +- `trial_results.csv` +- `best_results.csv` +- `throughput_by_gpu.png` +- `throughput_by_gpu_and_size.png` +- `tuning__.png` + +Default sweep values: + +- `--n-slides 0` to benchmark the full manifest by default +- `--batch-sizes 1 32 64 128 256` +- `--embedding-workers 4 8 16 32 64 128` + +## Basic Usage + +```shell +python scripts/benchmark_embedding_throughput.py \ + --config-files /path/to/pathojepa-small.yaml /path/to/pathojepa-base.yaml /path/to/pathojepa-large.yaml \ + --model-labels PathoJEPA-S PathoJEPA-B PathoJEPA-L \ + --size-labels S B L \ + --csv /path/to/slides.csv \ + --gpu-label "A100-80GB" \ + --batch-sizes 1 32 64 128 256 \ + --embedding-workers 4 8 16 32 64 128 \ + --repeat 3 \ + --n-slides 0 \ + --output-dir /tmp/slide2vec-benchmark +``` + +Notes: + +- the benchmark measures the full `Pipeline.run(...)` path, including tiling +- stage timings for tiling, embedding, and aggregation are also recorded when progress events are available +- embedding trials also record per-batch timing summaries from `embedding.batch.timing` events, including mean loader wait, mean ready-wait after async copy/preprocess, mean preprocess time, mean forward time, and a loader-wait fraction +- every compared model reuses the same sampled manifest within a run +- each config gets an untimed warmup before measured repeats +- benchmark config files are loaded through the same default-merge and validation path as the regular CLI, so omitted standard keys inherit the usual defaults + +Single-model usage is still supported: + +```shell +python scripts/benchmark_embedding_throughput.py \ + --config-file /path/to/model-config.yaml \ + --csv /path/to/slides.csv \ + --gpu-label "A100-80GB" +``` + +In multi-model mode: + +- `--config-files` is the primary interface +- `--model-labels` must match the config count +- `--size-labels` must match the config count +- size labels are explicit metadata like `S`, `B`, `L`, `XL`; the script does not infer them + +## Merging GPU Runs + +Run the benchmark once per GPU environment, then regenerate the cross-GPU comparison chart from multiple `trial_results.csv` files: + +```shell +python scripts/benchmark_embedding_throughput.py \ + --chart-only \ + /tmp/a100-benchmark/trial_results.csv \ + /tmp/h100-benchmark/trial_results.csv \ + --output-dir /tmp/slide2vec-benchmark-merged +``` + +The merged outputs include: + +- `throughput_by_gpu.png` for best tuned model entries per GPU +- `throughput_by_gpu_and_size.png` for grouped GPU-vs-size bars, choosing the winning model for each `(gpu, size)` bucket + +Use `--copy-locally` when your slide source lives on network storage and you want to reduce I/O variance during the sweep. + +## End-to-End Path Comparison + +For a direct full-pipeline comparison between: + +- tar-based embedding (`on_the_fly=false`) +- on-the-fly `wsd_single` embedding (`backend=asap`, `use_supertiles=false`) +- on-the-fly `cucim_supertiles` embedding + +use: + +```shell +python scripts/benchmark_end_to_end_paths.py \ + --csv /path/to/slides.csv \ + --config-file /path/to/model-config.yaml \ + --batch-size 256 \ + --repeat 1 \ + --output-dir /tmp/slide2vec-end-to-end +``` + +The model is taken from `--config-file`; the script does not accept a separate `--model` override. + +This benchmark runs the three paths independently from raw slide input to final embedding artifact and writes: + +- `trial_results.csv` +- `summary.csv` +- `end_to_end_by_path.png` +- `stage_breakdown.png` +- `embedding_subpath_breakdown.png` + +The summary also now includes an embedding subpath split derived from per-batch timing +events: + +- `mean_data_pipeline_seconds`: timed embedding seconds spent in loader wait, ready + wait, and preprocessing +- `mean_forward_seconds`: timed embedding seconds spent in model forward +- `mean_data_pipeline_fraction` / `mean_forward_fraction`: shares of the timed + embedding batches accounted for by those two buckets diff --git a/docs/cli.md b/docs/cli.md index 03fcdd8..caf7bc6 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -59,7 +59,7 @@ In practice, the config controls: - which model preset to use - preprocessing/tiling parameters - output directory -- batch size, workers, mixed precision, and GPU count +- batch size, workers, precision, and GPU count - whether to save tile artifacts alongside slide-level outputs ## Common Overrides @@ -86,7 +86,7 @@ Common overrides: ## Useful Flags - `--run-on-cpu` - Forces CPU inference and disables mixed precision. + Forces CPU inference and uses `speed.precision=fp32`. - `--tiling-only` Runs preprocessing/tiling without feature extraction. - `--output-dir /path/to/output` diff --git a/docs/gpu-throughput-optimization-protocol.md b/docs/gpu-throughput-optimization-protocol.md new file mode 100644 index 0000000..2b7d67d --- /dev/null +++ b/docs/gpu-throughput-optimization-protocol.md @@ -0,0 +1,161 @@ +# GPU Throughput Optimization Protocol + +You are optimizing slide2vec embedding throughput on this machine. Use the existing benchmark and timing metrics as the ground truth. Prioritize changes that maximize throughput, reduce loader_wait_fraction and mean_ready_wait_ms while preserving outputs. For every change, rerun the same benchmark slice, compare throughput and timing metrics to baseline, and keep only changes that improve throughput or clearly reduce GPU idle time. + +## Goal + +Iterate on `slide2vec` code to maximize embedding throughput while preserving correctness. +Primary optimization targets: + +- maximize throughput +- minimize `loader_wait_fraction` +- minimize `mean_loader_wait_ms` +- minimize `mean_ready_wait_ms` +- keep outputs unchanged + +## Recommended Config Shape + +Keep the preprocessing config unchanged, just vary model config to try different model sizes (from ViT-S to ViT-G) and embedding-related parameters (batch_size, num_workers_embeddimg, prefetch_factor_embedding, persistent_workers_embedding) + + +## Baseline Benchmark + +Start with one model config: + +Run: + +```bash +python slide2vec/scripts/benchmark_embedding_throughput.py \ + --config-file slide2vec/configs/h0-mini.yaml \ + --csv debug-histai-local.csv \ + --batch-sizes 32 64 128 256 512 \ + --embedding-workers 4 8 16 32 \ + --repeat 2 \ + --n-slides 0 \ + --output-dir output/benchmark +``` + +This benchmark writes per-trial metrics including embedding timing summaries derived from `embedding.batch.timing` events. + +## Follow-Up Targeted Sweep + +After the baseline: + +- identify the best 2-3 batch sizes +- identify the best 2-3 worker counts +- rerun a tighter sweep around them +- test `prefetch_factor_embedding` values `2`, `4`, `8` + +Example: + +```bash +python slide2vec/scripts/benchmark_embedding_throughput.py \ + --config-file slide2vec/configs/h0-mini.yaml \ + --csv debug-histai-local.csv \ + --batch-sizes 128 256 384 \ + --embedding-workers 8 16 \ + --repeat 3 \ + --n-slides 0 \ + --output-dir output/benchmark-tuned +``` + +## Metrics To Optimize + +Read these from `trial_results.csv`, `best_results.csv`, and `metrics.json`: + +- throughput +- `loader_wait_fraction` +- `mean_loader_wait_ms` +- `max_loader_wait_ms` +- `mean_ready_wait_ms` +- `mean_preprocess_ms` +- `mean_forward_ms` +- `timed_batches` + +Interpretation: + +- high `loader_wait_fraction`: the reader side is the bottleneck +- high `mean_ready_wait_ms`: transfer or preprocessing is not overlapping enough with forward +- high `mean_preprocess_ms` with low `mean_forward_ms`: preprocessing is the bottleneck +- throughput flattening while `mean_forward_ms` dominates: the run is compute-bound rather than loader-bound + +## GPU Telemetry + +Capture lightweight telemetry during benchmark runs: + +```bash +nvidia-smi dmon -s pucvmet -d 1 +``` + +or: + +```bash +watch -n 1 nvidia-smi +``` + +Record: + +- GPU utilization +- memory usage +- power +- SM activity trends during the run + +## Artifacts To Hand To The Optimizing Agent + +Provide: + +- benchmark output directory +- `trial_results.csv` +- `best_results.csv` +- `metrics.json` +- progress JSONL if present +- one or two `.nsys-rep` files +- the exact model config YAML +- GPU type +- CPU core count +- local disk type +- slide count + +## Instructions For The Optimizing Agent + +Give the agent a prompt like: + +```text +You are optimizing slide2vec embedding throughput for h0-mini on a single GPU. Start from slide2vec/configs/models/h0-mini.yaml and benchmark with python slide2vec/scripts/benchmark_embedding_throughput.py --config-file slide2vec/configs/models/h0-mini.yaml --csv debug-histai-local.csv --batch-sizes 32 64 128 256 512 --embedding-workers 4 8 16 32 --repeat 2 --n-slides 0 --output-dir output/benchmark-baseline. + +Your goal is to maximize throughput while preserving embedding correctness. You may change config parameters and, if justified by the metrics, change the codebase. Prioritize improvements that reduce loader_wait_fraction, mean_loader_wait_ms, and mean_ready_wait_ms. Test prefetch_factor_embedding and persistent_workers_embedding. Keep one variable sweep tight and controlled, compare every run to the same baseline, and only keep a change if throughput improves or GPU idle time clearly drops. + +After each promising change, rerun the same benchmark slice, record the throughput delta and timing deltas, and summarize whether the bottleneck is reader-bound, preprocess-bound, or compute-bound. If code changes are made, keep them minimal, document them under docs/optimize-throughput, rerun the benchmark, and verify that output shapes and metadata contracts stay unchanged. Do not count a change as good unless throughput improves or idle-related metrics clearly improve. +``` + +Additional constraints for the agent: + +- compare against the same manifest +- compare on the same GPU type +- compare with the same batch-size and worker grid unless intentionally testing a new knob +- do not count a change as good unless throughput improves or idle-related metrics clearly improve +- preserve embedding outputs and metadata contracts + +## Suggested Iteration Loop + +For each code change: + +1. run the same benchmark slice used for the baseline +2. compare throughput and timing metrics against the baseline +3. keep the change only if it improves throughput or materially reduces idle time +4. rerun one Nsight Systems profile when a change looks promising +5. keep notes on: + - what changed + - throughput delta + - loader-wait delta + - ready-wait delta + - whether correctness changed + +## Success Criteria + +The optimization effort is successful when: + +- throughput improves materially on the target GPU +- `loader_wait_fraction` becomes a small minority of embedding time +- large batches are limited mainly by compute or memory, not by loader stalls +- Nsight shows reduced gaps between forward passes diff --git a/docs/models.md b/docs/models.md index 38d8049..33fa794 100644 --- a/docs/models.md +++ b/docs/models.md @@ -1,48 +1,28 @@ # Supported Models -`slide2vec` currently ships preset configs for 18 model entries: - -- 10 tile-level presets -- 5 region-level presets -- 3 slide-level presets - -The canonical preset files live under `slide2vec/configs/models/`. - -## Tile-Level Presets - -| Preset | Model | Architecture | Parameters | -| --- | --- | --- | --- | -| `conch` | [CONCH](https://huggingface.co/MahmoodLab/conch) | ViT-B/16 | 86M | -| `h-optimus-1` | [H-optimus-1](https://huggingface.co/bioptimus/H-optimus-1) | ViT-G/14 | 1.1B | -| `h0-mini` | [H0-mini](https://huggingface.co/bioptimus/H0-mini) | ViT-B/16 | 86M | -| `hibou` | [Hibou-B](https://huggingface.co/histai/hibou-b) / [Hibou-L](https://huggingface.co/histai/hibou-L) | ViT-B/16 or ViT-L/16 | 86M / 307M | -| `kaiko-midnight` | [MidNight12k](https://huggingface.co/kaiko-ai/midnight) | ViT-G/14 | 1.1B | -| `kaiko` | [Kaiko](https://github.com/kaiko-ai/towards_large_pathology_fms) | Various | 86M - 307M | -| `musk` | [MUSK](https://huggingface.co/xiangjx/musk) | ViT-L/16 | 307M | -| `phikonv2` | [Phikon-v2](https://huggingface.co/owkin/phikon-v2) | ViT-L/16 | 307M | -| `prov-gigapath-tile` | [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) | ViT-G/14 | 1.1B | -| `uni2` | [UNI2](https://huggingface.co/MahmoodLab/UNI2-h) | ViT-G/14 | 1.1B | - -## Region-Level Presets - -| Preset | Model | Architecture | Parameters | -| --- | --- | --- | --- | -| `h-optimus-0` | [H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) | ViT-G/14 | 1.1B | -| `pathojepa` | PathoJEPA | ViT-S/16 (default) | 22M | -| `uni` | [UNI](https://huggingface.co/MahmoodLab/UNI) | ViT-L/16 | 307M | -| `virchow` | [Virchow](https://huggingface.co/paige-ai/Virchow) | ViT-H/14 | 632M | -| `virchow2` | [Virchow2](https://huggingface.co/paige-ai/Virchow2) | ViT-H/14 | 632M | - -## Slide-Level Presets - -| Preset | Model | Architecture | Parameters | -| --- | --- | --- | --- | -| `prism` | [PRISM](https://huggingface.co/paige-ai/PRISM) | Perceiver Resampler | 99M | -| `prov-gigapath-slide` | [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) | Transformer (LongNet) | 87M | -| `titan` | [TITAN](https://huggingface.co/MahmoodLab/TITAN) | Transformer | 49M | - -## Notes - -- `Model.from_pretrained(...)` chooses a default level for some models; pass `level=` explicitly when you want a non-default preset behavior. -- The `hibou` preset supports both Hibou-B and Hibou-L variants through model options. -- The README stays intentionally short; use this page and [`python-api.md`](/Users/clems/Code/slide2vec/docs/python-api.md) for fuller reference material. +The canonical model presets live under `slide2vec/configs/models/`. Use the table below as the single source of truth for: + +- which preset entries ship with `slide2vec` +- which encoder level each entry uses +- which spacing values are supported by the pretrained-model validator + +| Preset | Model | Encoder Level | Supported Spacing (um) | Notes | +| --- | --- | --- | --- | --- | +| `conch` | [CONCH](https://huggingface.co/MahmoodLab/conch) | `tile` | `0.5` | | +| `conchv15` | [CONCHv1.5](https://huggingface.co/MahmoodLab/conchv1_5) | `tile` | `0.5` | | +| `h-optimus-0` | [H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) | `tile` | `0.5` | | +| `h-optimus-1` | [H-optimus-1](https://huggingface.co/bioptimus/H-optimus-1) | `tile` | `0.5` | | +| `h0-mini` | [H0-mini](https://huggingface.co/bioptimus/H0-mini) | `tile` | `0.5` | Supports `mode="cls"` or `mode="full"` | +| `hibou` | [Hibou-B](https://huggingface.co/histai/hibou-b) / [Hibou-L](https://huggingface.co/histai/hibou-L) | `tile` | `0.5` | Supports `arch="hibou-b"` and `arch="hibou-L"` | +| `kaiko` | [Kaiko](https://github.com/kaiko-ai/towards_large_pathology_fms) | `tile` | `2.0`, `1.0`, `0.5`, `0.25` | Supports `arch` in [`vits8`, `vits16`, `vitb8`, `vitb16`, `vitl14`] | +| `kaiko-midnight` | [MidNight12k](https://huggingface.co/kaiko-ai/midnight) | `tile` | `2.0`, `1.0`, `0.5`, `0.25` | | +| `musk` | [MUSK](https://huggingface.co/xiangjx/musk) | `tile` | `1.0`, `0.5`, `0.25` | | +| `phikonv2` | [Phikon-v2](https://huggingface.co/owkin/phikon-v2) | `tile` | `0.5` | | +| `uni` | [UNI](https://huggingface.co/MahmoodLab/UNI) | `tile` | `0.5` | | +| `uni2` | [UNI2](https://huggingface.co/MahmoodLab/UNI2-h) | `tile` | `0.5` | | +| `virchow` | [Virchow](https://huggingface.co/paige-ai/Virchow) | `tile` | `0.5` | Supports `mode="cls"` or `mode="full"` | +| `virchow2` | [Virchow2](https://huggingface.co/paige-ai/Virchow2) | `tile` | `2.0`, `1.0`, `0.5`, `0.25` | Supports `mode="cls"` or `mode="full"` | +| `prov-gigapath-tile` | [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) | `tile` | `0.5` | | +| `prov-gigapath-slide` | [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) | `slide` | `0.5` | | +| `prism` | [PRISM](https://huggingface.co/paige-ai/PRISM) | `slide` | `0.5` | | +| `titan` | [TITAN](https://huggingface.co/MahmoodLab/TITAN) | `slide` | `0.5` | | \ No newline at end of file diff --git a/docs/python-api.md b/docs/python-api.md index 37fe024..7ae5825 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -60,6 +60,7 @@ Commonly overridden fields: - `target_tile_size_px` - `tissue_threshold` - `backend` + `backend` is the tiling / HS2P slide-reader backend. It may be `"asap"` or `"openslide"` depending on the reader you want HS2P to use. Defaults that most users can leave alone: @@ -70,9 +71,12 @@ Defaults that most users can leave alone: - `segmentation={}` - `filtering={}` - `preview={}` -- `read_tiles_from=None` +- `read_coordinates_from=/coordinates` when omitted +- `read_tiles_from=None` unless you want slide2vec to reuse an explicitly linked external `.tiles.tar` store root - `resume=False` +`slide2vec` writes `.tiles.tar` stores during tiling by default. Set `read_tiles_from` only when you want embedding to consume an existing external tile-store root instead of the stores generated in the current run. + Advanced example: ```python @@ -101,12 +105,17 @@ preprocessing = PreprocessingConfig( Defaults to `1` in the Python API unless you set it explicitly. - `num_workers` - `num_gpus` -- `mixed_precision` +- `precision` +- `prefetch_factor` +- `persistent_workers` +- `gpu_batch_preprocessing` - `save_tile_embeddings` - `save_latents` `.pt` is the default output format. Use `output_format="npz"` to write NumPy artifacts instead. +`precision` accepts `fp32`, `fp16`, or `bf16`. When you omit it in the Python API, `slide2vec` resolves it to the model's recommended precision when one is known. + For slide-level models, `save_tile_embeddings=False` skips persisted tile embedding artifacts while still returning tile embeddings in-memory from direct APIs. `num_gpus` defaults to all available GPUs. You can set it to control how many GPUs `slide2vec` uses for either direct or manifest-driven workflows: diff --git a/requirements.in b/requirements.in index 6803977..1e00979 100644 --- a/requirements.in +++ b/requirements.in @@ -6,7 +6,7 @@ pandas pillow rich tqdm -hs2p>=2.3.0,<3 +hs2p>=2.4.1,<3 torch torchvision transformers diff --git a/requirements.txt b/requirements.txt index d1c0d10..d091fc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -hs2p>=2.3.0,<3 +hs2p>=2.4.0,<3 omegaconf>=2.3.0 h5py matplotlib diff --git a/scripts/benchmark_embedding_throughput.py b/scripts/benchmark_embedding_throughput.py new file mode 100644 index 0000000..6556c7f --- /dev/null +++ b/scripts/benchmark_embedding_throughput.py @@ -0,0 +1,1295 @@ +#!/usr/bin/env python3 + +import argparse +import copy +import csv +import json +import math +import os +import random +import shutil +import statistics +import subprocess +import sys +import time +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import numpy as np + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_OUTPUT_DIR = Path("output/benchmark") +HEAVY_ARTIFACT_DIRS = ( + "coordinates", + "tile_embeddings", + "slide_embeddings", + "slide_latents", + "previews", +) + + +def _prepend_repo_root_to_sys_path(paths: list[str]) -> list[str]: + repo_root = str(REPO_ROOT) + return [repo_root, *[path for path in paths if path != repo_root]] + + +sys.path[:] = _prepend_repo_root_to_sys_path(sys.path) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark slide2vec end-to-end embedding throughput across tuned runtime configs.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--csv", type=Path, default=None, help="Manifest CSV to benchmark.") + parser.add_argument("--config-file", type=Path, required=False, help="Base slide2vec YAML config.") + parser.add_argument( + "--config-files", + type=Path, + nargs="+", + default=None, + help="Multiple model config files to compare in one benchmark run.", + ) + parser.add_argument( + "--model-labels", + nargs="+", + default=None, + help="Display labels for the provided model configs.", + ) + parser.add_argument( + "--size-labels", + nargs="+", + default=None, + help="Explicit size labels such as S/B/L/XL for the provided model configs.", + ) + parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR, help="Output directory for benchmark artifacts.") + parser.add_argument("--repeat", type=int, default=1, help="Number of timed repeats per config.") + parser.add_argument("--seed", type=int, default=42, help="Random seed for manifest sampling.") + parser.add_argument( + "--n-slides", + type=int, + default=0, + help="Number of balanced slides to sample from the manifest. Set to 0 to use all slides.", + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=[1, 32, 64, 128, 256], + help="Batch sizes to sweep.", + ) + parser.add_argument( + "--embedding-workers", + type=int, + nargs="+", + default=[4, 8, 16, 32, 64, 128], + help="Embedding dataloader workers to sweep.", + ) + parser.add_argument( + "--num-gpus", + type=int, + nargs="+", + default=[1], + help="Number of GPUs to sweep.", + ) + parser.add_argument("--gpu-label", default="auto", help="Label used to identify this GPU environment in results.") + parser.add_argument("--copy-locally", action="store_true", help="Copy sampled slides to --local-dir before benchmarking.") + parser.add_argument( + "--local-dir", + type=Path, + default=Path("/tmp-data/slide2vec-benchmark-slides"), + help="Destination for local slide copies when --copy-locally is set.", + ) + parser.add_argument( + "--chart-only", + type=Path, + nargs="+", + default=None, + metavar="TRIAL_RESULTS_CSV", + help="Skip benchmarking and regenerate aggregate outputs from one or more trial-results CSV files.", + ) + + parser.add_argument("--internal-harness", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--metrics-json", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--progress-jsonl", type=Path, default=None, help=argparse.SUPPRESS) + return parser.parse_args() + + +def sanitize_label(value: str) -> str: + sanitized = "".join(char.lower() if char.isalnum() else "-" for char in value.strip()) + while "--" in sanitized: + sanitized = sanitized.replace("--", "-") + return sanitized.strip("-") or "item" + + +def resolve_model_specs(args: argparse.Namespace) -> list[dict[str, Any]]: + config_files = list(args.config_files or ([] if args.config_file is None else [args.config_file])) + if not config_files: + raise ValueError("Provide --config-file or --config-files.") + + if len(config_files) == 1: + config_file = Path(config_files[0]) + model_label = args.model_labels[0] if args.model_labels else config_file.stem + size_label = args.size_labels[0] if args.size_labels else "unspecified" + return [ + { + "config_file": config_file, + "model_label": model_label, + "size_label": size_label, + } + ] + + if args.model_labels is None or len(args.model_labels) != len(config_files): + raise ValueError("--model-labels must match the number of --config-files entries.") + if args.size_labels is None or len(args.size_labels) != len(config_files): + raise ValueError("--size-labels must match the number of --config-files entries.") + + specs: list[dict[str, Any]] = [] + for config_file, model_label, size_label in zip(config_files, args.model_labels, args.size_labels): + specs.append( + { + "config_file": Path(config_file), + "model_label": str(model_label), + "size_label": str(size_label), + } + ) + return specs + + +def load_slides_from_csv(csv_path: Path) -> list[dict[str, Any]]: + slides: list[dict[str, Any]] = [] + with csv_path.open(newline="") as handle: + reader = csv.DictReader(handle) + fieldnames = set(reader.fieldnames or []) + has_mask = "mask_path" in fieldnames + has_spacing = "spacing_at_level_0" in fieldnames + has_sample_id = "sample_id" in fieldnames + for row in reader: + image_path = Path(row["image_path"]) + mask_path = Path(row["mask_path"]) if has_mask and row.get("mask_path") else None + raw_spacing = row.get("spacing_at_level_0", "") if has_spacing else "" + spacing_at_level_0 = float(raw_spacing) if raw_spacing.strip() else None + sample_id = row["sample_id"] if has_sample_id else image_path.stem + size_bytes = image_path.stat().st_size if image_path.is_file() else 0 + slides.append( + { + "sample_id": sample_id, + "image_path": image_path, + "mask_path": mask_path, + "spacing_at_level_0": spacing_at_level_0, + "size_bytes": size_bytes, + } + ) + return slides + + +def stratified_sample(slides: list[dict[str, Any]], n: int, *, seed: int) -> list[dict[str, Any]]: + rng = random.Random(seed) + if n <= 0 or len(slides) <= n: + return list(slides) + + sizes = [slide["size_bytes"] for slide in slides] + q33 = float(np.percentile(sizes, 33)) + q66 = float(np.percentile(sizes, 66)) + + small = [slide for slide in slides if slide["size_bytes"] < q33] + medium = [slide for slide in slides if q33 <= slide["size_bytes"] < q66] + large = [slide for slide in slides if slide["size_bytes"] >= q66] + + per_bin = n // 3 + remainder = n - per_bin * 3 + + sampled: list[dict[str, Any]] = [] + for index, bucket in enumerate((small, medium, large)): + take = per_bin + (1 if index < remainder else 0) + sampled.extend(rng.sample(bucket, min(take, len(bucket)))) + + if len(sampled) < n: + pool = [slide for slide in slides if slide not in sampled] + sampled.extend(rng.sample(pool, min(n - len(sampled), len(pool)))) + + rng.shuffle(sampled) + return sampled[:n] + + +def build_balanced_sample(slides: list[dict[str, Any]], *, n_slides: int, seed: int) -> list[dict[str, Any]]: + return stratified_sample(slides, n_slides, seed=seed)[:n_slides] + + +def copy_slides_locally(slides: list[dict[str, Any]], local_dir: Path) -> list[dict[str, Any]]: + local_dir.mkdir(parents=True, exist_ok=True) + updated: list[dict[str, Any]] = [] + for slide in slides: + src_image = slide["image_path"] + dst_image = local_dir / f"{slide['sample_id']}{src_image.suffix}" + if not dst_image.exists(): + shutil.copy2(src_image, dst_image) + + dst_mask: Path | None = None + if slide["mask_path"] is not None: + src_mask = slide["mask_path"] + dst_mask = local_dir / f"{slide['sample_id']}.mask{src_mask.suffix}" + if not dst_mask.exists(): + shutil.copy2(src_mask, dst_mask) + + updated.append({**slide, "image_path": dst_image, "mask_path": dst_mask}) + return updated + + +def write_slides_csv(slides: list[dict[str, Any]], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + has_spacing = any(slide.get("spacing_at_level_0") is not None for slide in slides) + fieldnames = ["sample_id", "image_path", "mask_path"] + if has_spacing: + fieldnames.append("spacing_at_level_0") + + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for slide in slides: + row: dict[str, Any] = { + "sample_id": slide["sample_id"], + "image_path": str(slide["image_path"]), + "mask_path": str(slide["mask_path"]) if slide["mask_path"] is not None else "", + } + if has_spacing: + row["spacing_at_level_0"] = slide.get("spacing_at_level_0") or "" + writer.writerow(row) + + +def build_trial_plan( + *, + output_root: Path, + model_specs: list[dict[str, Any]], + batch_sizes: list[int], + embedding_workers: list[int], + num_gpus: list[int], + repeat: int, +) -> list[dict[str, Any]]: + plan: list[dict[str, Any]] = [] + for model_spec in model_specs: + model_root = output_root / "runs" / sanitize_label(str(model_spec["model_label"])) + for n_gpus in num_gpus: + for batch_size in batch_sizes: + for worker_count in embedding_workers: + config_root = model_root / f"ng-{n_gpus:02d}" / f"bs-{batch_size:04d}" / f"ew-{worker_count:02d}" + plan.append( + { + "kind": "warmup", + "config_file": Path(model_spec["config_file"]), + "model_label": str(model_spec["model_label"]), + "size_label": str(model_spec["size_label"]), + "batch_size": batch_size, + "embedding_workers": worker_count, + "num_gpus": n_gpus, + "repeat_index": 0, + "run_dir": config_root / "warmup", + } + ) + for repeat_index in range(1, repeat + 1): + plan.append( + { + "kind": "measure", + "config_file": Path(model_spec["config_file"]), + "model_label": str(model_spec["model_label"]), + "size_label": str(model_spec["size_label"]), + "batch_size": batch_size, + "embedding_workers": worker_count, + "num_gpus": n_gpus, + "repeat_index": repeat_index, + "run_dir": config_root / f"rep-{repeat_index:02d}", + } + ) + return plan + + +def _to_namespace(value: Any) -> Any: + if isinstance(value, dict): + return SimpleNamespace(**{key: _to_namespace(item) for key, item in value.items()}) + if isinstance(value, list): + return [_to_namespace(item) for item in value] + return value + + +def _to_plain_data(value: Any) -> Any: + if value.__class__.__module__.startswith("omegaconf"): + from omegaconf import OmegaConf + + return OmegaConf.to_container(value, resolve=True) + if isinstance(value, SimpleNamespace): + return {key: _to_plain_data(item) for key, item in vars(value).items()} + if isinstance(value, dict): + return {key: _to_plain_data(item) for key, item in value.items()} + if isinstance(value, list): + return [_to_plain_data(item) for item in value] + if isinstance(value, Path): + return str(value) + return value + + +def build_trial_config( + base_config: dict[str, Any] | SimpleNamespace, + *, + csv_path: Path, + output_dir: Path, + batch_size: int, + embedding_workers: int, + num_gpus: int = 1, +) -> SimpleNamespace: + base_data = _to_plain_data(base_config) + config = copy.deepcopy(base_data) + config.setdefault("model", {}) + config.setdefault("speed", {}) + config.setdefault("wandb", {}) + config["csv"] = str(csv_path) + config["output_dir"] = str(output_dir) + config["resume"] = False + config["save_previews"] = False + config["wandb"]["enable"] = False + config["model"]["batch_size"] = int(batch_size) + config["speed"]["num_workers_embedding"] = int(embedding_workers) + config["speed"]["num_gpus"] = int(num_gpus) + return _to_namespace(config) + + +def _load_yaml(path: Path) -> dict[str, Any]: + import yaml + + with path.open(encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + if not isinstance(data, dict): + raise ValueError(f"Expected mapping config in {path}") + return data + + +def _load_cli_merged_config(path: Path) -> dict[str, Any]: + from slide2vec.utils.config import get_cfg_from_args + + cfg = get_cfg_from_args( + argparse.Namespace( + config_file=str(path), + output_dir=None, + opts=[], + ) + ) + plain = _to_plain_data(cfg) + if not isinstance(plain, dict): + raise ValueError(f"Expected mapping config in {path}") + return plain + + +def _write_yaml(data: dict[str, Any], path: Path) -> None: + import yaml + + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + yaml.safe_dump(data, handle, sort_keys=False) + + +def load_progress_records(path: Path) -> list[dict[str, Any]]: + if not path.is_file(): + return [] + records: list[dict[str, Any]] = [] + with path.open(encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + payload = json.loads(line) + if isinstance(payload, dict): + records.append(payload) + return records + + +def extract_stage_seconds(progress_path: Path) -> dict[str, float | None]: + records = load_progress_records(progress_path) + stage_seconds = { + "tiling_seconds": None, + "embedding_seconds": None, + "aggregation_seconds": None, + } + if not records: + return stage_seconds + + first_timestamps: dict[str, float] = {} + aggregation_starts: dict[str, float] = {} + aggregation_total = 0.0 + + for record in records: + kind = record.get("kind") + timestamp = record.get("timestamp") + if kind is None or timestamp is None: + continue + timestamp = float(timestamp) + if kind not in first_timestamps: + first_timestamps[kind] = timestamp + if kind == "aggregation.started": + sample_id = str(record.get("payload", {}).get("sample_id", "")) + aggregation_starts[sample_id] = timestamp + elif kind == "aggregation.finished": + sample_id = str(record.get("payload", {}).get("sample_id", "")) + started = aggregation_starts.pop(sample_id, None) + if started is not None: + aggregation_total += max(0.0, timestamp - started) + + if "tiling.started" in first_timestamps and "tiling.finished" in first_timestamps: + stage_seconds["tiling_seconds"] = round(first_timestamps["tiling.finished"] - first_timestamps["tiling.started"], 4) + if "embedding.started" in first_timestamps and "embedding.finished" in first_timestamps: + stage_seconds["embedding_seconds"] = round( + first_timestamps["embedding.finished"] - first_timestamps["embedding.started"], + 4, + ) + if aggregation_total > 0: + stage_seconds["aggregation_seconds"] = round(aggregation_total, 4) + return stage_seconds + + +def extract_batch_timing_metrics(progress_path: Path) -> dict[str, float | int]: + records = load_progress_records(progress_path) + batch_payloads = [ + record.get("payload", {}) + for record in records + if record.get("kind") == "embedding.batch.timing" and isinstance(record.get("payload"), dict) + ] + if not batch_payloads: + return { + "timed_batches": 0, + "mean_loader_wait_ms": 0.0, + "max_loader_wait_ms": 0.0, + "mean_ready_wait_ms": 0.0, + "mean_preprocess_ms": 0.0, + "mean_worker_batch_ms": 0.0, + "mean_reader_open_ms": 0.0, + "mean_reader_read_ms": 0.0, + "mean_forward_ms": 0.0, + "loader_wait_fraction": 0.0, + "gpu_busy_fraction": 0.0, + } + + loader_wait_ms = [float(payload.get("loader_wait_ms", 0.0)) for payload in batch_payloads] + ready_wait_ms = [float(payload.get("ready_wait_ms", 0.0)) for payload in batch_payloads] + preprocess_ms = [float(payload.get("preprocess_ms", 0.0)) for payload in batch_payloads] + worker_batch_ms = [float(payload.get("worker_batch_ms", 0.0)) for payload in batch_payloads] + reader_open_ms = [float(payload.get("reader_open_ms", 0.0)) for payload in batch_payloads] + reader_read_ms = [float(payload.get("reader_read_ms", 0.0)) for payload in batch_payloads] + forward_ms = [float(payload.get("forward_ms", 0.0)) for payload in batch_payloads] + gpu_busy_fraction = [float(payload.get("gpu_busy_fraction", 0.0)) for payload in batch_payloads] + total_ms = sum(loader_wait_ms) + sum(ready_wait_ms) + sum(preprocess_ms) + sum(forward_ms) + return { + "timed_batches": len(batch_payloads), + "mean_loader_wait_ms": round(statistics.mean(loader_wait_ms), 4), + "max_loader_wait_ms": round(max(loader_wait_ms), 4), + "mean_ready_wait_ms": round(statistics.mean(ready_wait_ms), 4), + "mean_preprocess_ms": round(statistics.mean(preprocess_ms), 4), + "mean_worker_batch_ms": round(statistics.mean(worker_batch_ms), 4), + "mean_reader_open_ms": round(statistics.mean(reader_open_ms), 4), + "mean_reader_read_ms": round(statistics.mean(reader_read_ms), 4), + "mean_forward_ms": round(statistics.mean(forward_ms), 4), + "loader_wait_fraction": round((sum(loader_wait_ms) + sum(ready_wait_ms)) / total_ms, 4) if total_ms > 0 else 0.0, + "gpu_busy_fraction": round(statistics.mean(gpu_busy_fraction), 4), + } + + +def parse_process_list(path: Path) -> dict[str, int]: + if not path.is_file(): + return { + "slides_total": 0, + "slides_with_tiles": 0, + "failed_slides": 0, + "total_tiles": 0, + } + + with path.open(newline="") as handle: + rows = list(csv.DictReader(handle)) + + total_tiles = sum(int(float(row.get("num_tiles") or 0)) for row in rows) + slides_with_tiles = sum(int(float(row.get("num_tiles") or 0)) > 0 for row in rows) + failed_slides = sum(row.get("tiling_status") == "failed" for row in rows) + return { + "slides_total": len(rows), + "slides_with_tiles": slides_with_tiles, + "failed_slides": failed_slides, + "total_tiles": total_tiles, + } + + +def _coerce_csv_row(row: dict[str, str]) -> dict[str, Any]: + int_fields = { + "batch_size", + "embedding_workers", + "num_gpus", + "repeat_index", + "repeat_count", + "exit_code", + "slides_total", + "slides_with_tiles", + "failed_slides", + "total_tiles", + "timed_batches", + } + float_fields = { + "tiles_per_second", + "slides_per_second", + "end_to_end_seconds", + "tiling_seconds", + "embedding_seconds", + "aggregation_seconds", + "mean_loader_wait_ms", + "max_loader_wait_ms", + "mean_ready_wait_ms", + "mean_preprocess_ms", + "mean_worker_batch_ms", + "mean_reader_open_ms", + "mean_reader_read_ms", + "mean_forward_ms", + "loader_wait_fraction", + "gpu_busy_fraction", + "mean_tiles_per_second", + "std_tiles_per_second", + "mean_end_to_end_seconds", + "mean_slides_per_second", + "mean_mean_loader_wait_ms", + "mean_max_loader_wait_ms", + "mean_mean_ready_wait_ms", + "mean_mean_preprocess_ms", + "mean_mean_worker_batch_ms", + "mean_mean_reader_open_ms", + "mean_mean_reader_read_ms", + "mean_mean_forward_ms", + "mean_loader_wait_fraction", + "mean_gpu_busy_fraction", + } + parsed: dict[str, Any] = {} + for key, value in row.items(): + if value == "": + parsed[key] = "" + elif key in int_fields: + parsed[key] = int(float(value)) + elif key in float_fields: + parsed[key] = float(value) + else: + parsed[key] = value + return parsed + + +def load_trial_results_csvs(paths: list[Path]) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + for path in paths: + with path.open(newline="") as handle: + reader = csv.DictReader(handle) + rows.extend(_coerce_csv_row(row) for row in reader) + return rows + + +def aggregate_trial_results(trial_rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + grouped: dict[tuple[str, str, str, int, int, int], list[dict[str, Any]]] = {} + for row in trial_rows: + if row.get("exit_code", 0) not in (0, "", None): + continue + gpu_label = str(row["gpu_label"]) + model_label = str(row.get("model_label", "")) + size_label = str(row.get("size_label", "")) + batch_size = int(row["batch_size"]) + embedding_workers = int(row["embedding_workers"]) + num_gpus = int(row.get("num_gpus", 1)) + grouped.setdefault((gpu_label, model_label, size_label, batch_size, embedding_workers, num_gpus), []).append(row) + + aggregated: list[dict[str, Any]] = [] + for (gpu_label, model_label, size_label, batch_size, embedding_workers, num_gpus), rows in sorted(grouped.items()): + tiles_per_second = [float(row["tiles_per_second"]) for row in rows] + end_to_end_seconds = [float(row["end_to_end_seconds"]) for row in rows] + slides_per_second = [float(row.get("slides_per_second", 0.0)) for row in rows] + aggregation_seconds = [float(row["aggregation_seconds"]) for row in rows if row.get("aggregation_seconds") not in ("", None)] + embedding_seconds = [float(row["embedding_seconds"]) for row in rows if row.get("embedding_seconds") not in ("", None)] + tiling_seconds = [float(row["tiling_seconds"]) for row in rows if row.get("tiling_seconds") not in ("", None)] + mean_loader_wait_ms = [float(row.get("mean_loader_wait_ms", 0.0)) for row in rows] + max_loader_wait_ms = [float(row.get("max_loader_wait_ms", 0.0)) for row in rows] + mean_ready_wait_ms = [float(row.get("mean_ready_wait_ms", 0.0)) for row in rows] + mean_preprocess_ms = [float(row.get("mean_preprocess_ms", 0.0)) for row in rows] + mean_worker_batch_ms = [float(row.get("mean_worker_batch_ms", 0.0)) for row in rows] + mean_reader_open_ms = [float(row.get("mean_reader_open_ms", 0.0)) for row in rows] + mean_reader_read_ms = [float(row.get("mean_reader_read_ms", 0.0)) for row in rows] + mean_forward_ms = [float(row.get("mean_forward_ms", 0.0)) for row in rows] + loader_wait_fraction = [float(row.get("loader_wait_fraction", 0.0)) for row in rows] + gpu_busy_fraction = [float(row.get("gpu_busy_fraction", 0.0)) for row in rows] + timed_batches = [int(row.get("timed_batches", 0)) for row in rows] + aggregated.append( + { + "gpu_label": gpu_label, + "model_label": model_label, + "size_label": size_label, + "config_file": str(rows[0].get("config_file", "")), + "batch_size": batch_size, + "embedding_workers": embedding_workers, + "num_gpus": num_gpus, + "repeat_count": len(rows), + "mean_timed_batches": round(statistics.mean(timed_batches), 4), + "mean_tiles_per_second": round(statistics.mean(tiles_per_second), 4), + "std_tiles_per_second": round(statistics.pstdev(tiles_per_second), 4) if len(tiles_per_second) > 1 else 0.0, + "mean_end_to_end_seconds": round(statistics.mean(end_to_end_seconds), 4), + "mean_slides_per_second": round(statistics.mean(slides_per_second), 4), + "mean_tiling_seconds": round(statistics.mean(tiling_seconds), 4) if tiling_seconds else "", + "mean_embedding_seconds": round(statistics.mean(embedding_seconds), 4) if embedding_seconds else "", + "mean_aggregation_seconds": round(statistics.mean(aggregation_seconds), 4) if aggregation_seconds else "", + "mean_loader_wait_ms": round(statistics.mean(mean_loader_wait_ms), 4), + "max_loader_wait_ms": round(max(max_loader_wait_ms), 4), + "mean_ready_wait_ms": round(statistics.mean(mean_ready_wait_ms), 4), + "mean_preprocess_ms": round(statistics.mean(mean_preprocess_ms), 4), + "mean_worker_batch_ms": round(statistics.mean(mean_worker_batch_ms), 4), + "mean_reader_open_ms": round(statistics.mean(mean_reader_open_ms), 4), + "mean_reader_read_ms": round(statistics.mean(mean_reader_read_ms), 4), + "mean_forward_ms": round(statistics.mean(mean_forward_ms), 4), + "loader_wait_fraction": round(statistics.mean(loader_wait_fraction), 4), + "gpu_busy_fraction": round(statistics.mean(gpu_busy_fraction), 4), + } + ) + return aggregated + + +def select_best_results(aggregated_rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + best_by_group: dict[tuple[str, str, str, int], dict[str, Any]] = {} + for row in aggregated_rows: + gpu_label = str(row["gpu_label"]) + model_label = str(row.get("model_label", "")) + size_label = str(row.get("size_label", "")) + num_gpus = int(row.get("num_gpus", 1)) + key = (gpu_label, model_label, size_label, num_gpus) + current = best_by_group.get(key) + candidate_key = ( + float(row["mean_tiles_per_second"]), + -float(row["mean_end_to_end_seconds"]), + -int(row["batch_size"]), + -int(row["embedding_workers"]), + ) + if current is None: + best_by_group[key] = row + continue + current_key = ( + float(current["mean_tiles_per_second"]), + -float(current["mean_end_to_end_seconds"]), + -int(current["batch_size"]), + -int(current["embedding_workers"]), + ) + if candidate_key > current_key: + best_by_group[key] = row + + best_rows = [] + for gpu_label, model_label, size_label, num_gpus in sorted(best_by_group): + row = best_by_group[(gpu_label, model_label, size_label, num_gpus)] + best_rows.append( + { + "gpu_label": gpu_label, + "model_label": model_label, + "size_label": size_label, + "config_file": str(row.get("config_file", "")), + "batch_size": int(row["batch_size"]), + "embedding_workers": int(row["embedding_workers"]), + "num_gpus": num_gpus, + "repeat_count": int(row["repeat_count"]), + "mean_tiles_per_second": float(row["mean_tiles_per_second"]), + "std_tiles_per_second": float(row["std_tiles_per_second"]), + "mean_end_to_end_seconds": float(row["mean_end_to_end_seconds"]), + "mean_slides_per_second": float(row["mean_slides_per_second"]), + "mean_loader_wait_ms": float(row.get("mean_loader_wait_ms", 0.0)), + "max_loader_wait_ms": float(row.get("max_loader_wait_ms", 0.0)), + "mean_ready_wait_ms": float(row.get("mean_ready_wait_ms", 0.0)), + "mean_preprocess_ms": float(row.get("mean_preprocess_ms", 0.0)), + "mean_worker_batch_ms": float(row.get("mean_worker_batch_ms", 0.0)), + "mean_reader_open_ms": float(row.get("mean_reader_open_ms", 0.0)), + "mean_reader_read_ms": float(row.get("mean_reader_read_ms", 0.0)), + "mean_forward_ms": float(row.get("mean_forward_ms", 0.0)), + "loader_wait_fraction": float(row.get("loader_wait_fraction", 0.0)), + "gpu_busy_fraction": float(row.get("gpu_busy_fraction", 0.0)), + } + ) + return best_rows + + +def build_size_plot_rows(best_rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + collapsed: dict[tuple[str, str], dict[str, Any]] = {} + for row in best_rows: + key = (str(row["gpu_label"]), str(row.get("size_label", ""))) + current = collapsed.get(key) + candidate_key = ( + float(row["mean_tiles_per_second"]), + -float(row["mean_end_to_end_seconds"]), + str(row.get("model_label", "")), + ) + if current is None: + collapsed[key] = row + continue + current_key = ( + float(current["mean_tiles_per_second"]), + -float(current["mean_end_to_end_seconds"]), + str(current.get("model_label", "")), + ) + if candidate_key > current_key: + collapsed[key] = row + + rows = [] + for gpu_label, size_label in sorted(collapsed): + row = collapsed[(gpu_label, size_label)] + rows.append( + { + "gpu_label": gpu_label, + "size_label": size_label, + "model_label": str(row.get("model_label", "")), + "mean_tiles_per_second": float(row["mean_tiles_per_second"]), + } + ) + return rows + + +def save_csv(rows: list[dict[str, Any]], path: Path) -> None: + if not rows: + return + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + +def cleanup_trial_output(output_dir: Path) -> None: + for dirname in HEAVY_ARTIFACT_DIRS: + candidate = output_dir / dirname + if candidate.exists(): + shutil.rmtree(candidate) + + +def _detect_gpu_label() -> str: + try: + import torch + except ImportError: + return "cpu-or-unknown" + if not torch.cuda.is_available(): + return "cpu-or-unknown" + count = int(torch.cuda.device_count()) + names = [torch.cuda.get_device_name(index).strip() for index in range(count)] + if len(set(names)) == 1: + return f"{count}x {names[0]}" + return " / ".join(names) + + +def _resolve_gpu_label(value: str) -> str: + return _detect_gpu_label() if value == "auto" else value + + +def _build_model_pipeline_from_config(config: dict[str, Any]): + from slide2vec import ExecutionOptions, Model, Pipeline, PreprocessingConfig + + model_cfg = config.get("model", {}) + tiling_cfg = config.get("tiling", {}) + params = tiling_cfg.get("params", {}) + preview = dict(tiling_cfg.get("preview", {})) + preprocessing = PreprocessingConfig( + backend=str(tiling_cfg.get("backend", "asap")), + target_spacing_um=float(params.get("target_spacing_um", 0.5)), + target_tile_size_px=int(params.get("target_tile_size_px", 224)), + tolerance=float(params.get("tolerance", 0.05)), + overlap=float(params.get("overlap", 0.0)), + tissue_threshold=float(params.get("tissue_threshold", 0.01)), + drop_holes=bool(params.get("drop_holes", False)), + use_padding=bool(params.get("use_padding", True)), + read_coordinates_from=( + Path(tiling_cfg["read_coordinates_from"]) + if tiling_cfg.get("read_coordinates_from") + else Path(config["output_dir"]) / "coordinates" + ), + read_tiles_from=Path(tiling_cfg["read_tiles_from"]) if tiling_cfg.get("read_tiles_from") else None, + resume=bool(config.get("resume", False)), + segmentation=dict(tiling_cfg.get("seg_params", {})), + filtering=dict(tiling_cfg.get("filter_params", {})), + preview={ + "save_mask_preview": bool(config.get("save_previews", False)), + "save_tiling_preview": bool(config.get("save_previews", False)), + "downsample": int(preview.get("downsample", 32)), + }, + ) + speed_cfg = config.get("speed", {}) + execution = ExecutionOptions( + output_dir=Path(config["output_dir"]), + output_format=str(config.get("output_format", "pt")), + batch_size=int(model_cfg.get("batch_size", 1)), + num_workers=int(speed_cfg.get("num_workers_embedding", speed_cfg.get("num_workers", 0))), + num_gpus=int(speed_cfg["num_gpus"]) if speed_cfg.get("num_gpus") is not None else None, + precision=str(speed_cfg.get("precision", "fp32")), + prefetch_factor=int(speed_cfg.get("prefetch_factor_embedding", 4)), + persistent_workers=bool(speed_cfg.get("persistent_workers_embedding", True)), + gpu_batch_preprocessing=bool(speed_cfg.get("gpu_batch_preprocessing", True)), + save_tile_embeddings=bool(model_cfg.get("save_tile_embeddings", False)), + save_latents=bool(model_cfg.get("save_latents", False)), + ) + model = Model.from_pretrained( + str(model_cfg["name"]), + level=model_cfg.get("level"), + mode=model_cfg.get("mode"), + arch=model_cfg.get("arch"), + pretrained_weights=model_cfg.get("pretrained_weights"), + input_size=model_cfg.get("input_size"), + patch_size=model_cfg.get("patch_size"), + token_size=model_cfg.get("token_size"), + normalize_embeddings=model_cfg.get("normalize_embeddings"), + device="auto", + ) + return Pipeline(model=model, preprocessing=preprocessing, execution=execution) + + +def _run_internal_harness(args: argparse.Namespace) -> int: + if args.config_file is None or args.metrics_json is None or args.progress_jsonl is None: + raise ValueError("--internal-harness requires --config-file, --metrics-json, and --progress-jsonl") + + from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter + + config = _load_yaml(args.config_file) + pipeline = _build_model_pipeline_from_config(config) + output_dir = Path(config["output_dir"]) + progress_path = Path(args.progress_jsonl) + metrics_path = Path(args.metrics_json) + progress_path.parent.mkdir(parents=True, exist_ok=True) + metrics_path.parent.mkdir(parents=True, exist_ok=True) + + reporter = JsonlProgressReporter(progress_path) + metrics: dict[str, Any] = {} + t0 = time.perf_counter() + try: + with activate_progress_reporter(reporter): + result = pipeline.run(manifest_path=config["csv"]) + end_to_end_seconds = time.perf_counter() - t0 + process_stats = parse_process_list(output_dir / "process_list.csv") + stage_seconds = extract_stage_seconds(progress_path) + batch_timing = extract_batch_timing_metrics(progress_path) + slides_total = int(process_stats["slides_total"]) + slides_per_second = slides_total / end_to_end_seconds if end_to_end_seconds > 0 else 0.0 + tiles_per_second = process_stats["total_tiles"] / end_to_end_seconds if end_to_end_seconds > 0 else 0.0 + metrics = { + "success": True, + "tile_artifacts": len(result.tile_artifacts), + "slide_artifacts": len(result.slide_artifacts), + "slides_total": slides_total, + "slides_with_tiles": int(process_stats["slides_with_tiles"]), + "failed_slides": int(process_stats["failed_slides"]), + "total_tiles": int(process_stats["total_tiles"]), + "end_to_end_seconds": round(end_to_end_seconds, 4), + "tiles_per_second": round(tiles_per_second, 4), + "slides_per_second": round(slides_per_second, 4), + **stage_seconds, + **batch_timing, + } + except Exception as exc: + end_to_end_seconds = time.perf_counter() - t0 + process_stats = parse_process_list(output_dir / "process_list.csv") + metrics = { + "success": False, + "error": str(exc), + "slides_total": int(process_stats["slides_total"]), + "slides_with_tiles": int(process_stats["slides_with_tiles"]), + "failed_slides": int(process_stats["failed_slides"]), + "total_tiles": int(process_stats["total_tiles"]), + "end_to_end_seconds": round(end_to_end_seconds, 4), + "tiles_per_second": 0.0, + "slides_per_second": 0.0, + **extract_stage_seconds(progress_path), + **extract_batch_timing_metrics(progress_path), + } + metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return 1 + + metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return 0 + + +def _run_trial_subprocess( + *, + config_path: Path, + metrics_path: Path, + progress_path: Path, + log_path: Path, +) -> subprocess.CompletedProcess[str]: + command = [ + sys.executable, + str(Path(__file__).resolve()), + "--internal-harness", + "--config-file", + str(config_path), + "--metrics-json", + str(metrics_path), + "--progress-jsonl", + str(progress_path), + ] + completed = subprocess.run( + command, + cwd=REPO_ROOT, + capture_output=True, + text=True, + ) + log_path.write_text((completed.stdout or "") + (completed.stderr or ""), encoding="utf-8") + return completed + + +def _ensure_failure_details_in_log(log_path: Path, metrics: dict[str, Any], *, exit_code: int) -> None: + if exit_code == 0: + return + existing = log_path.read_text(encoding="utf-8") if log_path.is_file() else "" + if existing.strip(): + return + error = str(metrics.get("error", "")).strip() + if error: + log_path.write_text(f"ERROR: {error}\n", encoding="utf-8") + return + log_path.write_text(f"ERROR: benchmark harness exited with code {exit_code}\n", encoding="utf-8") + + +def run_trial( + *, + trial_spec: dict[str, Any], + slides: list[dict[str, Any]], + shared_csv_path: Path, + base_config: dict[str, Any], + gpu_label: str, +) -> dict[str, Any]: + run_dir = Path(trial_spec["run_dir"]) + run_dir.mkdir(parents=True, exist_ok=True) + config_path = run_dir / "config.yaml" + progress_path = run_dir / "progress.jsonl" + metrics_path = run_dir / "metrics.json" + log_path = run_dir / "harness.log" + trial_output_dir = run_dir / "output" + + trial_config = build_trial_config( + base_config, + csv_path=shared_csv_path, + output_dir=trial_output_dir, + batch_size=int(trial_spec["batch_size"]), + embedding_workers=int(trial_spec["embedding_workers"]), + num_gpus=int(trial_spec.get("num_gpus", 1)), + ) + _write_yaml(_to_plain_data(trial_config), config_path) + + completed = _run_trial_subprocess( + config_path=config_path, + metrics_path=metrics_path, + progress_path=progress_path, + log_path=log_path, + ) + metrics = json.loads(metrics_path.read_text(encoding="utf-8")) if metrics_path.is_file() else {} + _ensure_failure_details_in_log(log_path, metrics, exit_code=int(completed.returncode)) + cleanup_trial_output(trial_output_dir) + return { + "gpu_label": gpu_label, + "model_label": str(trial_spec["model_label"]), + "size_label": str(trial_spec["size_label"]), + "config_file": str(trial_spec["config_file"]), + "batch_size": int(trial_spec["batch_size"]), + "embedding_workers": int(trial_spec["embedding_workers"]), + "num_gpus": int(trial_spec.get("num_gpus", 1)), + "repeat_index": int(trial_spec["repeat_index"]), + "run_kind": str(trial_spec["kind"]), + "exit_code": int(completed.returncode), + "slides_total": int(metrics.get("slides_total", 0)), + "slides_with_tiles": int(metrics.get("slides_with_tiles", 0)), + "failed_slides": int(metrics.get("failed_slides", 0)), + "total_tiles": int(metrics.get("total_tiles", 0)), + "end_to_end_seconds": float(metrics.get("end_to_end_seconds", 0.0)), + "tiles_per_second": float(metrics.get("tiles_per_second", 0.0)), + "slides_per_second": float(metrics.get("slides_per_second", 0.0)), + "tiling_seconds": metrics.get("tiling_seconds", ""), + "embedding_seconds": metrics.get("embedding_seconds", ""), + "aggregation_seconds": metrics.get("aggregation_seconds", ""), + "timed_batches": int(metrics.get("timed_batches", 0)), + "mean_loader_wait_ms": float(metrics.get("mean_loader_wait_ms", 0.0)), + "max_loader_wait_ms": float(metrics.get("max_loader_wait_ms", 0.0)), + "mean_ready_wait_ms": float(metrics.get("mean_ready_wait_ms", 0.0)), + "mean_preprocess_ms": float(metrics.get("mean_preprocess_ms", 0.0)), + "mean_worker_batch_ms": float(metrics.get("mean_worker_batch_ms", 0.0)), + "mean_reader_open_ms": float(metrics.get("mean_reader_open_ms", 0.0)), + "mean_reader_read_ms": float(metrics.get("mean_reader_read_ms", 0.0)), + "mean_forward_ms": float(metrics.get("mean_forward_ms", 0.0)), + "loader_wait_fraction": float(metrics.get("loader_wait_fraction", 0.0)), + "error": metrics.get("error", ""), + } + + +def plot_throughput_by_gpu(best_rows: list[dict[str, Any]], output_path: Path) -> None: + if not best_rows: + return + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + gpu_labels = sorted({str(row["gpu_label"]) for row in best_rows}) + series_labels = sorted( + {f"{row.get('model_label', '')} ({row.get('size_label', '')})".strip() for row in best_rows} + ) + x_positions = np.arange(len(gpu_labels), dtype=float) + width = 0.8 / max(len(series_labels), 1) + + fig, ax = plt.subplots(figsize=(max(7.0, 1.8 * len(gpu_labels)), 4.8)) + for index, series_label in enumerate(series_labels): + values = [] + annotations = [] + for gpu_label in gpu_labels: + row = next( + ( + item + for item in best_rows + if str(item["gpu_label"]) == gpu_label + and f"{item.get('model_label', '')} ({item.get('size_label', '')})".strip() == series_label + ), + None, + ) + values.append(float(row["mean_tiles_per_second"]) if row is not None else np.nan) + annotations.append( + f"bs={row['batch_size']}, w={row['embedding_workers']}" if row is not None else "" + ) + offsets = x_positions - 0.4 + width / 2 + index * width + bars = ax.bar(offsets, values, width=width, label=series_label) + for bar, value, annotation in zip(bars, values, annotations): + if math.isnan(value): + continue + ax.text( + bar.get_x() + bar.get_width() / 2, + value, + f"{value:,.1f}\n{annotation}", + ha="center", + va="bottom", + fontsize=7, + ) + + ax.set_ylabel("Tiles / second") + ax.set_title("slide2vec End-to-End Throughput by GPU") + ax.set_xticks(x_positions, labels=gpu_labels) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + ax.legend(loc="best", fontsize=8) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def plot_throughput_by_gpu_and_size(best_rows: list[dict[str, Any]], output_path: Path) -> None: + size_rows = build_size_plot_rows(best_rows) + if not size_rows: + return + + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + gpu_labels = sorted({str(row["gpu_label"]) for row in size_rows}) + size_labels = sorted({str(row["size_label"]) for row in size_rows}) + x_positions = np.arange(len(gpu_labels), dtype=float) + width = 0.8 / max(len(size_labels), 1) + + fig, ax = plt.subplots(figsize=(max(7.0, 1.8 * len(gpu_labels)), 4.8)) + for index, size_label in enumerate(size_labels): + values = [] + annotations = [] + for gpu_label in gpu_labels: + row = next( + ( + item + for item in size_rows + if str(item["gpu_label"]) == gpu_label and str(item["size_label"]) == size_label + ), + None, + ) + values.append(float(row["mean_tiles_per_second"]) if row is not None else np.nan) + annotations.append(str(row["model_label"]) if row is not None else "") + offsets = x_positions - 0.4 + width / 2 + index * width + bars = ax.bar(offsets, values, width=width, label=size_label) + for bar, value, annotation in zip(bars, values, annotations): + if math.isnan(value): + continue + ax.text( + bar.get_x() + bar.get_width() / 2, + value, + f"{value:,.1f}\n{annotation}", + ha="center", + va="bottom", + fontsize=7, + ) + + ax.set_ylabel("Tiles / second") + ax.set_title("slide2vec Throughput by GPU and Size") + ax.set_xticks(x_positions, labels=gpu_labels) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + ax.legend(title="Size", fontsize=8) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def plot_tuning_grid(aggregated_rows: list[dict[str, Any]], *, gpu_label: str, model_label: str, output_path: Path) -> None: + gpu_rows = [ + row + for row in aggregated_rows + if row["gpu_label"] == gpu_label and row.get("model_label", "") == model_label + ] + if not gpu_rows: + return + + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + batch_sizes = sorted({int(row["batch_size"]) for row in gpu_rows}) + workers = sorted({int(row["embedding_workers"]) for row in gpu_rows}) + grid = np.full((len(workers), len(batch_sizes)), np.nan, dtype=float) + for row in gpu_rows: + worker_index = workers.index(int(row["embedding_workers"])) + batch_index = batch_sizes.index(int(row["batch_size"])) + grid[worker_index, batch_index] = float(row["mean_tiles_per_second"]) + + fig, ax = plt.subplots(figsize=(1.4 * max(len(batch_sizes), 3), 1.1 * max(len(workers), 3))) + image = ax.imshow(grid, cmap="Blues", aspect="auto") + ax.set_xticks(range(len(batch_sizes)), labels=[str(value) for value in batch_sizes]) + ax.set_yticks(range(len(workers)), labels=[str(value) for value in workers]) + ax.set_xlabel("Batch size") + ax.set_ylabel("Embedding workers") + ax.set_title(f"{gpu_label} - {model_label} tuning sweep") + for worker_index in range(len(workers)): + for batch_index in range(len(batch_sizes)): + value = grid[worker_index, batch_index] + if not math.isnan(value): + ax.text(batch_index, worker_index, f"{value:,.1f}", ha="center", va="center", fontsize=8) + fig.colorbar(image, ax=ax, label="Tiles / second") + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def _prepare_chart_outputs(rows: list[dict[str, Any]], output_dir: Path) -> int: + if not rows: + print("No rows available for chart generation.", file=sys.stderr) + return 1 + + if "mean_tiles_per_second" in rows[0]: + aggregated_rows = rows + best_rows = select_best_results(aggregated_rows) + else: + aggregated_rows = aggregate_trial_results(rows) + best_rows = select_best_results(aggregated_rows) + + save_csv(best_rows, output_dir / "best_results.csv") + plot_throughput_by_gpu(best_rows, output_dir / "throughput_by_gpu.png") + plot_throughput_by_gpu_and_size(best_rows, output_dir / "throughput_by_gpu_and_size.png") + for gpu_label, model_label in sorted({(row["gpu_label"], row.get("model_label", "")) for row in aggregated_rows}): + sanitized_gpu = sanitize_label(str(gpu_label)) + sanitized_model = sanitize_label(str(model_label)) + plot_tuning_grid( + aggregated_rows, + gpu_label=gpu_label, + model_label=model_label, + output_path=output_dir / f"tuning_{sanitized_gpu}_{sanitized_model}.png", + ) + return 0 + + +def run_benchmark(args: argparse.Namespace) -> int: + try: + model_specs = resolve_model_specs(args) + except ValueError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + + configs_by_path = {spec["config_file"]: _load_cli_merged_config(spec["config_file"]) for spec in model_specs} + first_config = configs_by_path[model_specs[0]["config_file"]] + manifest_path = args.csv or (Path(first_config["csv"]) if first_config.get("csv") else None) + if manifest_path is None: + print("ERROR: provide --csv or set csv in the config file.", file=sys.stderr) + return 1 + + all_slides = load_slides_from_csv(manifest_path) + if not all_slides: + print("ERROR: the manifest is empty.", file=sys.stderr) + return 1 + + target_count = args.n_slides if args.n_slides > 0 else len(all_slides) + balanced = build_balanced_sample(all_slides, n_slides=min(target_count, len(all_slides)), seed=args.seed) + if args.copy_locally: + balanced = copy_slides_locally(balanced, args.local_dir) + + output_dir = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + shared_manifest_path = output_dir / "sampled_slides.csv" + write_slides_csv(balanced, shared_manifest_path) + gpu_label = _resolve_gpu_label(args.gpu_label) + trial_plan = build_trial_plan( + output_root=output_dir, + model_specs=model_specs, + batch_sizes=list(args.batch_sizes), + embedding_workers=list(args.embedding_workers), + num_gpus=list(args.num_gpus), + repeat=int(args.repeat), + ) + + trial_rows: list[dict[str, Any]] = [] + total_measured = sum(item["kind"] == "measure" for item in trial_plan) + measured_index = 0 + for trial_spec in trial_plan: + trial_spec = dict(trial_spec) + trial_spec["shared_csv_path"] = shared_manifest_path + label = ( + f"{trial_spec['model_label']} [{trial_spec['size_label']}] " + f"bs={trial_spec['batch_size']} workers={trial_spec['embedding_workers']} gpus={trial_spec.get('num_gpus', 1)}" + ) + if trial_spec["kind"] == "warmup": + print(f"Warmup: {label}") + else: + measured_index += 1 + print(f"[{measured_index}/{total_measured}] {label} repeat={trial_spec['repeat_index']}") + + row = run_trial( + trial_spec=trial_spec, + slides=balanced, + shared_csv_path=shared_manifest_path, + base_config=configs_by_path[Path(trial_spec["config_file"])], + gpu_label=gpu_label, + ) + if trial_spec["kind"] == "measure": + trial_rows.append(row) + status = "OK" if row["exit_code"] == 0 else f"exit={row['exit_code']}" + print( + f" -> {row['total_tiles']:,} tiles in {row['end_to_end_seconds']:.2f}s " + f"({row['tiles_per_second']:,.1f} tiles/s, " + f"loader={row.get('loader_wait_fraction', 0.0) * 100:.1f}% " + f"wait={row.get('mean_loader_wait_ms', 0.0):.1f}ms " + f"ready={row.get('mean_ready_wait_ms', 0.0):.1f}ms " + f"prep={row.get('mean_preprocess_ms', 0.0):.1f}ms " + f"fwd={row.get('mean_forward_ms', 0.0):.1f}ms) [{status}]" + ) + elif row["exit_code"] != 0: + message = f"Warmup failed for {label}. See {trial_spec['run_dir'] / 'harness.log'}" + if row.get("error"): + message += f". Error: {row['error']}" + print(message, file=sys.stderr) + return int(row["exit_code"]) + + trial_results_path = output_dir / "trial_results.csv" + save_csv(trial_rows, trial_results_path) + print(f"Saved raw trial results to {trial_results_path}") + return _prepare_chart_outputs(trial_rows, output_dir) + + +def main() -> int: + args = parse_args() + if args.internal_harness: + return _run_internal_harness(args) + if args.chart_only is not None: + chart_output_dir = args.output_dir or args.chart_only[0].parent + rows = load_trial_results_csvs(list(args.chart_only)) + return _prepare_chart_outputs(rows, chart_output_dir) + return run_benchmark(args) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_end_to_end_paths.py b/scripts/benchmark_end_to_end_paths.py new file mode 100644 index 0000000..9b63295 --- /dev/null +++ b/scripts/benchmark_end_to_end_paths.py @@ -0,0 +1,1008 @@ +#!/usr/bin/env python3 +"""Benchmark slide2vec full pipelines across tar and on-the-fly read modes.""" + +from __future__ import annotations + +import argparse +import csv +import copy +import json +import shutil +import statistics +import subprocess +import sys +import time +from pathlib import Path +from typing import Any + +import numpy as np +from types import SimpleNamespace + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_OUTPUT_DIR = Path("output/benchmark-end-to-end-paths") +HEAVY_ARTIFACT_DIRS = ( + "tiles", + "coordinates", + "tile_embeddings", + "slide_embeddings", + "slide_latents", + "previews", +) + +ALL_MODES = ["tar", "wsd_single", "cucim_supertiles"] + +MODE_CONFIGS: dict[str, dict[str, Any]] = { + "tar": dict( + on_the_fly=False, + backend="cucim", + use_supertiles=True, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), + "wsd_single": dict( + on_the_fly=True, + backend="asap", + use_supertiles=False, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), + "cucim_supertiles": dict( + on_the_fly=True, + backend="cucim", + use_supertiles=True, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), +} + +MODE_DISPLAY_LABELS = { + "tar": "tar path", + "wsd_single": "wsd single", + "cucim_supertiles": "cucim supertiles", +} + + +def _prepend_repo_root_to_sys_path(paths: list[str]) -> list[str]: + repo_root = str(REPO_ROOT) + return [repo_root, *[path for path in paths if path != repo_root]] + + +sys.path[:] = _prepend_repo_root_to_sys_path(sys.path) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark full slide2vec pipelines for tar, on-the-fly wsd single-tile reads, and on-the-fly cucim supertiles.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--csv", type=Path, required=False, help="Slide manifest CSV.") + parser.add_argument("--config-file", type=Path, required=False, help="Base slide2vec YAML config.") + parser.add_argument("--repeat", type=int, default=1, help="Timed repetitions per mode.") + parser.add_argument("--warmup", type=int, default=0, help="Untimed warmup reps per mode.") + parser.add_argument("--batch-size", type=int, default=256, help="Embedding batch size.") + parser.add_argument("--num-dataloader-workers", type=int, default=32, help="Tar-path DataLoader workers.") + parser.add_argument("--num-cucim-workers", type=int, default=4, help="cucim internal threads per read_region call.") + parser.add_argument("--num-preprocessing-workers", type=int, default=8, help="Workers for hs2p tiling phase.") + parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR, help="Results directory.") + parser.add_argument( + "--chart-only", + type=Path, + nargs="+", + default=None, + metavar="TRIAL_RESULTS_CSV", + help="Skip benchmarking and regenerate charts from existing trial-results CSV files.", + ) + + parser.add_argument("--internal-harness", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--harness-config", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--metrics-json", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--progress-jsonl", type=Path, default=None, help=argparse.SUPPRESS) + return parser.parse_args() + + +def load_slides_from_csv(csv_path: Path) -> list[dict[str, Any]]: + slides: list[dict[str, Any]] = [] + with csv_path.open(newline="") as handle: + reader = csv.DictReader(handle) + fieldnames = set(reader.fieldnames or []) + has_mask = "mask_path" in fieldnames + has_spacing = "spacing_at_level_0" in fieldnames + has_sample_id = "sample_id" in fieldnames + for row in reader: + image_path = Path(row["image_path"]) + mask_path = Path(row["mask_path"]) if has_mask and row.get("mask_path") else None + raw_spacing = row.get("spacing_at_level_0", "") if has_spacing else "" + spacing_at_level_0 = float(raw_spacing) if raw_spacing.strip() else None + sample_id = row["sample_id"] if has_sample_id else image_path.stem + slides.append( + { + "sample_id": sample_id, + "image_path": image_path, + "mask_path": mask_path, + "spacing_at_level_0": spacing_at_level_0, + } + ) + return slides + + +def write_slides_csv(slides: list[dict[str, Any]], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + has_spacing = any(slide.get("spacing_at_level_0") is not None for slide in slides) + fieldnames = ["sample_id", "image_path", "mask_path"] + if has_spacing: + fieldnames.append("spacing_at_level_0") + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for slide in slides: + row: dict[str, Any] = { + "sample_id": slide["sample_id"], + "image_path": str(slide["image_path"]), + "mask_path": str(slide["mask_path"]) if slide["mask_path"] is not None else "", + } + if has_spacing: + row["spacing_at_level_0"] = slide.get("spacing_at_level_0") or "" + writer.writerow(row) + + +def _load_yaml(path: Path) -> dict[str, Any]: + import yaml + + with path.open(encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + if not isinstance(data, dict): + raise ValueError(f"Expected mapping config in {path}") + return data + + +def _to_namespace(value: Any) -> Any: + if isinstance(value, dict): + return SimpleNamespace(**{key: _to_namespace(item) for key, item in value.items()}) + if isinstance(value, list): + return [_to_namespace(item) for item in value] + return value + + +def _to_plain_data(value: Any) -> Any: + if value.__class__.__module__.startswith("omegaconf"): + from omegaconf import OmegaConf + + return OmegaConf.to_container(value, resolve=True) + if isinstance(value, SimpleNamespace): + return {key: _to_plain_data(item) for key, item in vars(value).items()} + if isinstance(value, dict): + return {key: _to_plain_data(item) for key, item in value.items()} + if isinstance(value, list): + return [_to_plain_data(item) for item in value] + if isinstance(value, Path): + return str(value) + return value + + +def _load_cli_merged_config(path: Path) -> dict[str, Any]: + from slide2vec.utils.config import get_cfg_from_args + + cfg = get_cfg_from_args( + argparse.Namespace( + config_file=str(path), + output_dir=None, + opts=[], + ) + ) + plain = _to_plain_data(cfg) + if not isinstance(plain, dict): + raise ValueError(f"Expected mapping config in {path}") + return plain + + +def _write_yaml(data: dict[str, Any], path: Path) -> None: + import yaml + + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + yaml.safe_dump(data, handle, sort_keys=False) + + +def _default_base_config( + *, + csv_path: Path, + output_dir: Path, + batch_size: int, + num_dataloader_workers: int, + num_preprocessing_workers: int, + num_cucim_workers: int, +) -> dict[str, Any]: + return { + "csv": str(csv_path), + "output_dir": str(output_dir), + "resume": False, + "save_previews": False, + "model": {"batch_size": batch_size}, + "tiling": { + "on_the_fly": True, + "gpu_decode": False, + "adaptive_batching": False, + "use_supertiles": True, + "jpeg_backend": "turbojpeg", + "backend": "cucim", + "read_coordinates_from": None, + "read_tiles_from": None, + "params": { + "target_spacing_um": 0.5, + "tolerance": 0.05, + "target_tile_size_px": 256, + "overlap": 0.0, + "tissue_threshold": 0.01, + "drop_holes": False, + "use_padding": True, + }, + "seg_params": { + "downsample": 64, + "sthresh": 8, + "sthresh_up": 255, + "mthresh": 7, + "close": 4, + "use_otsu": False, + "use_hsv": True, + }, + "filter_params": { + "ref_tile_size": 256, + "a_t": 4, + "a_h": 2, + "max_n_holes": 8, + "filter_white": False, + "filter_black": False, + "white_threshold": 220, + "black_threshold": 25, + "fraction_threshold": 0.9, + }, + "preview": {"downsample": 32}, + }, + "speed": { + "precision": "fp32", + "num_preprocessing_workers": num_preprocessing_workers, + "num_dataloader_workers": num_dataloader_workers, + "num_cucim_workers": num_cucim_workers, + "prefetch_factor_embedding": 4, + "persistent_workers_embedding": True, + "gpu_batch_preprocessing": True, + }, + "wandb": {"enable": False}, + } + + +def _merge_base_config(base: dict[str, Any], config_file: Path | None) -> dict[str, Any]: + if config_file is None: + return base + file_data = _load_cli_merged_config(config_file) + merged = copy.deepcopy(file_data) + merged["csv"] = base["csv"] + merged["output_dir"] = base["output_dir"] + merged["resume"] = False + merged["save_previews"] = False + merged.setdefault("model", {})["batch_size"] = base["model"]["batch_size"] + merged.setdefault("speed", {}) + merged["speed"]["num_preprocessing_workers"] = base["speed"]["num_preprocessing_workers"] + merged["speed"]["num_dataloader_workers"] = base["speed"]["num_dataloader_workers"] + merged["speed"]["num_cucim_workers"] = base["speed"]["num_cucim_workers"] + merged.setdefault("tiling", {}) + merged["tiling"]["read_coordinates_from"] = None + merged["tiling"]["read_tiles_from"] = None + merged.setdefault("wandb", {})["enable"] = False + return merged + + +def _apply_mode_overrides(config: dict[str, Any], mode: str) -> dict[str, Any]: + import copy + + cfg = copy.deepcopy(config) + mode_cfg = MODE_CONFIGS[mode] + cfg["tiling"]["on_the_fly"] = mode_cfg["on_the_fly"] + cfg["tiling"]["backend"] = mode_cfg["backend"] + cfg["tiling"]["use_supertiles"] = mode_cfg["use_supertiles"] + cfg["tiling"]["adaptive_batching"] = mode_cfg["adaptive_batching"] + cfg["tiling"]["jpeg_backend"] = mode_cfg["jpeg_backend"] + cfg["tiling"]["read_coordinates_from"] = None + cfg["tiling"]["read_tiles_from"] = None + return cfg + + +def load_progress_records(path: Path) -> list[dict[str, Any]]: + if not path.is_file(): + return [] + records: list[dict[str, Any]] = [] + with path.open(encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + payload = json.loads(line) + if isinstance(payload, dict): + records.append(payload) + return records + + +def extract_stage_seconds(progress_path: Path) -> dict[str, float | None]: + records = load_progress_records(progress_path) + stage_seconds: dict[str, float | None] = { + "tiling_seconds": None, + "embedding_seconds": None, + } + if not records: + return stage_seconds + first_timestamps: dict[str, float] = {} + for record in records: + kind = record.get("kind") + timestamp = record.get("timestamp") + if kind is None or timestamp is None: + continue + if kind not in first_timestamps: + first_timestamps[kind] = float(timestamp) + if "tiling.started" in first_timestamps and "tiling.finished" in first_timestamps: + stage_seconds["tiling_seconds"] = round(first_timestamps["tiling.finished"] - first_timestamps["tiling.started"], 4) + if "embedding.started" in first_timestamps and "embedding.finished" in first_timestamps: + stage_seconds["embedding_seconds"] = round(first_timestamps["embedding.finished"] - first_timestamps["embedding.started"], 4) + return stage_seconds + + +def extract_batch_timing_metrics(progress_path: Path) -> dict[str, float | int]: + records = load_progress_records(progress_path) + batch_payloads = [ + record.get("payload", {}) + for record in records + if record.get("kind") == "embedding.batch.timing" and isinstance(record.get("payload"), dict) + ] + zeros: dict[str, float | int] = { + "timed_batches": 0, + "mean_loader_wait_ms": 0.0, + "max_loader_wait_ms": 0.0, + "mean_ready_wait_ms": 0.0, + "mean_preprocess_ms": 0.0, + "mean_worker_batch_ms": 0.0, + "mean_reader_open_ms": 0.0, + "mean_reader_read_ms": 0.0, + "mean_forward_ms": 0.0, + "data_pipeline_seconds": 0.0, + "forward_seconds": 0.0, + "accounted_embedding_seconds": 0.0, + "data_pipeline_fraction": 0.0, + "forward_fraction": 0.0, + "loader_wait_fraction": 0.0, + "gpu_busy_fraction": 0.0, + } + if not batch_payloads: + return zeros + loader_wait_ms = [float(p.get("loader_wait_ms", 0.0)) for p in batch_payloads] + ready_wait_ms = [float(p.get("ready_wait_ms", 0.0)) for p in batch_payloads] + preprocess_ms = [float(p.get("preprocess_ms", 0.0)) for p in batch_payloads] + worker_batch_ms = [float(p.get("worker_batch_ms", 0.0)) for p in batch_payloads] + reader_open_ms = [float(p.get("reader_open_ms", 0.0)) for p in batch_payloads] + reader_read_ms = [float(p.get("reader_read_ms", 0.0)) for p in batch_payloads] + forward_ms = [float(p.get("forward_ms", 0.0)) for p in batch_payloads] + gpu_busy_fraction = [float(p.get("gpu_busy_fraction", 0.0)) for p in batch_payloads] + total_loader_ms = sum(loader_wait_ms) + sum(ready_wait_ms) + total_data_pipeline_ms = total_loader_ms + sum(preprocess_ms) + total_forward_ms = sum(forward_ms) + total_ms = total_data_pipeline_ms + total_forward_ms + return { + "timed_batches": len(batch_payloads), + "mean_loader_wait_ms": round(statistics.mean(loader_wait_ms), 4), + "max_loader_wait_ms": round(max(loader_wait_ms), 4), + "mean_ready_wait_ms": round(statistics.mean(ready_wait_ms), 4), + "mean_preprocess_ms": round(statistics.mean(preprocess_ms), 4), + "mean_worker_batch_ms": round(statistics.mean(worker_batch_ms), 4), + "mean_reader_open_ms": round(statistics.mean(reader_open_ms), 4), + "mean_reader_read_ms": round(statistics.mean(reader_read_ms), 4), + "mean_forward_ms": round(statistics.mean(forward_ms), 4), + "data_pipeline_seconds": round(total_data_pipeline_ms / 1000.0, 4), + "forward_seconds": round(total_forward_ms / 1000.0, 4), + "accounted_embedding_seconds": round(total_ms / 1000.0, 4), + "data_pipeline_fraction": round(total_data_pipeline_ms / total_ms, 4) if total_ms > 0 else 0.0, + "forward_fraction": round(total_forward_ms / total_ms, 4) if total_ms > 0 else 0.0, + "loader_wait_fraction": round(total_loader_ms / total_ms, 0 if total_ms <= 0 else 4) if total_ms > 0 else 0.0, + "gpu_busy_fraction": round(statistics.mean(gpu_busy_fraction), 4), + } + + +def parse_process_list(path: Path) -> dict[str, int]: + if not path.is_file(): + return {"slides_total": 0, "slides_with_tiles": 0, "failed_slides": 0, "total_tiles": 0} + with path.open(newline="") as handle: + rows = list(csv.DictReader(handle)) + total_tiles = sum(int(float(row.get("num_tiles") or 0)) for row in rows) + slides_with_tiles = sum(int(float(row.get("num_tiles") or 0)) > 0 for row in rows) + failed_slides = sum(row.get("tiling_status") == "failed" for row in rows) + return { + "slides_total": len(rows), + "slides_with_tiles": slides_with_tiles, + "failed_slides": failed_slides, + "total_tiles": total_tiles, + } + + +def save_csv(rows: list[dict[str, Any]], path: Path) -> None: + if not rows: + return + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + +def _build_pipeline_from_config_dict(config: dict[str, Any]): + from slide2vec import ExecutionOptions, Model, Pipeline, PreprocessingConfig + + model_cfg = config.get("model", {}) + tiling_cfg = config.get("tiling", {}) + params = tiling_cfg.get("params", {}) + preview = dict(tiling_cfg.get("preview", {})) + speed_cfg = config.get("speed", {}) + + preprocessing = PreprocessingConfig( + backend=str(tiling_cfg.get("backend", "cucim")), + target_spacing_um=float(params.get("target_spacing_um", 0.5)), + target_tile_size_px=int(params.get("target_tile_size_px", 256)), + tolerance=float(params.get("tolerance", 0.05)), + overlap=float(params.get("overlap", 0.0)), + tissue_threshold=float(params.get("tissue_threshold", 0.01)), + drop_holes=bool(params.get("drop_holes", False)), + use_padding=bool(params.get("use_padding", True)), + read_coordinates_from=None, + read_tiles_from=None, + on_the_fly=bool(tiling_cfg.get("on_the_fly", True)), + gpu_decode=bool(tiling_cfg.get("gpu_decode", False)), + adaptive_batching=bool(tiling_cfg.get("adaptive_batching", False)), + use_supertiles=bool(tiling_cfg.get("use_supertiles", True)), + jpeg_backend=str(tiling_cfg.get("jpeg_backend", "turbojpeg")), + num_cucim_workers=int(speed_cfg.get("num_cucim_workers", tiling_cfg.get("num_cucim_workers", 4))), + resume=bool(config.get("resume", False)), + segmentation=dict(tiling_cfg.get("seg_params", {})), + filtering=dict(tiling_cfg.get("filter_params", {})), + preview={ + "save_mask_preview": bool(config.get("save_previews", False)), + "save_tiling_preview": bool(config.get("save_previews", False)), + "downsample": int(preview.get("downsample", 32)), + }, + ) + execution = ExecutionOptions( + output_dir=Path(config["output_dir"]), + batch_size=int(model_cfg.get("batch_size", 256)), + num_workers=int(speed_cfg.get("num_dataloader_workers", speed_cfg.get("num_workers_embedding", 32))), + num_preprocessing_workers=int(speed_cfg.get("num_preprocessing_workers", 8)), + precision=str(speed_cfg.get("precision", "fp32")), + prefetch_factor=int(speed_cfg.get("prefetch_factor_embedding", 4)), + persistent_workers=bool(speed_cfg.get("persistent_workers_embedding", True)), + gpu_batch_preprocessing=bool(speed_cfg.get("gpu_batch_preprocessing", True)), + save_tile_embeddings=bool(model_cfg.get("save_tile_embeddings", False)), + save_latents=bool(model_cfg.get("save_latents", False)), + ) + model = Model.from_pretrained( + str(model_cfg["name"]), + level=model_cfg.get("level", "tile"), + mode=model_cfg.get("mode"), + arch=model_cfg.get("arch"), + pretrained_weights=model_cfg.get("pretrained_weights"), + input_size=model_cfg.get("input_size"), + patch_size=model_cfg.get("patch_size"), + token_size=model_cfg.get("token_size"), + normalize_embeddings=model_cfg.get("normalize_embeddings"), + device="auto", + ) + return Pipeline(model=model, preprocessing=preprocessing, execution=execution) + + +def _run_internal_harness(args: argparse.Namespace) -> int: + if args.harness_config is None or args.metrics_json is None or args.progress_jsonl is None: + raise ValueError("--internal-harness requires --harness-config, --metrics-json, and --progress-jsonl") + + from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter + + config = _load_yaml(args.harness_config) + pipeline = _build_pipeline_from_config_dict(config) + output_dir = Path(config["output_dir"]) + progress_path = Path(args.progress_jsonl) + metrics_path = Path(args.metrics_json) + progress_path.parent.mkdir(parents=True, exist_ok=True) + metrics_path.parent.mkdir(parents=True, exist_ok=True) + + reporter = JsonlProgressReporter(progress_path) + t0 = time.perf_counter() + try: + with activate_progress_reporter(reporter): + result = pipeline.run(manifest_path=config["csv"]) + end_to_end_seconds = time.perf_counter() - t0 + process_stats = parse_process_list(output_dir / "process_list.csv") + metrics = { + "success": True, + "tile_artifacts": len(result.tile_artifacts), + "slide_artifacts": len(result.slide_artifacts), + "slides_total": int(process_stats["slides_total"]), + "slides_with_tiles": int(process_stats["slides_with_tiles"]), + "failed_slides": int(process_stats["failed_slides"]), + "total_tiles": int(process_stats["total_tiles"]), + "end_to_end_seconds": round(end_to_end_seconds, 4), + "tiles_per_second": round(process_stats["total_tiles"] / end_to_end_seconds, 4) if end_to_end_seconds > 0 else 0.0, + **extract_stage_seconds(progress_path), + **extract_batch_timing_metrics(progress_path), + } + except Exception as exc: + end_to_end_seconds = time.perf_counter() - t0 + process_stats = parse_process_list(output_dir / "process_list.csv") + metrics = { + "success": False, + "error": str(exc), + "slides_total": int(process_stats["slides_total"]), + "slides_with_tiles": int(process_stats["slides_with_tiles"]), + "failed_slides": int(process_stats["failed_slides"]), + "total_tiles": int(process_stats["total_tiles"]), + "end_to_end_seconds": round(end_to_end_seconds, 4), + "tiles_per_second": 0.0, + **extract_stage_seconds(progress_path), + **extract_batch_timing_metrics(progress_path), + } + metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return 1 + + metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return 0 + + +def _run_trial_subprocess(*, config_path: Path, metrics_path: Path, progress_path: Path, log_path: Path) -> subprocess.CompletedProcess[str]: + command = [ + sys.executable, + str(Path(__file__).resolve()), + "--internal-harness", + "--harness-config", + str(config_path), + "--metrics-json", + str(metrics_path), + "--progress-jsonl", + str(progress_path), + ] + completed = subprocess.run(command, cwd=REPO_ROOT, capture_output=True, text=True) + log_path.write_text((completed.stdout or "") + (completed.stderr or ""), encoding="utf-8") + return completed + + +def cleanup_trial_output(output_dir: Path) -> None: + for dirname in HEAVY_ARTIFACT_DIRS: + candidate = output_dir / dirname + if candidate.exists(): + shutil.rmtree(candidate) + + +def reset_trial_run_dir(run_dir: Path) -> None: + if run_dir.exists(): + shutil.rmtree(run_dir) + run_dir.mkdir(parents=True, exist_ok=True) + + +def run_trial(*, mode: str, kind: str, repeat_index: int, run_dir: Path, config: dict[str, Any]) -> dict[str, Any]: + reset_trial_run_dir(run_dir) + config_path = run_dir / "config.yaml" + progress_path = run_dir / "progress.jsonl" + metrics_path = run_dir / "metrics.json" + log_path = run_dir / "harness.log" + trial_output_dir = run_dir / "output" + + trial_config = _apply_mode_overrides(config, mode) + trial_config["output_dir"] = str(trial_output_dir) + _write_yaml(trial_config, config_path) + + completed = _run_trial_subprocess( + config_path=config_path, + metrics_path=metrics_path, + progress_path=progress_path, + log_path=log_path, + ) + metrics = json.loads(metrics_path.read_text(encoding="utf-8")) if metrics_path.is_file() else {} + cleanup_trial_output(trial_output_dir) + return { + "mode": mode, + "kind": kind, + "repeat_index": repeat_index, + "exit_code": int(completed.returncode), + "slides_total": int(metrics.get("slides_total", 0)), + "slides_with_tiles": int(metrics.get("slides_with_tiles", 0)), + "failed_slides": int(metrics.get("failed_slides", 0)), + "total_tiles": int(metrics.get("total_tiles", 0)), + "end_to_end_seconds": float(metrics.get("end_to_end_seconds", 0.0)), + "tiles_per_second": float(metrics.get("tiles_per_second", 0.0)), + "tiling_seconds": metrics.get("tiling_seconds") or "", + "embedding_seconds": metrics.get("embedding_seconds") or "", + "timed_batches": int(metrics.get("timed_batches", 0)), + "mean_loader_wait_ms": float(metrics.get("mean_loader_wait_ms", 0.0)), + "max_loader_wait_ms": float(metrics.get("max_loader_wait_ms", 0.0)), + "mean_ready_wait_ms": float(metrics.get("mean_ready_wait_ms", 0.0)), + "mean_preprocess_ms": float(metrics.get("mean_preprocess_ms", 0.0)), + "mean_worker_batch_ms": float(metrics.get("mean_worker_batch_ms", 0.0)), + "mean_reader_open_ms": float(metrics.get("mean_reader_open_ms", 0.0)), + "mean_reader_read_ms": float(metrics.get("mean_reader_read_ms", 0.0)), + "mean_forward_ms": float(metrics.get("mean_forward_ms", 0.0)), + "data_pipeline_seconds": float(metrics.get("data_pipeline_seconds", 0.0)), + "forward_seconds": float(metrics.get("forward_seconds", 0.0)), + "accounted_embedding_seconds": float(metrics.get("accounted_embedding_seconds", 0.0)), + "data_pipeline_fraction": float(metrics.get("data_pipeline_fraction", 0.0)), + "forward_fraction": float(metrics.get("forward_fraction", 0.0)), + "loader_wait_fraction": float(metrics.get("loader_wait_fraction", 0.0)), + "gpu_busy_fraction": float(metrics.get("gpu_busy_fraction", 0.0)), + "error": metrics.get("error", ""), + } + + +def aggregate_trial_results(trial_rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + grouped: dict[str, list[dict[str, Any]]] = {} + for row in trial_rows: + if row.get("exit_code", 0) not in (0, "", None): + continue + grouped.setdefault(str(row["mode"]), []).append(row) + aggregated: list[dict[str, Any]] = [] + for mode in ALL_MODES: + rows = grouped.get(mode) + if not rows: + continue + end_to_end_seconds = [float(r["end_to_end_seconds"]) for r in rows] + tiles_per_second = [float(r["tiles_per_second"]) for r in rows] + tiling_seconds = [float(r["tiling_seconds"]) for r in rows if r.get("tiling_seconds") not in ("", None)] + embedding_seconds = [float(r["embedding_seconds"]) for r in rows if r.get("embedding_seconds") not in ("", None)] + loader_wait_ms = [float(r.get("mean_loader_wait_ms", 0.0)) for r in rows] + forward_ms = [float(r.get("mean_forward_ms", 0.0)) for r in rows] + data_pipeline_seconds = [float(r.get("data_pipeline_seconds", 0.0)) for r in rows] + forward_seconds = [float(r.get("forward_seconds", 0.0)) for r in rows] + accounted_embedding_seconds = [float(r.get("accounted_embedding_seconds", 0.0)) for r in rows] + data_pipeline_fraction = [float(r.get("data_pipeline_fraction", 0.0)) for r in rows] + forward_fraction = [float(r.get("forward_fraction", 0.0)) for r in rows] + gpu_busy_fraction = [float(r.get("gpu_busy_fraction", 0.0)) for r in rows] + aggregated.append( + { + "mode": mode, + "repeat_count": len(rows), + "total_tiles": int(rows[0].get("total_tiles", 0)), + "mean_end_to_end_seconds": round(statistics.mean(end_to_end_seconds), 4), + "std_end_to_end_seconds": round(statistics.pstdev(end_to_end_seconds), 4) if len(end_to_end_seconds) > 1 else 0.0, + "mean_tiles_per_second": round(statistics.mean(tiles_per_second), 4), + "std_tiles_per_second": round(statistics.pstdev(tiles_per_second), 4) if len(tiles_per_second) > 1 else 0.0, + "mean_tiling_seconds": round(statistics.mean(tiling_seconds), 4) if tiling_seconds else "", + "mean_embedding_seconds": round(statistics.mean(embedding_seconds), 4) if embedding_seconds else "", + "mean_loader_wait_ms": round(statistics.mean(loader_wait_ms), 4), + "mean_forward_ms": round(statistics.mean(forward_ms), 4), + "mean_data_pipeline_seconds": round(statistics.mean(data_pipeline_seconds), 4), + "mean_forward_seconds": round(statistics.mean(forward_seconds), 4), + "mean_accounted_embedding_seconds": round(statistics.mean(accounted_embedding_seconds), 4), + "mean_data_pipeline_fraction": round(statistics.mean(data_pipeline_fraction), 4), + "mean_forward_fraction": round(statistics.mean(forward_fraction), 4), + "gpu_busy_fraction": round(statistics.mean(gpu_busy_fraction), 4), + } + ) + return aggregated + + +def plot_end_to_end_by_path(summary_rows: list[dict[str, Any]], output_path: Path) -> None: + if not summary_rows: + return + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + labels = [MODE_DISPLAY_LABELS.get(str(r["mode"]), str(r["mode"])) for r in summary_rows] + values = [float(r["mean_end_to_end_seconds"]) for r in summary_rows] + errors = [float(r.get("std_end_to_end_seconds", 0.0)) for r in summary_rows] + x_pos = np.arange(len(summary_rows)) + palette = ["#4C72B0", "#C44E52", "#55A868", "#8172B2"] + colors = [palette[idx % len(palette)] for idx in range(len(summary_rows))] + fig, ax = plt.subplots(figsize=(7.0, 4.5)) + bars = ax.bar(x_pos, values, yerr=errors, capsize=4, width=0.6, color=colors) + for bar, value in zip(bars, values): + ax.text(bar.get_x() + bar.get_width() / 2, value, f"{value:.1f}s", ha="center", va="bottom", fontsize=9) + ax.set_ylabel("End-to-end seconds") + ax.set_title("End-to-End Time by Pipeline Path") + ax.set_xticks(x_pos, labels=labels) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def plot_stage_breakdown(summary_rows: list[dict[str, Any]], output_path: Path) -> None: + if not summary_rows: + return + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + labels = [MODE_DISPLAY_LABELS.get(str(r["mode"]), str(r["mode"])) for r in summary_rows] + tiling = [float(r.get("mean_tiling_seconds") or 0.0) for r in summary_rows] + embedding = [float(r.get("mean_embedding_seconds") or 0.0) for r in summary_rows] + x_pos = np.arange(len(summary_rows)) + fig, ax = plt.subplots(figsize=(7.0, 4.5)) + ax.bar(x_pos, tiling, 0.6, label="Tiling", color="#4C72B0") + ax.bar(x_pos, embedding, 0.6, bottom=tiling, label="Embedding", color="#55A868") + ax.set_ylabel("Seconds") + ax.set_title("Stage Breakdown by Pipeline Path") + ax.set_xticks(x_pos, labels=labels) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + ax.legend(loc="upper right", fontsize=9) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def plot_embedding_subpath_breakdown(summary_rows: list[dict[str, Any]], output_path: Path) -> None: + if not summary_rows: + return + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + labels = [MODE_DISPLAY_LABELS.get(str(r["mode"]), str(r["mode"])) for r in summary_rows] + data_pipeline = [float(r.get("mean_data_pipeline_seconds", 0.0)) for r in summary_rows] + forward = [float(r.get("mean_forward_seconds", 0.0)) for r in summary_rows] + x_pos = np.arange(len(summary_rows)) + + fig, ax = plt.subplots(figsize=(max(7.0, 1.8 * len(summary_rows)), 4.8)) + ax.bar(x_pos, data_pipeline, 0.6, label="Data pipeline", color="#4C72B0") + ax.bar(x_pos, forward, 0.6, bottom=data_pipeline, label="Model forward", color="#55A868") + ax.set_ylabel("Seconds across timed embedding batches") + ax.set_title("Embedding Subpath Breakdown by Pipeline Path") + ax.set_xticks(x_pos, labels=labels) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + ax.legend(loc="upper right", fontsize=9) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def _prepare_chart_outputs(trial_rows: list[dict[str, Any]], output_dir: Path) -> list[dict[str, Any]]: + summary_rows = aggregate_trial_results(trial_rows) + save_csv(summary_rows, output_dir / "summary.csv") + plot_end_to_end_by_path(summary_rows, output_dir / "end_to_end_by_path.png") + plot_stage_breakdown(summary_rows, output_dir / "stage_breakdown.png") + plot_embedding_subpath_breakdown(summary_rows, output_dir / "embedding_subpath_breakdown.png") + return summary_rows + + +def _load_trial_results_csvs(paths: list[Path]) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + for path in paths: + with path.open(newline="") as handle: + rows.extend(dict(row) for row in csv.DictReader(handle)) + int_fields = {"repeat_index", "exit_code", "slides_total", "slides_with_tiles", "failed_slides", "total_tiles", "timed_batches"} + float_fields = { + "end_to_end_seconds", + "tiles_per_second", + "tiling_seconds", + "embedding_seconds", + "mean_loader_wait_ms", + "max_loader_wait_ms", + "mean_ready_wait_ms", + "mean_preprocess_ms", + "mean_worker_batch_ms", + "mean_reader_open_ms", + "mean_reader_read_ms", + "mean_forward_ms", + "data_pipeline_seconds", + "forward_seconds", + "accounted_embedding_seconds", + "data_pipeline_fraction", + "forward_fraction", + "loader_wait_fraction", + "gpu_busy_fraction", + } + coerced: list[dict[str, Any]] = [] + for row in rows: + parsed: dict[str, Any] = {} + for key, value in row.items(): + if value == "": + parsed[key] = "" + elif key in int_fields: + parsed[key] = int(float(value)) + elif key in float_fields: + parsed[key] = float(value) + else: + parsed[key] = value + coerced.append(parsed) + return coerced + + +def _make_summary_table(summary_rows: list[dict[str, Any]]) -> str: + from rich.table import Table + + baseline_end_to_end: float | None = None + for row in summary_rows: + if str(row["mode"]) == "tar": + baseline_end_to_end = float(row["mean_end_to_end_seconds"]) + break + + table = Table(title="End-to-End Path Summary", show_lines=True) + table.add_column("Mode", style="bold") + table.add_column("End-to-end", justify="right") + table.add_column("vs tar", justify="right") + table.add_column("Tiles/s", justify="right") + table.add_column("Tiling", justify="right") + table.add_column("Embedding", justify="right") + table.add_column("Data path", justify="right") + table.add_column("Forward", justify="right") + table.add_column("Data share", justify="right") + table.add_column("GPU busy", justify="right") + table.add_column("Reps", justify="right", style="dim") + + for row in summary_rows: + mode = str(row["mode"]) + end_to_end = float(row["mean_end_to_end_seconds"]) + if baseline_end_to_end and baseline_end_to_end > 0: + relative = end_to_end / baseline_end_to_end + rel_style = "green" if relative <= 1.0 else "yellow" + relative_str = f"{relative:.2f}×" + else: + rel_style = "dim" + relative_str = "—" + table.add_row( + MODE_DISPLAY_LABELS.get(mode, mode), + f"{end_to_end:.1f}s", + f"[{rel_style}]{relative_str}[/{rel_style}]", + f"{float(row['mean_tiles_per_second']):,.1f}", + f"{float(row.get('mean_tiling_seconds') or 0.0):.1f}s", + f"{float(row.get('mean_embedding_seconds') or 0.0):.1f}s", + f"{float(row.get('mean_data_pipeline_seconds', 0.0)):.1f}s", + f"{float(row.get('mean_forward_seconds', 0.0)):.1f}s", + f"{100.0 * float(row.get('mean_data_pipeline_fraction', 0.0)):.1f}%", + f"{100.0 * float(row.get('gpu_busy_fraction', 0.0)):.1f}%", + str(int(row.get("repeat_count", 0))), + ) + return table + + +def _print_log_panel(console: "Any", log_path: Path, title: str = "Error log") -> None: + from rich.panel import Panel + + log = "" + if log_path.is_file(): + log = log_path.read_text(encoding="utf-8").strip() + if not log: + log = "(no output captured)" + console.print(Panel(log, title=f"[red]{title}[/]", border_style="red", highlight=False)) + + +def run_benchmark(args: argparse.Namespace) -> int: + from rich.console import Console + from rich.panel import Panel + from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + ) + + console = Console() + output_dir = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + if args.csv is None: + console.print("[red]ERROR:[/] --csv is required.") + return 1 + if args.config_file is None: + console.print("[red]ERROR:[/] --config-file is required.") + return 1 + slides = load_slides_from_csv(args.csv) + if not slides: + console.print("[red]ERROR:[/] manifest is empty.") + return 1 + shared_csv = output_dir / "slides.csv" + write_slides_csv(slides, shared_csv) + base = _default_base_config( + csv_path=shared_csv, + output_dir=output_dir / "trial_output", + batch_size=args.batch_size, + num_dataloader_workers=args.num_dataloader_workers, + num_preprocessing_workers=args.num_preprocessing_workers, + num_cucim_workers=args.num_cucim_workers, + ) + config = _merge_base_config(base, args.config_file) + + console.rule("[bold cyan]Benchmark") + console.print( + f" [bold]{len(ALL_MODES)}[/] paths · " + f"[bold]{args.repeat}[/] repeat · " + f"[bold]{args.warmup}[/] warmup · " + f"batch [bold]{args.batch_size}[/] · " + f"config [bold]{args.config_file.name}[/]" + ) + console.print() + + trial_rows: list[dict[str, Any]] = [] + total_trials = len(ALL_MODES) * (args.warmup + args.repeat) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + console=console, + ) as progress: + overall_task = progress.add_task("[bold]Overall", total=total_trials) + trial_task = progress.add_task("", total=None) + + for mode in ALL_MODES: + mode_dir = output_dir / "runs" / mode + for rep_idx in range(args.warmup + args.repeat): + is_warmup = rep_idx < args.warmup + kind = "warmup" if is_warmup else "measure" + rep_num = rep_idx if is_warmup else rep_idx - args.warmup + 1 + run_dir = mode_dir / ("warmup" if is_warmup else f"rep-{rep_num:02d}") + + if is_warmup: + desc = f"[dim]warmup {mode}[/]" + else: + desc = f"[bold cyan]{mode}[/] rep [bold]{rep_num}[/]/{args.repeat}" + progress.update(trial_task, description=desc) + + row = run_trial(mode=mode, kind=kind, repeat_index=rep_num, run_dir=run_dir, config=config) + progress.advance(overall_task) + + ok = row["exit_code"] == 0 + icon = "[green]✓[/]" if ok else "[red]✗[/]" + if is_warmup: + progress.console.log(f"{icon} [dim]warmup[/] {mode} {row['end_to_end_seconds']:.1f}s") + if not ok: + _print_log_panel(progress.console, run_dir / "harness.log", title=f"warmup {mode} — error log") + else: + progress.console.log( + f"{icon} [bold]{mode}[/] rep {rep_num}/{args.repeat} " + f"{row['end_to_end_seconds']:.1f}s total " + f"[bold yellow]{row['tiles_per_second']:,.1f}[/] tiles/s" + + (f" [red]exit={row['exit_code']}[/]" if not ok else "") + ) + if not ok: + _print_log_panel(progress.console, run_dir / "harness.log", title=f"{mode} rep {rep_num} — error log") + trial_rows.append(row) + + progress.update(trial_task, visible=False) + + save_csv(trial_rows, output_dir / "trial_results.csv") + console.print(f"\n[dim]Trial results →[/] {output_dir / 'trial_results.csv'}") + + console.rule("[bold cyan]Results") + summary_rows = _prepare_chart_outputs(trial_rows, output_dir) + if not summary_rows: + return 1 + console.print(_make_summary_table(summary_rows)) + console.print( + Panel( + f"[dim]end_to_end_by_path.png[/]\n[dim]stage_breakdown.png[/]\n[dim]embedding_subpath_breakdown.png[/]\n[dim]summary.csv[/]", + title=f"[bold]Saved to[/] {output_dir}", + expand=False, + ) + ) + return 0 + + +def main() -> int: + args = parse_args() + if args.internal_harness: + return _run_internal_harness(args) + if args.chart_only: + from rich.console import Console + + console = Console() + summary_rows = _prepare_chart_outputs(_load_trial_results_csvs(args.chart_only), args.output_dir) + if not summary_rows: + return 1 + console.print(_make_summary_table(summary_rows)) + return 0 + return run_benchmark(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/benchmark_tile_read_strategies.py b/scripts/benchmark_tile_read_strategies.py new file mode 100644 index 0000000..c518dd3 --- /dev/null +++ b/scripts/benchmark_tile_read_strategies.py @@ -0,0 +1,1296 @@ +#!/usr/bin/env python3 +"""Benchmark tile reading strategies for slide2vec on-the-fly and tar paths. + +Compares five configurations in increasing order of optimization: + tar - pre-extracted tar archives (cucim+supertiles+turbojpeg extraction) + wsd_single - WSD per-tile reads (ASAP backend, no cucim) + cucim_single - cucim batched read_region, one location per tile + cucim_supertiles - cucim read_region per 8x8/4x4/2x2 super tile block + cucim_supertiles_adaptive - same + SuperTileBatchSampler aligns batches to block boundaries +""" + +from __future__ import annotations + +import argparse +import csv +import json +import math +import os +import shutil +import statistics +import subprocess +import sys +import time +from pathlib import Path +from typing import Any + +import numpy as np + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_OUTPUT_DIR = Path("output/benchmark-read-strategies") +HEAVY_ARTIFACT_DIRS = ( + "tile_embeddings", + "slide_embeddings", + "slide_latents", + "previews", +) + +ALL_MODES = [ + "tar", + "wsd_single", + "wsd_supertiles", + "cucim_single", + "cucim_supertiles", + "cucim_supertiles_adaptive", +] + +MODE_CONFIGS: dict[str, dict[str, Any]] = { + "tar": dict( + on_the_fly=False, + backend="cucim", + use_supertiles=True, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), + "wsd_single": dict( + on_the_fly=True, + backend="asap", + use_supertiles=False, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), + "wsd_supertiles": dict( + on_the_fly=True, + backend="asap", + use_supertiles=True, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), + "cucim_single": dict( + on_the_fly=True, + backend="cucim", + use_supertiles=False, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), + "cucim_supertiles": dict( + on_the_fly=True, + backend="cucim", + use_supertiles=True, + adaptive_batching=False, + jpeg_backend="turbojpeg", + ), + "cucim_supertiles_adaptive": dict( + on_the_fly=True, + backend="cucim", + use_supertiles=True, + adaptive_batching=True, + jpeg_backend="turbojpeg", + ), +} + +MODE_DISPLAY_LABELS: dict[str, str] = { + "tar": "tar", + "wsd_single": "wsd\n(single)", + "wsd_supertiles": "wsd\n(supertiles)", + "cucim_single": "cucim\n(single)", + "cucim_supertiles": "cucim\n(supertiles)", + "cucim_supertiles_adaptive": "cucim\n(supertiles\nadaptive)", +} + + +def _prepend_repo_root_to_sys_path(paths: list[str]) -> list[str]: + repo_root = str(REPO_ROOT) + return [repo_root, *[path for path in paths if path != repo_root]] + + +sys.path[:] = _prepend_repo_root_to_sys_path(sys.path) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark tile reading strategies for slide2vec.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--csv", type=Path, required=False, help="Slide manifest CSV.") + parser.add_argument("--config-file", type=Path, required=False, help="Base slide2vec YAML config (optional).") + parser.add_argument("--model", type=str, default="phikonv2", help="Model name.") + parser.add_argument( + "--modes", + nargs="+", + default=ALL_MODES, + choices=ALL_MODES, + metavar="MODE", + help=f"Reading strategies to benchmark. Choices: {', '.join(ALL_MODES)}", + ) + parser.add_argument("--repeat", type=int, default=3, help="Timed repetitions per mode.") + parser.add_argument("--warmup", type=int, default=1, help="Untimed warmup reps per mode.") + parser.add_argument("--batch-size", type=int, default=256, help="Fixed batch size.") + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=None, + metavar="BATCH", + help="Optional batch-size sweep. When set, runs every mode for each listed batch size.", + ) + parser.add_argument( + "--num-dataloader-workers", + type=int, + default=32, + help="DataLoader workers for the tar path. On-the-fly path auto-derives from cpu_count // num-cucim-workers.", + ) + parser.add_argument("--num-cucim-workers", type=int, default=4, help="cucim internal threads per read_region call.") + parser.add_argument("--num-preprocessing-workers", type=int, default=8, help="Workers for hs2p tiling phase.") + parser.add_argument( + "--output-dir", + type=Path, + default=DEFAULT_OUTPUT_DIR, + help="Results directory.", + ) + parser.add_argument( + "--chart-only", + type=Path, + nargs="+", + default=None, + metavar="TRIAL_RESULTS_CSV", + help="Skip benchmarking and regenerate charts from existing trial-results CSV files.", + ) + + # Hidden flags for the subprocess harness + parser.add_argument("--internal-harness", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--harness-mode", type=str, default=None, help=argparse.SUPPRESS) + parser.add_argument("--harness-config", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--metrics-json", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--progress-jsonl", type=Path, default=None, help=argparse.SUPPRESS) + return parser.parse_args() + + +def _resolve_batch_sizes(args: argparse.Namespace) -> list[int]: + values = args.batch_sizes if args.batch_sizes else [args.batch_size] + resolved: list[int] = [] + for value in values: + batch_size = int(value) + if batch_size < 1: + raise ValueError("Batch sizes must be positive integers") + if batch_size not in resolved: + resolved.append(batch_size) + return resolved + + +# --------------------------------------------------------------------------- +# Slide loading helpers (reused from benchmark_embedding_throughput.py pattern) +# --------------------------------------------------------------------------- + +def load_slides_from_csv(csv_path: Path) -> list[dict[str, Any]]: + slides: list[dict[str, Any]] = [] + with csv_path.open(newline="") as handle: + reader = csv.DictReader(handle) + fieldnames = set(reader.fieldnames or []) + has_mask = "mask_path" in fieldnames + has_spacing = "spacing_at_level_0" in fieldnames + has_sample_id = "sample_id" in fieldnames + for row in reader: + image_path = Path(row["image_path"]) + mask_path = Path(row["mask_path"]) if has_mask and row.get("mask_path") else None + raw_spacing = row.get("spacing_at_level_0", "") if has_spacing else "" + spacing_at_level_0 = float(raw_spacing) if raw_spacing.strip() else None + sample_id = row["sample_id"] if has_sample_id else image_path.stem + slides.append( + { + "sample_id": sample_id, + "image_path": image_path, + "mask_path": mask_path, + "spacing_at_level_0": spacing_at_level_0, + } + ) + return slides + + +def write_slides_csv(slides: list[dict[str, Any]], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + has_spacing = any(slide.get("spacing_at_level_0") is not None for slide in slides) + fieldnames = ["sample_id", "image_path", "mask_path"] + if has_spacing: + fieldnames.append("spacing_at_level_0") + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for slide in slides: + row: dict[str, Any] = { + "sample_id": slide["sample_id"], + "image_path": str(slide["image_path"]), + "mask_path": str(slide["mask_path"]) if slide["mask_path"] is not None else "", + } + if has_spacing: + row["spacing_at_level_0"] = slide.get("spacing_at_level_0") or "" + writer.writerow(row) + + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- + +def _load_yaml(path: Path) -> dict[str, Any]: + import yaml + + with path.open(encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + if not isinstance(data, dict): + raise ValueError(f"Expected mapping config in {path}") + return data + + +def _write_yaml(data: dict[str, Any], path: Path) -> None: + import yaml + + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + yaml.safe_dump(data, handle, sort_keys=False) + + +def _default_base_config( + *, + model_name: str, + csv_path: Path, + output_dir: Path, + batch_size: int, + num_dataloader_workers: int, + num_preprocessing_workers: int, + num_cucim_workers: int, +) -> dict[str, Any]: + """Build a minimal config dict without requiring a YAML file.""" + return { + "csv": str(csv_path), + "output_dir": str(output_dir), + "resume": False, + "save_previews": False, + "model": { + "name": model_name, + "level": "tile", + "batch_size": batch_size, + "save_tile_embeddings": False, + "save_latents": False, + }, + "tiling": { + "on_the_fly": True, + "gpu_decode": False, + "adaptive_batching": False, + "use_supertiles": True, + "jpeg_backend": "turbojpeg", + "backend": "cucim", + "read_coordinates_from": None, + "read_tiles_from": None, + "params": { + "target_spacing_um": 0.5, + "tolerance": 0.05, + "target_tile_size_px": 224, + "overlap": 0.0, + "tissue_threshold": 0.1, + "drop_holes": False, + "use_padding": True, + }, + "seg_params": { + "downsample": 64, + "sthresh": 8, + "sthresh_up": 255, + "mthresh": 7, + "close": 4, + "use_otsu": False, + "use_hsv": True, + }, + "filter_params": { + "ref_tile_size": 224, + "a_t": 4, + "a_h": 2, + "max_n_holes": 8, + "filter_white": False, + "filter_black": False, + "white_threshold": 220, + "black_threshold": 25, + "fraction_threshold": 0.9, + }, + "preview": { + "downsample": 32, + }, + }, + "speed": { + "precision": "fp32", + "num_preprocessing_workers": num_preprocessing_workers, + "num_dataloader_workers": num_dataloader_workers, + "num_cucim_workers": num_cucim_workers, + "prefetch_factor_embedding": 4, + "persistent_workers_embedding": True, + "gpu_batch_preprocessing": True, + }, + "wandb": {"enable": False}, + } + + +def _merge_base_config(base: dict[str, Any], config_file: Path | None) -> dict[str, Any]: + """If a config file is provided, use it as the starting point; otherwise use base.""" + if config_file is None: + return base + import copy + + file_data = _load_yaml(config_file) + merged = copy.deepcopy(file_data) + # Override with our baseline settings + merged["csv"] = base["csv"] + merged["output_dir"] = base["output_dir"] + merged["resume"] = False + merged["save_previews"] = False + merged.setdefault("model", {})["batch_size"] = base["model"]["batch_size"] + merged.setdefault("speed", {}) + merged["speed"]["num_preprocessing_workers"] = base["speed"]["num_preprocessing_workers"] + merged["speed"]["num_dataloader_workers"] = base["speed"]["num_dataloader_workers"] + merged["speed"]["num_cucim_workers"] = base["speed"]["num_cucim_workers"] + merged.setdefault("tiling", {}) + merged.setdefault("wandb", {})["enable"] = False + return merged + + +def _apply_mode_overrides( + config: dict[str, Any], + mode: str, + *, + batch_size: int, + read_coordinates_from: Path, + read_tiles_from: Path | None, +) -> dict[str, Any]: + import copy + + cfg = copy.deepcopy(config) + mode_cfg = MODE_CONFIGS[mode] + cfg.setdefault("model", {})["batch_size"] = int(batch_size) + cfg["tiling"]["on_the_fly"] = mode_cfg["on_the_fly"] + cfg["tiling"]["backend"] = mode_cfg["backend"] + cfg["tiling"]["use_supertiles"] = mode_cfg["use_supertiles"] + cfg["tiling"]["adaptive_batching"] = mode_cfg["adaptive_batching"] + cfg["tiling"]["jpeg_backend"] = mode_cfg["jpeg_backend"] + cfg["tiling"]["read_coordinates_from"] = str(read_coordinates_from) + cfg["tiling"]["read_tiles_from"] = str(read_tiles_from) if read_tiles_from is not None else None + return cfg + + +# --------------------------------------------------------------------------- +# Progress / metrics helpers +# --------------------------------------------------------------------------- + +def load_progress_records(path: Path) -> list[dict[str, Any]]: + if not path.is_file(): + return [] + records: list[dict[str, Any]] = [] + with path.open(encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + payload = json.loads(line) + if isinstance(payload, dict): + records.append(payload) + return records + + +def extract_stage_seconds(progress_path: Path) -> dict[str, float | None]: + records = load_progress_records(progress_path) + stage_seconds: dict[str, float | None] = { + "tiling_seconds": None, + "embedding_seconds": None, + } + if not records: + return stage_seconds + + first_timestamps: dict[str, float] = {} + for record in records: + kind = record.get("kind") + timestamp = record.get("timestamp") + if kind is None or timestamp is None: + continue + timestamp = float(timestamp) + if kind not in first_timestamps: + first_timestamps[kind] = timestamp + + if "tiling.started" in first_timestamps and "tiling.finished" in first_timestamps: + stage_seconds["tiling_seconds"] = round( + first_timestamps["tiling.finished"] - first_timestamps["tiling.started"], 4 + ) + if "embedding.started" in first_timestamps and "embedding.finished" in first_timestamps: + stage_seconds["embedding_seconds"] = round( + first_timestamps["embedding.finished"] - first_timestamps["embedding.started"], 4 + ) + return stage_seconds + + +def extract_batch_timing_metrics(progress_path: Path) -> dict[str, float | int]: + records = load_progress_records(progress_path) + batch_payloads = [ + record.get("payload", {}) + for record in records + if record.get("kind") == "embedding.batch.timing" and isinstance(record.get("payload"), dict) + ] + zeros: dict[str, float | int] = { + "timed_batches": 0, + "mean_loader_wait_ms": 0.0, + "max_loader_wait_ms": 0.0, + "mean_ready_wait_ms": 0.0, + "mean_preprocess_ms": 0.0, + "mean_worker_batch_ms": 0.0, + "mean_reader_open_ms": 0.0, + "mean_reader_read_ms": 0.0, + "mean_forward_ms": 0.0, + "loader_wait_fraction": 0.0, + "gpu_busy_fraction": 0.0, + } + if not batch_payloads: + return zeros + + loader_wait_ms = [float(p.get("loader_wait_ms", 0.0)) for p in batch_payloads] + ready_wait_ms = [float(p.get("ready_wait_ms", 0.0)) for p in batch_payloads] + preprocess_ms = [float(p.get("preprocess_ms", 0.0)) for p in batch_payloads] + worker_batch_ms = [float(p.get("worker_batch_ms", 0.0)) for p in batch_payloads] + reader_open_ms = [float(p.get("reader_open_ms", 0.0)) for p in batch_payloads] + reader_read_ms = [float(p.get("reader_read_ms", 0.0)) for p in batch_payloads] + forward_ms = [float(p.get("forward_ms", 0.0)) for p in batch_payloads] + gpu_busy_fraction = [float(p.get("gpu_busy_fraction", 0.0)) for p in batch_payloads] + total_ms = sum(loader_wait_ms) + sum(ready_wait_ms) + sum(preprocess_ms) + sum(forward_ms) + return { + "timed_batches": len(batch_payloads), + "mean_loader_wait_ms": round(statistics.mean(loader_wait_ms), 4), + "max_loader_wait_ms": round(max(loader_wait_ms), 4), + "mean_ready_wait_ms": round(statistics.mean(ready_wait_ms), 4), + "mean_preprocess_ms": round(statistics.mean(preprocess_ms), 4), + "mean_worker_batch_ms": round(statistics.mean(worker_batch_ms), 4), + "mean_reader_open_ms": round(statistics.mean(reader_open_ms), 4), + "mean_reader_read_ms": round(statistics.mean(reader_read_ms), 4), + "mean_forward_ms": round(statistics.mean(forward_ms), 4), + "loader_wait_fraction": round((sum(loader_wait_ms) + sum(ready_wait_ms)) / total_ms, 4) if total_ms > 0 else 0.0, + "gpu_busy_fraction": round(statistics.mean(gpu_busy_fraction), 4), + } + + +def parse_process_list(path: Path) -> dict[str, int]: + if not path.is_file(): + return {"slides_total": 0, "slides_with_tiles": 0, "failed_slides": 0, "total_tiles": 0} + with path.open(newline="") as handle: + rows = list(csv.DictReader(handle)) + total_tiles = sum(int(float(row.get("num_tiles") or 0)) for row in rows) + slides_with_tiles = sum(int(float(row.get("num_tiles") or 0)) > 0 for row in rows) + failed_slides = sum(row.get("tiling_status") == "failed" for row in rows) + return { + "slides_total": len(rows), + "slides_with_tiles": slides_with_tiles, + "failed_slides": failed_slides, + "total_tiles": total_tiles, + } + + +def save_csv(rows: list[dict[str, Any]], path: Path) -> None: + if not rows: + return + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + +# --------------------------------------------------------------------------- +# Internal harness (runs inside subprocess) +# --------------------------------------------------------------------------- + +def _build_pipeline_from_config_dict(config: dict[str, Any]): + from slide2vec import ExecutionOptions, Model, Pipeline, PreprocessingConfig + + model_cfg = config.get("model", {}) + tiling_cfg = config.get("tiling", {}) + params = tiling_cfg.get("params", {}) + preview = dict(tiling_cfg.get("preview", {})) + speed_cfg = config.get("speed", {}) + + preprocessing = PreprocessingConfig( + backend=str(tiling_cfg.get("backend", "cucim")), + target_spacing_um=float(params.get("target_spacing_um", 0.5)), + target_tile_size_px=int(params.get("target_tile_size_px", 256)), + tolerance=float(params.get("tolerance", 0.05)), + overlap=float(params.get("overlap", 0.0)), + tissue_threshold=float(params.get("tissue_threshold", 0.01)), + drop_holes=bool(params.get("drop_holes", False)), + use_padding=bool(params.get("use_padding", True)), + read_coordinates_from=( + Path(tiling_cfg["read_coordinates_from"]) + if tiling_cfg.get("read_coordinates_from") + else Path(config["output_dir"]) / "coordinates" + ), + read_tiles_from=( + Path(tiling_cfg["read_tiles_from"]) + if tiling_cfg.get("read_tiles_from") + else None + ), + on_the_fly=bool(tiling_cfg.get("on_the_fly", True)), + gpu_decode=bool(tiling_cfg.get("gpu_decode", False)), + adaptive_batching=bool(tiling_cfg.get("adaptive_batching", False)), + use_supertiles=bool(tiling_cfg.get("use_supertiles", True)), + jpeg_backend=str(tiling_cfg.get("jpeg_backend", "turbojpeg")), + num_cucim_workers=int(speed_cfg.get("num_cucim_workers", tiling_cfg.get("num_cucim_workers", 4))), + resume=bool(config.get("resume", False)), + segmentation=dict(tiling_cfg.get("seg_params", {})), + filtering=dict(tiling_cfg.get("filter_params", {})), + preview={ + "save_mask_preview": bool(config.get("save_previews", False)), + "save_tiling_preview": bool(config.get("save_previews", False)), + "downsample": int(preview.get("downsample", 32)), + }, + ) + execution = ExecutionOptions( + output_dir=Path(config["output_dir"]), + batch_size=int(model_cfg.get("batch_size", 256)), + num_workers=int(speed_cfg.get("num_dataloader_workers", speed_cfg.get("num_workers_embedding", 32))), + num_preprocessing_workers=int(speed_cfg.get("num_preprocessing_workers", 8)), + precision=str(speed_cfg.get("precision", "fp32")), + prefetch_factor=int(speed_cfg.get("prefetch_factor_embedding", 4)), + persistent_workers=bool(speed_cfg.get("persistent_workers_embedding", True)), + gpu_batch_preprocessing=bool(speed_cfg.get("gpu_batch_preprocessing", True)), + save_tile_embeddings=bool(model_cfg.get("save_tile_embeddings", False)), + save_latents=bool(model_cfg.get("save_latents", False)), + ) + model = Model.from_pretrained( + str(model_cfg["name"]), + level=model_cfg.get("level", "tile"), + mode=model_cfg.get("mode"), + arch=model_cfg.get("arch"), + pretrained_weights=model_cfg.get("pretrained_weights"), + input_size=model_cfg.get("input_size"), + patch_size=model_cfg.get("patch_size"), + token_size=model_cfg.get("token_size"), + normalize_embeddings=model_cfg.get("normalize_embeddings"), + device="auto", + ) + return Pipeline(model=model, preprocessing=preprocessing, execution=execution) + + +def _run_internal_harness(args: argparse.Namespace) -> int: + if args.harness_config is None or args.metrics_json is None or args.progress_jsonl is None: + raise ValueError("--internal-harness requires --harness-config, --metrics-json, and --progress-jsonl") + + from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter + + config = _load_yaml(args.harness_config) + pipeline = _build_pipeline_from_config_dict(config) + output_dir = Path(config["output_dir"]) + progress_path = Path(args.progress_jsonl) + metrics_path = Path(args.metrics_json) + progress_path.parent.mkdir(parents=True, exist_ok=True) + metrics_path.parent.mkdir(parents=True, exist_ok=True) + + reporter = JsonlProgressReporter(progress_path) + metrics: dict[str, Any] = {} + t0 = time.perf_counter() + try: + with activate_progress_reporter(reporter): + result = pipeline.run(manifest_path=config["csv"]) + end_to_end_seconds = time.perf_counter() - t0 + process_stats = parse_process_list(output_dir / "process_list.csv") + stage_seconds = extract_stage_seconds(progress_path) + batch_timing = extract_batch_timing_metrics(progress_path) + slides_total = int(process_stats["slides_total"]) + tiles_per_second = process_stats["total_tiles"] / end_to_end_seconds if end_to_end_seconds > 0 else 0.0 + metrics = { + "success": True, + "tile_artifacts": len(result.tile_artifacts), + "slide_artifacts": len(result.slide_artifacts), + "slides_total": slides_total, + "slides_with_tiles": int(process_stats["slides_with_tiles"]), + "failed_slides": int(process_stats["failed_slides"]), + "total_tiles": int(process_stats["total_tiles"]), + "end_to_end_seconds": round(end_to_end_seconds, 4), + "tiles_per_second": round(tiles_per_second, 4), + **stage_seconds, + **batch_timing, + } + except Exception as exc: + end_to_end_seconds = time.perf_counter() - t0 + process_stats = parse_process_list(output_dir / "process_list.csv") + metrics = { + "success": False, + "error": str(exc), + "slides_total": int(process_stats["slides_total"]), + "slides_with_tiles": int(process_stats["slides_with_tiles"]), + "failed_slides": int(process_stats["failed_slides"]), + "total_tiles": int(process_stats["total_tiles"]), + "end_to_end_seconds": round(end_to_end_seconds, 4), + "tiles_per_second": 0.0, + **extract_stage_seconds(progress_path), + **extract_batch_timing_metrics(progress_path), + } + metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return 1 + + metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return 0 + + +# --------------------------------------------------------------------------- +# Subprocess trial runner +# --------------------------------------------------------------------------- + +def _run_trial_subprocess(*, config_path: Path, metrics_path: Path, progress_path: Path, log_path: Path) -> subprocess.CompletedProcess[str]: + command = [ + sys.executable, + str(Path(__file__).resolve()), + "--internal-harness", + "--harness-config", + str(config_path), + "--metrics-json", + str(metrics_path), + "--progress-jsonl", + str(progress_path), + ] + completed = subprocess.run(command, cwd=REPO_ROOT, capture_output=True, text=True) + log_path.write_text((completed.stdout or "") + (completed.stderr or ""), encoding="utf-8") + return completed + + +def cleanup_trial_output(output_dir: Path) -> None: + for dirname in HEAVY_ARTIFACT_DIRS: + candidate = output_dir / dirname + if candidate.exists(): + shutil.rmtree(candidate) + + +def run_trial( + *, + mode: str, + batch_size: int, + kind: str, + repeat_index: int, + run_dir: Path, + config: dict[str, Any], + read_coordinates_from: Path, + read_tiles_from: Path | None, +) -> dict[str, Any]: + run_dir.mkdir(parents=True, exist_ok=True) + config_path = run_dir / "config.yaml" + progress_path = run_dir / "progress.jsonl" + metrics_path = run_dir / "metrics.json" + log_path = run_dir / "harness.log" + trial_output_dir = run_dir / "output" + + trial_config = _apply_mode_overrides( + config, + mode, + batch_size=batch_size, + read_coordinates_from=read_coordinates_from, + read_tiles_from=read_tiles_from, + ) + trial_config["output_dir"] = str(trial_output_dir) + _write_yaml(trial_config, config_path) + + completed = _run_trial_subprocess( + config_path=config_path, + metrics_path=metrics_path, + progress_path=progress_path, + log_path=log_path, + ) + metrics = json.loads(metrics_path.read_text(encoding="utf-8")) if metrics_path.is_file() else {} + cleanup_trial_output(trial_output_dir) + + return { + "mode": mode, + "batch_size": int(batch_size), + "kind": kind, + "repeat_index": repeat_index, + "exit_code": int(completed.returncode), + "slides_total": int(metrics.get("slides_total", 0)), + "slides_with_tiles": int(metrics.get("slides_with_tiles", 0)), + "failed_slides": int(metrics.get("failed_slides", 0)), + "total_tiles": int(metrics.get("total_tiles", 0)), + "end_to_end_seconds": float(metrics.get("end_to_end_seconds", 0.0)), + "tiles_per_second": float(metrics.get("tiles_per_second", 0.0)), + "tiling_seconds": metrics.get("tiling_seconds") or "", + "embedding_seconds": metrics.get("embedding_seconds") or "", + "timed_batches": int(metrics.get("timed_batches", 0)), + "mean_loader_wait_ms": float(metrics.get("mean_loader_wait_ms", 0.0)), + "max_loader_wait_ms": float(metrics.get("max_loader_wait_ms", 0.0)), + "mean_ready_wait_ms": float(metrics.get("mean_ready_wait_ms", 0.0)), + "mean_preprocess_ms": float(metrics.get("mean_preprocess_ms", 0.0)), + "mean_worker_batch_ms": float(metrics.get("mean_worker_batch_ms", 0.0)), + "mean_reader_open_ms": float(metrics.get("mean_reader_open_ms", 0.0)), + "mean_reader_read_ms": float(metrics.get("mean_reader_read_ms", 0.0)), + "mean_forward_ms": float(metrics.get("mean_forward_ms", 0.0)), + "loader_wait_fraction": float(metrics.get("loader_wait_fraction", 0.0)), + "gpu_busy_fraction": float(metrics.get("gpu_busy_fraction", 0.0)), + "error": metrics.get("error", ""), + } + + +# --------------------------------------------------------------------------- +# Setup: tile once to produce coordinates + tar archives +# --------------------------------------------------------------------------- + +def _setup_tiling( + *, + config: dict[str, Any], + setup_dir: Path, + csv_path: Path, + status: "Any | None" = None, +) -> tuple[Path, Path]: + """Run a tiling-only pass (tar path) to produce coordinates and tile archives. + + Returns (coordinates_dir, tiles_dir). hs2p writes: + output_dir/coordinates/ — NPZ + meta JSON per slide + output_dir/tiles/ — .tiles.tar per slide + """ + # hs2p writes to output_dir/coordinates and output_dir/tiles + coordinates_dir = setup_dir / "coordinates" + tiles_dir = setup_dir / "tiles" + + if coordinates_dir.exists() and tiles_dir.exists(): + tar_files = list(tiles_dir.glob("*.tiles.tar")) + if tar_files: + if status is not None: + status.update("Reusing existing tile stores") + return coordinates_dir, tiles_dir + + if status is not None: + status.update("Tiling slides (runs once) …") + import copy + + setup_config = copy.deepcopy(config) + setup_config["csv"] = str(csv_path) + setup_config["output_dir"] = str(setup_dir) + setup_config["resume"] = True # safe to resume if partially done + setup_config["save_previews"] = False + setup_config["tiling"]["on_the_fly"] = False + setup_config["tiling"]["backend"] = "cucim" + setup_config["tiling"]["use_supertiles"] = True + setup_config["tiling"]["jpeg_backend"] = "turbojpeg" + setup_config["tiling"]["read_coordinates_from"] = None + setup_config["tiling"]["read_tiles_from"] = None + + config_path = setup_dir / "setup_config.yaml" + metrics_path = setup_dir / "setup_metrics.json" + progress_path = setup_dir / "setup_progress.jsonl" + log_path = setup_dir / "setup_harness.log" + setup_dir.mkdir(parents=True, exist_ok=True) + _write_yaml(setup_config, config_path) + + completed = _run_trial_subprocess( + config_path=config_path, + metrics_path=metrics_path, + progress_path=progress_path, + log_path=log_path, + ) + if completed.returncode != 0: + raise RuntimeError(f"exit={completed.returncode}", log_path) + + return coordinates_dir, tiles_dir + + +# --------------------------------------------------------------------------- +# Aggregation helpers +# --------------------------------------------------------------------------- + +def aggregate_trial_results(trial_rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + grouped: dict[tuple[str, int], list[dict[str, Any]]] = {} + for row in trial_rows: + if row.get("exit_code", 0) not in (0, "", None): + continue + key = (str(row["mode"]), int(row.get("batch_size", 256))) + grouped.setdefault(key, []).append(row) + + aggregated: list[dict[str, Any]] = [] + batch_sizes = sorted({batch_size for (_mode, batch_size) in grouped}) + for batch_size in batch_sizes: + for mode in ALL_MODES: + rows = grouped.get((mode, batch_size)) + if not rows: + continue + tiles_per_second = [float(r["tiles_per_second"]) for r in rows] + end_to_end_seconds = [float(r["end_to_end_seconds"]) for r in rows] + loader_wait_ms = [float(r.get("mean_loader_wait_ms", 0.0)) for r in rows] + max_loader_wait_ms = [float(r.get("max_loader_wait_ms", 0.0)) for r in rows] + ready_wait_ms = [float(r.get("mean_ready_wait_ms", 0.0)) for r in rows] + preprocess_ms = [float(r.get("mean_preprocess_ms", 0.0)) for r in rows] + worker_batch_ms = [float(r.get("mean_worker_batch_ms", 0.0)) for r in rows] + reader_open_ms = [float(r.get("mean_reader_open_ms", 0.0)) for r in rows] + reader_read_ms = [float(r.get("mean_reader_read_ms", 0.0)) for r in rows] + forward_ms = [float(r.get("mean_forward_ms", 0.0)) for r in rows] + loader_wait_fraction = [float(r.get("loader_wait_fraction", 0.0)) for r in rows] + gpu_busy_fraction = [float(r.get("gpu_busy_fraction", 0.0)) for r in rows] + aggregated.append( + { + "mode": mode, + "batch_size": int(batch_size), + "repeat_count": len(rows), + "total_tiles": int(rows[0].get("total_tiles", 0)), + "mean_tiles_per_second": round(statistics.mean(tiles_per_second), 4), + "std_tiles_per_second": round(statistics.pstdev(tiles_per_second), 4) if len(tiles_per_second) > 1 else 0.0, + "mean_end_to_end_seconds": round(statistics.mean(end_to_end_seconds), 4), + "mean_loader_wait_ms": round(statistics.mean(loader_wait_ms), 4), + "max_loader_wait_ms": round(max(max_loader_wait_ms), 4), + "mean_ready_wait_ms": round(statistics.mean(ready_wait_ms), 4), + "mean_preprocess_ms": round(statistics.mean(preprocess_ms), 4), + "mean_worker_batch_ms": round(statistics.mean(worker_batch_ms), 4), + "mean_reader_open_ms": round(statistics.mean(reader_open_ms), 4), + "mean_reader_read_ms": round(statistics.mean(reader_read_ms), 4), + "mean_forward_ms": round(statistics.mean(forward_ms), 4), + "loader_wait_fraction": round(statistics.mean(loader_wait_fraction), 4), + "gpu_busy_fraction": round(statistics.mean(gpu_busy_fraction), 4), + } + ) + return aggregated + + +# --------------------------------------------------------------------------- +# Charts +# --------------------------------------------------------------------------- + +def plot_throughput_by_strategy(summary_rows: list[dict[str, Any]], output_path: Path) -> None: + if not summary_rows: + return + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + # Find baseline for speedup annotation + baseline_tps: float | None = None + for row in summary_rows: + if str(row["mode"]) == "tar": + baseline_tps = float(row["mean_tiles_per_second"]) + break + + modes = [str(r["mode"]) for r in summary_rows] + values = [float(r["mean_tiles_per_second"]) for r in summary_rows] + errors = [float(r.get("std_tiles_per_second", 0.0)) for r in summary_rows] + x_pos = np.arange(len(modes)) + labels = [MODE_DISPLAY_LABELS.get(m, m) for m in modes] + + fig, ax = plt.subplots(figsize=(max(7.0, 1.6 * len(modes)), 5.0)) + bars = ax.bar(x_pos, values, yerr=errors, capsize=4, width=0.6, color="#4C72B0", error_kw={"linewidth": 1.2}) + + for bar, value, mode_name in zip(bars, values, modes): + annotation = f"{value:,.1f}" + if baseline_tps is not None and baseline_tps > 0 and mode_name != "tar": + speedup = value / baseline_tps + annotation += f"\n({speedup:.2f}×)" + ax.text( + bar.get_x() + bar.get_width() / 2, + value, + annotation, + ha="center", + va="bottom", + fontsize=8, + ) + + ax.set_ylabel("Tiles / second") + ax.set_title("Tile Reading Strategy Throughput") + ax.set_xticks(x_pos, labels=labels) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def plot_timing_breakdown(summary_rows: list[dict[str, Any]], output_path: Path) -> None: + if not summary_rows: + return + filtered_rows = [row for row in summary_rows if int(row.get("batch_size", 256)) == int(summary_rows[0].get("batch_size", 256))] + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + modes = [str(r["mode"]) for r in filtered_rows] + loader_wait = [float(r.get("mean_loader_wait_ms", 0.0)) for r in filtered_rows] + preprocess = [float(r.get("mean_preprocess_ms", 0.0)) for r in filtered_rows] + forward = [float(r.get("mean_forward_ms", 0.0)) for r in filtered_rows] + x_pos = np.arange(len(modes)) + labels = [MODE_DISPLAY_LABELS.get(m, m) for m in modes] + batch_size = int(filtered_rows[0].get("batch_size", 256)) + + fig, ax = plt.subplots(figsize=(max(7.0, 1.6 * len(modes)), 5.0)) + bar_width = 0.6 + ax.bar(x_pos, loader_wait, bar_width, label="Loader wait", color="#4C72B0") + ax.bar(x_pos, preprocess, bar_width, bottom=loader_wait, label="Preprocess", color="#DD8452") + bottom2 = [a + b for a, b in zip(loader_wait, preprocess)] + ax.bar(x_pos, forward, bar_width, bottom=bottom2, label="Forward pass", color="#55A868") + + ax.set_ylabel("Milliseconds per batch") + ax.set_title(f"Batch Timing Breakdown by Strategy (batch size {batch_size})") + ax.set_xticks(x_pos, labels=labels) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + ax.legend(loc="upper right", fontsize=9) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def plot_throughput_vs_batch_size(summary_rows: list[dict[str, Any]], output_path: Path) -> None: + if not summary_rows: + return + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + grouped: dict[str, list[dict[str, Any]]] = {} + for row in summary_rows: + grouped.setdefault(str(row["mode"]), []).append(row) + + fig, ax = plt.subplots(figsize=(8.0, 5.0)) + for mode in ALL_MODES: + rows = grouped.get(mode) + if not rows: + continue + rows = sorted(rows, key=lambda row: int(row.get("batch_size", 256))) + x = [int(row.get("batch_size", 256)) for row in rows] + y = [float(row["mean_tiles_per_second"]) for row in rows] + ax.plot(x, y, marker="o", linewidth=2.0, label=MODE_DISPLAY_LABELS.get(mode, mode).replace("\n", " ")) + + ax.set_xlabel("Batch size") + ax.set_ylabel("Tiles / second") + ax.set_title("Throughput vs Batch Size by Strategy") + ax.set_xticks(sorted({int(row.get("batch_size", 256)) for row in summary_rows})) + ax.grid(axis="y", color="#e8e8e8", linewidth=0.8) + ax.set_axisbelow(True) + ax.legend(loc="best", fontsize=9) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def _prepare_chart_outputs( + trial_rows: list[dict[str, Any]], + output_dir: Path, + *, + console: "Any | None" = None, +) -> int: + _print = console.print if console is not None else print + if not trial_rows: + _print("[red]No trial rows available for chart generation.[/]" if console else "No trial rows available for chart generation.") + return 1 + summary_rows = aggregate_trial_results(trial_rows) + save_csv(summary_rows, output_dir / "summary.csv") + batch_sizes = sorted({int(row.get("batch_size", 256)) for row in summary_rows}) + if len(batch_sizes) == 1: + plot_throughput_by_strategy(summary_rows, output_dir / "throughput_by_strategy.png") + plot_timing_breakdown(summary_rows, output_dir / "timing_breakdown.png") + else: + plot_throughput_vs_batch_size(summary_rows, output_dir / "throughput_by_batch_size.png") + for batch_size in batch_sizes: + batch_rows = [row for row in summary_rows if int(row.get("batch_size", 256)) == batch_size] + plot_throughput_by_strategy(batch_rows, output_dir / f"throughput_by_strategy_bs{batch_size}.png") + plot_timing_breakdown(batch_rows, output_dir / f"timing_breakdown_bs{batch_size}.png") + return summary_rows + + +def _load_trial_results_csvs(paths: list[Path]) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + for path in paths: + with path.open(newline="") as handle: + rows.extend(dict(row) for row in csv.DictReader(handle)) + # Coerce numeric fields + int_fields = {"repeat_index", "exit_code", "slides_total", "slides_with_tiles", "failed_slides", "total_tiles", "timed_batches"} + int_fields.add("batch_size") + float_fields = { + "end_to_end_seconds", "tiles_per_second", "mean_loader_wait_ms", + "max_loader_wait_ms", "mean_ready_wait_ms", "mean_preprocess_ms", + "mean_worker_batch_ms", "mean_reader_open_ms", "mean_reader_read_ms", + "mean_forward_ms", "loader_wait_fraction", "gpu_busy_fraction", + } + coerced = [] + for row in rows: + parsed: dict[str, Any] = {} + for key, value in row.items(): + if value == "": + parsed[key] = "" + elif key in int_fields: + parsed[key] = int(float(value)) + elif key in float_fields: + parsed[key] = float(value) + else: + parsed[key] = value + coerced.append(parsed) + return coerced + + +# --------------------------------------------------------------------------- +# Rich helpers +# --------------------------------------------------------------------------- + +def _print_log_panel(console: "Any", log_path: Path, title: str = "Error log") -> None: + """Print the contents of a subprocess log file in a red panel.""" + from rich.panel import Panel + + log = "" + if log_path.is_file(): + log = log_path.read_text(encoding="utf-8").strip() + if not log: + log = "(no output captured)" + console.print(Panel(log, title=f"[red]{title}[/]", border_style="red", highlight=False)) + +def _make_summary_table(summary_rows: list[dict[str, Any]], *, baseline_mode: str = "tar") -> "Any": + from rich.table import Table + + single_batch = len({int(r.get("batch_size", 256)) for r in summary_rows}) == 1 + + table = Table(title="Benchmark summary", show_lines=True) + table.add_column("Mode", style="bold") + if not single_batch: + table.add_column("Batch", justify="right") + table.add_column("Tiles/s", justify="right") + table.add_column("± std", justify="right", style="dim") + table.add_column("vs tar", justify="right") + table.add_column("Loader wait", justify="right") + table.add_column("Reader read", justify="right") + table.add_column("GPU busy", justify="right") + table.add_column("Preprocess", justify="right") + table.add_column("Forward", justify="right") + table.add_column("Reps", justify="right", style="dim") + + for r in summary_rows: + mode = str(r["mode"]) + tps = float(r["mean_tiles_per_second"]) + std = float(r.get("std_tiles_per_second", 0.0)) + batch_size = int(r.get("batch_size", 256)) + baseline_tps: float | None = None + for candidate in summary_rows: + if str(candidate["mode"]) == baseline_mode and int(candidate.get("batch_size", 256)) == batch_size: + baseline_tps = float(candidate["mean_tiles_per_second"]) + break + if baseline_tps and baseline_tps > 0: + speedup = tps / baseline_tps + speedup_str = f"{speedup:.2f}×" + speedup_style = "green" if speedup >= 1.0 else "red" + else: + speedup_str = "—" + speedup_style = "dim" + row_values = [ + mode, + ] + if not single_batch: + row_values.append(str(batch_size)) + row_values.extend( + [ + f"{tps:,.1f}", + f"{std:,.1f}", + f"[{speedup_style}]{speedup_str}[/{speedup_style}]", + f"{r.get('mean_loader_wait_ms', 0.0):.1f} ms", + f"{r.get('mean_reader_read_ms', 0.0):.1f} ms", + f"{100.0 * float(r.get('gpu_busy_fraction', 0.0)):.1f}%", + f"{r.get('mean_preprocess_ms', 0.0):.1f} ms", + f"{r.get('mean_forward_ms', 0.0):.1f} ms", + str(r.get("repeat_count", 0)), + ] + ) + table.add_row(*row_values) + return table + + +# --------------------------------------------------------------------------- +# Main benchmark orchestration +# --------------------------------------------------------------------------- + +def run_benchmark(args: argparse.Namespace) -> int: + from rich.console import Console + from rich.panel import Panel + from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + ) + from rich.status import Status + + console = Console() + + output_dir = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + try: + batch_sizes = _resolve_batch_sizes(args) + except ValueError as exc: + console.print(f"[red]ERROR:[/] {exc}") + return 1 + + if args.csv is None: + console.print("[red]ERROR:[/] --csv is required.") + return 1 + + slides = load_slides_from_csv(args.csv) + if not slides: + console.print("[red]ERROR:[/] the manifest is empty.") + return 1 + + shared_csv = output_dir / "slides.csv" + write_slides_csv(slides, shared_csv) + + base = _default_base_config( + model_name=args.model, + csv_path=shared_csv, + output_dir=output_dir / "trial_output", + batch_size=batch_sizes[0], + num_dataloader_workers=args.num_dataloader_workers, + num_preprocessing_workers=args.num_preprocessing_workers, + num_cucim_workers=args.num_cucim_workers, + ) + config = _merge_base_config(base, args.config_file) + + # ── Setup ──────────────────────────────────────────────────────────────── + console.rule("[bold cyan]Setup") + setup_dir = output_dir / "setup" + with Status("Tiling slides (runs once) …", console=console, spinner="dots") as status: + try: + coordinates_dir, tiles_dir = _setup_tiling( + config=config, + setup_dir=setup_dir, + csv_path=shared_csv, + status=status, + ) + except RuntimeError as exc: + args_list = exc.args + msg = args_list[0] if args_list else "" + log_path = args_list[1] if len(args_list) > 1 else setup_dir / "setup_harness.log" + console.print(f"[red bold]✗ Tiling setup failed[/] ({msg})") + _print_log_panel(console, log_path, title="Setup error log") + return 1 + console.print(f"[green]✓[/] Tile stores ready [dim]{tiles_dir}[/]") + + # ── Benchmark ──────────────────────────────────────────────────────────── + modes = list(args.modes) + trial_rows: list[dict[str, Any]] = [] + trial_results_path = output_dir / "trial_results.csv" + total_trials = len(batch_sizes) * len(modes) * (args.warmup + args.repeat) + + console.rule("[bold cyan]Benchmark") + batch_label = ", ".join(str(batch_size) for batch_size in batch_sizes) + console.print( + f" [bold]{len(batch_sizes)}[/] batch sizes · " + f"[bold]{len(modes)}[/] modes · " + f"[bold]{args.repeat}[/] repeat · " + f"[bold]{args.warmup}[/] warmup · " + f"batch [bold]{batch_label}[/] · " + f"model [bold]{args.model}[/]" + ) + console.print() + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + console=console, + ) as progress: + overall_task = progress.add_task("[bold]Overall", total=total_trials) + trial_task = progress.add_task("", total=None) + + for batch_size in batch_sizes: + for mode in modes: + mode_dir = output_dir / "runs" / f"bs-{batch_size}" / mode + read_tiles_from = tiles_dir if mode == "tar" else None + + for rep_idx in range(args.warmup + args.repeat): + is_warmup = rep_idx < args.warmup + kind = "warmup" if is_warmup else "measure" + rep_num = rep_idx if is_warmup else rep_idx - args.warmup + 1 + run_dir = mode_dir / ("warmup" if is_warmup else f"rep-{rep_num:02d}") + + if is_warmup: + desc = f"[dim]warmup bs={batch_size} {mode}[/]" + else: + desc = f"[bold cyan]bs={batch_size} {mode}[/] rep [bold]{rep_num}[/]/{args.repeat}" + progress.update(trial_task, description=desc) + + row = run_trial( + mode=mode, + batch_size=batch_size, + kind=kind, + repeat_index=rep_num, + run_dir=run_dir, + config=config, + read_coordinates_from=coordinates_dir, + read_tiles_from=read_tiles_from, + ) + progress.advance(overall_task) + + ok = row["exit_code"] == 0 + icon = "[green]✓[/]" if ok else "[red]✗[/]" + if is_warmup: + progress.console.log( + f"{icon} [dim]warmup[/] bs={batch_size} {mode} {row['end_to_end_seconds']:.1f}s" + ) + if not ok: + _print_log_panel(progress.console, run_dir / "harness.log", title=f"warmup bs={batch_size} {mode} — error log") + else: + tps = row["tiles_per_second"] + elapsed = row["end_to_end_seconds"] + tiles = row["total_tiles"] + progress.console.log( + f"{icon} [bold]bs={batch_size} {mode}[/] rep {rep_num}/{args.repeat} " + f"[bold yellow]{tps:,.0f}[/] tiles/s " + f"({tiles:,} tiles in {elapsed:.1f}s)" + + (f" [red]exit={row['exit_code']}[/]" if not ok else "") + ) + if not ok: + _print_log_panel(progress.console, run_dir / "harness.log", title=f"bs={batch_size} {mode} rep {rep_num} — error log") + trial_rows.append(row) + + progress.update(trial_task, visible=False) + + # ── Save results ───────────────────────────────────────────────────────── + save_csv(trial_rows, trial_results_path) + console.print(f"\n[dim]Trial results →[/] {trial_results_path}") + + # ── Charts + summary table ──────────────────────────────────────────────── + console.rule("[bold cyan]Results") + summary_rows = _prepare_chart_outputs(trial_rows, output_dir, console=console) + if not summary_rows: + return 1 + + console.print(_make_summary_table(summary_rows)) + console.print( + Panel( + ( + f"[dim]throughput_by_strategy.png[/]\n[dim]timing_breakdown.png[/]\n[dim]summary.csv[/]" + if len(batch_sizes) == 1 + else f"[dim]throughput_by_batch_size.png[/]\n[dim]throughput_by_strategy_bs*.png[/]\n[dim]timing_breakdown_bs*.png[/]\n[dim]summary.csv[/]" + ), + title=f"[bold]Saved to[/] {output_dir}", + expand=False, + ) + ) + return 0 + + +def main() -> int: + args = parse_args() + + if args.internal_harness: + return _run_internal_harness(args) + + if args.chart_only: + from rich.console import Console + + console = Console() + trial_rows = _load_trial_results_csvs(args.chart_only) + summary_rows = _prepare_chart_outputs(trial_rows, args.output_dir, console=console) + if not summary_rows: + return 1 + console.print(_make_summary_table(summary_rows)) + return 0 + + return run_benchmark(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/setup.cfg b/setup.cfg index 6d89cc7..c5285c1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ classifiers = packages = slide2vec install_requires = - hs2p>=2.3.0,<3 + hs2p>=2.4.1,<3 omegaconf h5py matplotlib @@ -38,6 +38,9 @@ zip_safe = no include_package_data = True [options.extras_require] +cucim = + hs2p[cucim]>=2.4.1,<3 + PyTurboJPEG models = huggingface-hub sacremoses diff --git a/slide2vec/api.py b/slide2vec/api.py index f48a9fb..b1c1ec2 100644 --- a/slide2vec/api.py +++ b/slide2vec/api.py @@ -3,6 +3,12 @@ from typing import TYPE_CHECKING, Any, Mapping, Protocol, Sequence, overload from slide2vec.artifacts import SlideEmbeddingArtifact, TileEmbeddingArtifact +from slide2vec.model_settings import ( + canonicalize_model_name, + get_recommended_model_settings, + normalize_precision_name, + validate_model_runtime_compatibility, +) if TYPE_CHECKING: from hs2p import SlideSpec @@ -17,12 +23,6 @@ "titan": "slide", } -MODEL_NAME_ALIASES = { - "phikon-v2": "phikonv2", - "hibou-b": "hibou", - "hibou-l": "hibou", -} - PathLike = str | Path @@ -37,9 +37,19 @@ class SlideLike(Protocol): SlideSequence = Sequence[SlideInput] TilingResultsInput = Sequence[Any] | Mapping[str, Any] + +def _cfg_num_cucim_workers(cfg: Any) -> int: + speed = getattr(cfg, "speed", None) + if speed is not None and hasattr(speed, "num_cucim_workers"): + return int(getattr(speed, "num_cucim_workers")) + tiling = getattr(cfg, "tiling", None) + if tiling is not None and hasattr(tiling, "num_cucim_workers"): + return int(getattr(tiling, "num_cucim_workers")) + return 4 + @dataclass(frozen=True) class PreprocessingConfig: - backend: str = "asap" + backend: str = "auto" target_spacing_um: float = 0.5 target_tile_size_px: int = 224 tolerance: float = 0.05 @@ -47,7 +57,14 @@ class PreprocessingConfig: tissue_threshold: float = 0.01 drop_holes: bool = False use_padding: bool = True + read_coordinates_from: Path | None = None read_tiles_from: Path | None = None + on_the_fly: bool = True + gpu_decode: bool = False + adaptive_batching: bool = False + use_supertiles: bool = True + jpeg_backend: str = "turbojpeg" + num_cucim_workers: int = 4 resume: bool = False segmentation: dict[str, Any] = field(default_factory=dict) filtering: dict[str, Any] = field(default_factory=dict) @@ -56,6 +73,12 @@ class PreprocessingConfig: @classmethod def from_config(cls, cfg: Any) -> "PreprocessingConfig": tiling = cfg.tiling + default_read_coordinates_from = Path(getattr(cfg, "output_dir", "output")) / "coordinates" + read_coordinates_from = getattr(tiling, "read_coordinates_from", None) + read_tiles_from = getattr(tiling, "read_tiles_from", None) + on_the_fly = bool(getattr(tiling, "on_the_fly", True)) + gpu_decode = bool(getattr(tiling, "gpu_decode", False)) + adaptive_batching = bool(getattr(tiling, "adaptive_batching", False)) return cls( backend=tiling.backend, target_spacing_um=float(tiling.params.target_spacing_um), @@ -65,7 +88,18 @@ def from_config(cls, cfg: Any) -> "PreprocessingConfig": tissue_threshold=float(tiling.params.tissue_threshold), drop_holes=bool(tiling.params.drop_holes), use_padding=bool(tiling.params.use_padding), - read_tiles_from=Path(tiling.read_tiles_from) if tiling.read_tiles_from else None, + read_coordinates_from=( + Path(read_coordinates_from) if read_coordinates_from else default_read_coordinates_from + ), + read_tiles_from=( + Path(read_tiles_from) if read_tiles_from else None + ), + on_the_fly=on_the_fly, + gpu_decode=gpu_decode, + adaptive_batching=adaptive_batching, + use_supertiles=bool(getattr(tiling, "use_supertiles", True)), + jpeg_backend=str(getattr(tiling, "jpeg_backend", "turbojpeg")), + num_cucim_workers=_cfg_num_cucim_workers(cfg), resume=bool(getattr(cfg, "resume", False)), segmentation=dict(tiling.seg_params), filtering=dict(tiling.filter_params), @@ -86,21 +120,30 @@ class ExecutionOptions: output_format: str = "pt" batch_size: int = 1 num_workers: int = 0 + num_preprocessing_workers: int = 8 num_gpus: int | None = None - mixed_precision: bool = False + precision: str | None = None + prefetch_factor: int = 4 + persistent_workers: bool = True + gpu_batch_preprocessing: bool = True save_tile_embeddings: bool = False save_latents: bool = False @classmethod def from_config(cls, cfg: Any, *, run_on_cpu: bool = False) -> "ExecutionOptions": configured_num_gpus = getattr(cfg.speed, "num_gpus", None) + requested_precision = normalize_precision_name(getattr(cfg.speed, "precision", "fp32")) return cls( output_dir=Path(cfg.output_dir), output_format="pt", batch_size=int(getattr(cfg.model, "batch_size", 1)), - num_workers=int(getattr(cfg.speed, "num_workers_embedding", cfg.speed.num_workers)), + num_workers=int(getattr(cfg.speed, "num_dataloader_workers", getattr(cfg.speed, "num_workers_embedding", cfg.speed.num_workers))), + num_preprocessing_workers=int(getattr(cfg.speed, "num_preprocessing_workers", cfg.speed.num_workers)), num_gpus=1 if run_on_cpu else _coerce_num_gpus(configured_num_gpus), - mixed_precision=bool(cfg.speed.fp16 and not run_on_cpu), + precision="fp32" if run_on_cpu else requested_precision, + prefetch_factor=int(getattr(cfg.speed, "prefetch_factor_embedding", 4)), + persistent_workers=bool(getattr(cfg.speed, "persistent_workers_embedding", True)), + gpu_batch_preprocessing=bool(getattr(cfg.speed, "gpu_batch_preprocessing", True)), save_tile_embeddings=bool(getattr(cfg.model, "save_tile_embeddings", False)), save_latents=bool(getattr(cfg.model, "save_latents", False)), ) @@ -108,8 +151,11 @@ def from_config(cls, cfg: Any, *, run_on_cpu: bool = False) -> "ExecutionOptions def __post_init__(self) -> None: resolved_num_gpus = _default_num_gpus() if self.num_gpus is None else self.num_gpus object.__setattr__(self, "num_gpus", resolved_num_gpus) + object.__setattr__(self, "precision", normalize_precision_name(self.precision)) if resolved_num_gpus < 1: raise ValueError("ExecutionOptions.num_gpus must be at least 1") + if self.prefetch_factor < 1: + raise ValueError("ExecutionOptions.prefetch_factor must be at least 1") def with_output_dir(self, output_dir: PathLike | None) -> "ExecutionOptions": if output_dir is None: @@ -150,10 +196,12 @@ def __init__( patch_size: int | None = None, token_size: int | None = None, normalize_embeddings: bool | None = None, + allow_non_recommended_settings: bool = False, ) -> None: self.name = _canonical_model_name(name) self.level = level self._requested_device = device + self.allow_non_recommended_settings = bool(allow_non_recommended_settings) self._model_kwargs = { "mode": mode, "arch": arch, @@ -178,6 +226,7 @@ def from_pretrained( patch_size: int | None = None, token_size: int | None = None, normalize_embeddings: bool | None = None, + allow_non_recommended_settings: bool = False, device: str = "auto", ) -> "Model": canonical_name = _canonical_model_name(name) @@ -193,6 +242,7 @@ def from_pretrained( patch_size=patch_size, token_size=token_size, normalize_embeddings=normalize_embeddings, + allow_non_recommended_settings=allow_non_recommended_settings, ) @property @@ -213,8 +263,10 @@ def embed_tiles( ) -> list[TileEmbeddingArtifact]: from slide2vec.inference import embed_tiles - resolved = _coerce_execution_options(execution) + resolved = _coerce_execution_options(execution, model=self) _require_output_dir_for_persistence(resolved, method_name="Model.embed_tiles(...)") + if preprocessing is not None: + validate_model_runtime_compatibility(self, preprocessing, resolved) return embed_tiles(self, slides, tiling_results, execution=resolved, preprocessing=preprocessing) def aggregate_tiles( @@ -226,7 +278,7 @@ def aggregate_tiles( ) -> list[SlideEmbeddingArtifact]: from slide2vec.inference import aggregate_tiles - resolved = _coerce_execution_options(execution) + resolved = _coerce_execution_options(execution, model=self) _require_output_dir_for_persistence(resolved, method_name="Model.aggregate_tiles(...)") return aggregate_tiles(self, tile_artifacts, execution=resolved, preprocessing=preprocessing) @@ -292,7 +344,8 @@ def embed_slides( ) -> list[EmbeddedSlide]: from slide2vec.inference import embed_slides - resolved = _coerce_execution_options(execution) + resolved = _coerce_execution_options(execution, model=self) + validate_model_runtime_compatibility(self, preprocessing, resolved) return embed_slides( self, slides, @@ -303,13 +356,16 @@ def embed_slides( def _load_backend(self) -> "LoadedModel": if self._backend is None: from slide2vec.inference import load_model + from slide2vec.progress import emit_progress + emit_progress("model.loading", model_name=self.name) self._backend = load_model( name=self.name, level=self.level, device=self._requested_device, **self._model_kwargs, ) + emit_progress("model.ready", model_name=self.name, device=str(self._backend.device)) return self._backend @@ -323,7 +379,7 @@ def __init__( ) -> None: self.model = model self.preprocessing = preprocessing - self.execution = _coerce_execution_options(execution) + self.execution = _coerce_execution_options(execution, model=model) def run( self, @@ -334,6 +390,8 @@ def run( ) -> RunResult: from slide2vec.inference import run_pipeline + if not tiling_only: + validate_model_runtime_compatibility(self.model, self.preprocessing, self.execution) return run_pipeline( self.model, slides=slides, @@ -345,14 +403,19 @@ def run( def _canonical_model_name(name: str) -> str: - normalized = name.strip().lower() - return MODEL_NAME_ALIASES.get(normalized, normalized) + return canonicalize_model_name(name) -def _coerce_execution_options(options: ExecutionOptions | None) -> ExecutionOptions: - if options is None: - return ExecutionOptions() - return options +def _coerce_execution_options( + options: ExecutionOptions | None, + *, + model: Model | None = None, +) -> ExecutionOptions: + resolved = ExecutionOptions() if options is None else options + if resolved.precision is not None: + return resolved + recommended = _recommended_execution_precision(model) + return replace(resolved, precision=recommended) def _coerce_num_gpus(value: Any) -> int | None: @@ -374,3 +437,10 @@ def _default_num_gpus() -> int: def _require_output_dir_for_persistence(execution: ExecutionOptions, *, method_name: str) -> None: if execution.output_dir is None: raise ValueError(f"ExecutionOptions.output_dir is required for {method_name}") + + +def _recommended_execution_precision(model: Model | None) -> str: + settings = get_recommended_model_settings(getattr(model, "name", None)) + if settings is not None and settings.precision is not None: + return settings.precision + return "fp32" diff --git a/slide2vec/cli.py b/slide2vec/cli.py index 05961d4..89097c2 100644 --- a/slide2vec/cli.py +++ b/slide2vec/cli.py @@ -33,6 +33,9 @@ def build_model_and_pipeline(args): patch_size=cfg.model.patch_size, token_size=cfg.model.token_size, normalize_embeddings=getattr(cfg.model, "normalize_embeddings", None), + allow_non_recommended_settings=bool( + getattr(cfg.model, "allow_non_recommended_settings", False) + ), device="cpu" if args.run_on_cpu else "auto", ) preprocessing = PreprocessingConfig.from_config(cfg) diff --git a/slide2vec/configs/models/conch.yaml b/slide2vec/configs/models/conch.yaml index 8bc60fc..94b8689 100644 --- a/slide2vec/configs/models/conch.yaml +++ b/slide2vec/configs/models/conch.yaml @@ -1,44 +1,16 @@ -# csv: "/data/temporary/clement/discern/multi-organ-1k.csv" -csv: "/data/pathology/projects/clement/discern/multi-center-prostate-no-mask.csv" +csv: -output_dir: "/data/pathology/projects/clement/discern/hs2p" # output directory - -save_previews: true +output_dir: +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 448 # size of the tiles to extract, in pixels - tissue_threshold: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) - seg_params: - downsample: 16 # find the closest downsample in the slide for tissue segmentation - filter_params: - ref_tile_size: 448 # reference tile size at spacing tiling.params.target_spacing_um - a_t: 16 # area filter threshold for tissue ; positive integer, the minimum size of detected foreground contours to consider, relative to the reference tile size ref_tile_size - # e.g. a value 10 means only detected foreground contours of size greater than 10 [ref_tile_size, ref_tile_size] tiles at spacing tiling.params.target_spacing_um will be kept - a_h: 4 # area filter threshold for holes ; positive integer, the minimum size of detected holes/cavities in foreground contours to avoid, once again relative to the reference tile size ref_tile_size - max_n_holes: 8 # maximum of holes to consider per detected foreground contours (positive integer, higher values lead to more accurate patching but increase computational cost ; keeps the biggest holes) - filter_white: true # whether to filter out mostly white tiles - filter_black: true # whether to filter out mostly black tiles - white_threshold: 220 # threshold for white pixels (0-255) - black_threshold: 25 # threshold for black pixels (0-255) - fraction_threshold: 0.9 # fraction of pixels that must be white/black to filter out the tile + target_spacing_um: 0.5 + target_tile_size_px: 448 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") + level: "tile" name: "conch" - batch_size: 64 speed: - fp16: true # use mixed precision during model inference - num_workers: 16 # number of workers for tiling slides - num_workers_embedding: 16 # number of workers for data loading when embedding slides - -wandb: - enable: true - project: "discern" - username: "clemsg" - exp_name: "tiling" - # exp_name: "features" - tags: ["tiling", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] + precision: "fp32" diff --git a/slide2vec/configs/models/conchv15.yaml b/slide2vec/configs/models/conchv15.yaml new file mode 100644 index 0000000..5a356a8 --- /dev/null +++ b/slide2vec/configs/models/conchv15.yaml @@ -0,0 +1,16 @@ +csv: + +output_dir: +save_previews: false + +tiling: + params: + target_spacing_um: 0.5 + target_tile_size_px: 448 + +model: + level: "tile" + name: "conchv15" + +speed: + precision: "fp16" diff --git a/slide2vec/configs/models/default.yaml b/slide2vec/configs/models/default.yaml index b7bc3f8..bff4aaa 100644 --- a/slide2vec/configs/models/default.yaml +++ b/slide2vec/configs/models/default.yaml @@ -8,21 +8,22 @@ seed: 0 # seed for reproducibility model: level: "tile" # level at which to extract the features ("tile", "region" or "slide") - name: # foundation model name ["uni", "uni2", "virchow", "virchow2", "prov-gigapath", "h-optimus-0", "h-optimus-1", "pathojepa", "titan", "prism"] (leave empty when using a custom model) - mode: "cls" # embedding mode ["cls", "full"] + name: # foundation model name; see docs/models.md for supported request strings (leave empty when using a custom model) + mode: # embedding mode override ["cls", "full"]; leave empty for the model default arch: # architecture of custom model pretrained_weights: # path to the pretrained weights when using a custom model - batch_size: 256 + batch_size: 32 input_size: ${tiling.params.target_tile_size_px} patch_size: 256 # if level is "region", size used to unroll the region into patches token_size: 16 # size of the tokens used model is a custom pretrained ViT save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide" save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism') normalize_embeddings: false # L2 normalize tile embeddings (used by some custom checkpoints such as pathojepa) + allow_non_recommended_settings: false # when true, non-recommended model input size / spacing / precision combinations warn instead of erroring speed: - fp16: false # use mixed precision during model inference - num_workers_embedding: 8 # number of workers for data loading when embedding slides + precision: fp32 # model inference precision ["fp32", "fp16", "bf16"] + num_dataloader_workers: 8 # number of DataLoader worker processes for reading tiles during embedding (tar path); on-the-fly path derives this automatically from cpu_count // speed.num_cucim_workers num_gpus: # number of GPUs to use for feature extraction; defaults to all available GPUs wandb: diff --git a/slide2vec/configs/models/h-optimus-0.yaml b/slide2vec/configs/models/h-optimus-0.yaml index efcf233..c05e299 100644 --- a/slide2vec/configs/models/h-optimus-0.yaml +++ b/slide2vec/configs/models/h-optimus-0.yaml @@ -1,37 +1,16 @@ -# csv: "/data/temporary/clement/leopard/csvs/dev-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/tcga/tcga-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/validation-without-pen-marks-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/test-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/brazil-slide2vec.csv" -csv: "/data/temporary/clement/leopard/csvs/cologne-slide2vec.csv" +csv: output_dir: "output" -resume: true -resume_dirname: "48juvldi" - save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 2048 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "region" + level: "tile" name: "h-optimus-0" - batch_size: 1 speed: - fp16: true - -wandb: - enable: true - project: "leopard" - username: "clemsg" - exp_name: "features" - # tags: ["features", "dev", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "tcga", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "validation", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "test", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "brazil", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - tags: ["features", "cologne", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/h-optimus-1.yaml b/slide2vec/configs/models/h-optimus-1.yaml index 60a017b..7a3b520 100644 --- a/slide2vec/configs/models/h-optimus-1.yaml +++ b/slide2vec/configs/models/h-optimus-1.yaml @@ -1,16 +1,16 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 224 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") + level: "tile" name: "h-optimus-1" - batch_size: 1 speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/h0-mini.yaml b/slide2vec/configs/models/h0-mini.yaml index e3fc0f9..884310a 100644 --- a/slide2vec/configs/models/h0-mini.yaml +++ b/slide2vec/configs/models/h0-mini.yaml @@ -1,20 +1,16 @@ -csv: # path to csv containing slide paths +csv: output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) - target_tile_size_px: 224 # size of the tiles to extract, in pixels - tissue_threshold: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) - filter_params: - ref_tile_size: 224 + target_spacing_um: 0.5 + target_tile_size_px: 224 model: level: "tile" name: "h0-mini" - batch_size: 1 speed: - fp16: true \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/hibou.yaml b/slide2vec/configs/models/hibou.yaml index 3ceb3a8..797d164 100644 --- a/slide2vec/configs/models/hibou.yaml +++ b/slide2vec/configs/models/hibou.yaml @@ -1,19 +1,17 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory - -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 224 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") - arch: "hibou-b" # Hibou model architectures, options: ("hibou-b", "hibou-L") + level: "tile" + arch: "hibou-b" name: "hibou" - batch_size: 1 speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/kaiko-midnight.yaml b/slide2vec/configs/models/kaiko-midnight.yaml index b630dc6..7cb035a 100644 --- a/slide2vec/configs/models/kaiko-midnight.yaml +++ b/slide2vec/configs/models/kaiko-midnight.yaml @@ -1,18 +1,16 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory - -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 224 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "tile" # level at which to extract the features ("tile" or "region") + level: "tile" name: "kaiko-midnight" - batch_size: 1 speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/kaiko.yaml b/slide2vec/configs/models/kaiko.yaml index 704dc68..b442e7c 100644 --- a/slide2vec/configs/models/kaiko.yaml +++ b/slide2vec/configs/models/kaiko.yaml @@ -1,19 +1,17 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory - -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 224 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") + level: "tile" name: "kaiko" - arch: "vitl14" # kaiko model architectures, options: ("vits8", "vits16", "vitb8", "vitb16", "vitl14") - batch_size: 1 + arch: "vitl14" speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp32" diff --git a/slide2vec/configs/models/musk.yaml b/slide2vec/configs/models/musk.yaml index 67d4520..03e2232 100644 --- a/slide2vec/configs/models/musk.yaml +++ b/slide2vec/configs/models/musk.yaml @@ -1,18 +1,16 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory - -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 384 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 384 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") + level: "tile" name: "musk" - batch_size: 1 speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/panda-vit-s.yaml b/slide2vec/configs/models/panda-vit-s.yaml index a1925bd..e012540 100644 --- a/slide2vec/configs/models/panda-vit-s.yaml +++ b/slide2vec/configs/models/panda-vit-s.yaml @@ -1,23 +1,17 @@ csv: "" -save_previews: true - -output_dir: "output" # output directory +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) - target_tile_size_px: 224 # size of the tiles to extract, in pixels - tissue_threshold: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) - filter_params: - ref_tile_size: 224 + target_spacing_um: 0.5 + target_tile_size_px: 224 model: level: "tile" name: "panda-vit-s" pretrained_weights: "/path/to/model/weights.pt" - batch_size: 1 speed: - fp16: false + precision: "fp32" diff --git a/slide2vec/configs/models/pathojepa.yaml b/slide2vec/configs/models/pathojepa.yaml index e81a2f0..de63f99 100644 --- a/slide2vec/configs/models/pathojepa.yaml +++ b/slide2vec/configs/models/pathojepa.yaml @@ -1,7 +1,3 @@ -# csv: "/data/pathology/projects/clement/leopard/csvs/dev-slide2vec.csv" -# csv: "/data/pathology/projects/clement/leopard/csvs/tcga/tcga-slide2vec.csv" -# csv: "/data/pathology/projects/clement/leopard/csvs/test-slide2vec.csv" -# csv: "/data/pathology/projects/clement/leopard/csvs/cologne-slide2vec.csv" csv: "/data/pathology/projects/clement/leopard/csvs/brazil-slide2vec-august-2025-revision.csv" output_dir: "/data/pathology/projects/clement/discern/pathojepa/slide2vec" @@ -9,28 +5,28 @@ save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1) - target_tile_size_px: 2048 # PathoJEPA inference target tile size - tissue_threshold: 0.1 # threshold used to filter out tiles with too little tissue + target_spacing_um: 0.5 + tolerance: 0.05 + target_tile_size_px: 2048 + tissue_threshold: 0.1 seg_params: downsample: 64 filter_params: ref_tile_size: 256 model: - level: "region" # set to "region" to run region-level inference with this tile encoder + level: "region" name: "pathojepa" arch: "vit_small" pretrained_weights: "/data/pathology/projects/clement/discern/pathojepa/runs/dmky8lh7/jepa-pathorob-latest.pth.tar" input_size: 224 - patch_size: 256 # region-unrolling size when model.level == "region" - token_size: 16 # ViT patch size used by PathoJEPA + patch_size: 256 + token_size: 16 normalize_embeddings: false batch_size: 1 speed: - fp16: false + precision: "fp32" wandb: enable: true diff --git a/slide2vec/configs/models/phikonv2.yaml b/slide2vec/configs/models/phikonv2.yaml index 8eec445..0ac2328 100644 --- a/slide2vec/configs/models/phikonv2.yaml +++ b/slide2vec/configs/models/phikonv2.yaml @@ -1,18 +1,16 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory - -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 224 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") + level: "tile" name: "phikonv2" - batch_size: 1 speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp32" diff --git a/slide2vec/configs/models/prism.yaml b/slide2vec/configs/models/prism.yaml index 02f8987..dda6d48 100644 --- a/slide2vec/configs/models/prism.yaml +++ b/slide2vec/configs/models/prism.yaml @@ -1,24 +1,18 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.07 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) - target_tile_size_px: 224 # size of the tiles to extract, in pixels - tissue_threshold: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) - filter_params: - ref_tile_size: 224 + target_spacing_um: 0.5 + target_tile_size_px: 224 model: level: "slide" name: "prism" - batch_size: 32 - save_tile_embeddings: true # whether to save tile embeddings alongside the pooled slide embedding - save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism') + save_tile_embeddings: true + save_latents: false speed: - fp16: true - num_workers_embedding: 16 \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/prov-gigapath-slide.yaml b/slide2vec/configs/models/prov-gigapath-slide.yaml index df12e1d..f91bef1 100644 --- a/slide2vec/configs/models/prov-gigapath-slide.yaml +++ b/slide2vec/configs/models/prov-gigapath-slide.yaml @@ -1,18 +1,17 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 256 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 256 model: - level: "slide" # level at which to extract the features ("tile", "region" or "slide") + level: "slide" name: "prov-gigapath" - batch_size: 32 - save_tile_embeddings: true # whether to save tile embeddings alongside the pooled slide embedding + save_tile_embeddings: true speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/prov-gigapath-tile.yaml b/slide2vec/configs/models/prov-gigapath-tile.yaml index d902e41..70a5525 100644 --- a/slide2vec/configs/models/prov-gigapath-tile.yaml +++ b/slide2vec/configs/models/prov-gigapath-tile.yaml @@ -1,16 +1,16 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 256 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 256 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") + level: "tile" name: "prov-gigapath" - batch_size: 1 speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/titan.yaml b/slide2vec/configs/models/titan.yaml index dbccbd9..1dffbf4 100644 --- a/slide2vec/configs/models/titan.yaml +++ b/slide2vec/configs/models/titan.yaml @@ -1,18 +1,17 @@ -csv: # path to csv containing slide paths +csv: -output_dir: "output" # output directory -save_previews: true +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - target_tile_size_px: 512 # size of the tiles to extract, in pixels + target_spacing_um: 0.5 + target_tile_size_px: 512 model: - level: "slide" # level at which to extract the features ("tile", "region" or "slide") + level: "slide" name: "titan" - batch_size: 32 - save_tile_embeddings: true # whether to save tile embeddings alongside the pooled slide embedding + save_tile_embeddings: true speed: - fp16: true # use mixed precision during model inference \ No newline at end of file + precision: "fp16" diff --git a/slide2vec/configs/models/uni.yaml b/slide2vec/configs/models/uni.yaml index e9034c4..dfcd571 100644 --- a/slide2vec/configs/models/uni.yaml +++ b/slide2vec/configs/models/uni.yaml @@ -1,39 +1,16 @@ -# csv: "/data/temporary/clement/leopard/csvs/dev-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/tcga/tcga-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/validation-without-pen-marks-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/test-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/brazil-slide2vec.csv" -# csv: "/data/temporary/clement/leopard/csvs/cologne-slide2vec.csv" -# csv: "/data/temporary/clement/code/ab-mil/data/unicorn-task-1/validation+test.csv" -csv: "/data/temporary/clement/code/slide2vec/tmp.csv" +csv: output_dir: "output/debug" - -save_previews: true +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) - target_tile_size_px: 2048 # size of the tiles to extract, in pixels - tissue_threshold: 0.01 # threshold used to filter out tiles that have less tissue than this value (percentage) - filter_params: - ref_tile_size: 224 + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "region" + level: "tile" name: "uni" - batch_size: 1 -wandb: - enable: false - project: "unicorn" - username: "clemsg" - exp_name: "features" - # tags: ["features", "dev", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "tcga", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "validation", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "test", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "brazil", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "cologne", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - tags: ["features", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] +speed: + precision: "fp16" diff --git a/slide2vec/configs/models/uni2.yaml b/slide2vec/configs/models/uni2.yaml index 60786ac..62e6bb0 100644 --- a/slide2vec/configs/models/uni2.yaml +++ b/slide2vec/configs/models/uni2.yaml @@ -1,24 +1,16 @@ -csv: "/data/temporary/clement/code/ab-mil/data/unicorn-task-1/validation+test.csv" +csv: -output_dir: "output" # output directory +output_dir: "output" +save_previews: false tiling: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) - target_tile_size_px: 224 # size of the tiles to extract, in pixels - tissue_threshold: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) - filter_params: - ref_tile_size: 224 + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "tile" # level at which to extract the features ("tile", "region" or "slide") + level: "tile" name: "uni2" - batch_size: 1 -wandb: - enable: true - project: "unicorn" - username: "clemsg" - exp_name: "features" - tags: ["features", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] \ No newline at end of file +speed: + precision: "bf16" diff --git a/slide2vec/configs/models/virchow.yaml b/slide2vec/configs/models/virchow.yaml index c3f2870..09ffd2f 100644 --- a/slide2vec/configs/models/virchow.yaml +++ b/slide2vec/configs/models/virchow.yaml @@ -1,34 +1,16 @@ -# csv: "/data/temporary/clement/dataset/tcga-prad/mutations/tcga-prad-tp53-slide2vec.csv" -# csv: "/data/temporary/clement/dataset/tcga-blca/mutations/tcga-blca-tp53-slide2vec.csv" -csv: "/data/temporary/clement/dataset/tcga-brca/mutations/tcga-brca-tp53-slide2vec.csv" +csv: output_dir: "output" - -save_previews: true +save_previews: false tiling: - read_tiles_from: params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.07 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) - target_tile_size_px: 4096 # size of the tiles to extract, in pixels - tissue_threshold: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) - filter_params: - ref_tile_size: 224 + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "region" + level: "tile" name: "virchow" - batch_size: 1 speed: - fp16: true - -wandb: - enable: true - project: "mut-pred" - username: "clemsg" - exp_name: "features" - # tags: ["features", "tcga-prad", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "tcga-blca", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - tags: ["features", "tcga-brca", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] + precision: "fp16" diff --git a/slide2vec/configs/models/virchow2.yaml b/slide2vec/configs/models/virchow2.yaml index 63b1e8a..ce866a8 100644 --- a/slide2vec/configs/models/virchow2.yaml +++ b/slide2vec/configs/models/virchow2.yaml @@ -1,35 +1,16 @@ -# csv: "/data/temporary/clement/dataset/tcga-prad/mutations/tcga-prad-tp53-slide2vec.csv" -# csv: "/data/temporary/clement/dataset/tcga-blca/mutations/tcga-blca-tp53-slide2vec.csv" -# csv: "/data/temporary/clement/dataset/tcga-brca/mutations/tcga-brca-tp53-slide2vec.csv" -csv: "/data/temporary/clement/leopard/csvs/brazil-slide2vec-august-2025-revision.csv" +csv: -output_dir: "/data/temporary/clement/code/slide2vec/output" - -save_previews: true +output_dir: "output" +save_previews: false tiling: - # read_tiles_from: "/data/temporary/clement/code/slide2vec/output/vjh3hrr6/coordinates" params: - target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel - tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) - target_tile_size_px: 2048 # size of the tiles to extract, in pixels - tissue_threshold: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) - filter_params: - ref_tile_size: 256 + target_spacing_um: 0.5 + target_tile_size_px: 224 model: - level: "region" + level: "tile" name: "virchow2" - batch_size: 1 speed: - fp16: true - -wandb: - enable: true - project: "leopard" - username: "clemsg" - exp_name: "features" - tags: ["features", "brazil", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "tcga-blca", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] - # tags: ["features", "tcga-brca", "${model.name}", "${model.level}", "${tiling.params.target_tile_size_px}"] + precision: "fp16" diff --git a/slide2vec/configs/preprocessing/default.yaml b/slide2vec/configs/preprocessing/default.yaml index 2a5a997..d125e86 100644 --- a/slide2vec/configs/preprocessing/default.yaml +++ b/slide2vec/configs/preprocessing/default.yaml @@ -9,14 +9,20 @@ save_previews: true # save preview images of slide tiling and mask overlays seed: 0 # seed for reproducibility tiling: - read_tiles_from: # path to a directory containing HS2P `.tiles.npz` / `.tiles.meta.json` artifacts (leave empty to compute the coordinates) - backend: "asap" # backend to use for slide reading + on_the_fly: true # read tiles directly from WSI during embedding (requires cucim backend) + gpu_decode: false # attempt GPU-accelerated JPEG decoding via nvImageCodec (experimental) + adaptive_batching: false # when true, vary batch size to align with super tile boundaries (avoids redundant reads but batch size fluctuates) + use_supertiles: true # group tiles into 8x8/4x4/2x2 super tile reads to reduce WSI read calls (on-the-fly path only) + jpeg_backend: "turbojpeg" # JPEG encoder for tar extraction: "turbojpeg" (faster) or "pil" (compatible with older ground truth fixtures) + read_coordinates_from: # path to HS2P `.coordinates.npz` / `.coordinates.meta.json` artifacts; defaults to /coordinates when left empty + read_tiles_from: # path to an existing directory containing HS2P `.tiles.tar` tile stores to reuse instead of writing new stores during tiling + backend: "auto" # backend to use for slide reading; "auto" lets hs2p resolve the best backend per slide, preferring cuCIM when available params: target_spacing_um: 0.5 # spacing at which to tile the slide, in microns per pixel tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) target_tile_size_px: 256 # size of the tiles to extract, in pixels overlap: 0.0 # percentage of overlap between two consecutive tiles (float between 0 and 1) - tissue_threshold: 0.01 # minimum fraction of pixels that must be tissue to keep a tile (float between 0 and 1) + tissue_threshold: 0.1 # minimum fraction of pixels that must be tissue to keep a tile (float between 0 and 1) drop_holes: false # whether or not to drop tiles whose center pixel falls withing an identified holes use_padding: true # whether to pad the border of the slide seg_params: @@ -28,7 +34,7 @@ tiling: use_otsu: false # use otsu's method instead of simple binary thresholding use_hsv: true # use HSV thresholding instead of simple binary thresholding filter_params: - ref_tile_size: ${target_tile_size_px} # reference tile size at spacing tiling.params.target_spacing_um + ref_tile_size: ${tiling.params.target_tile_size_px} # reference tile size at spacing tiling.params.target_spacing_um a_t: 4 # area filter threshold for tissue (positive integer, the minimum size of detected foreground contours to consider, relative to the reference tile size ref_tile_size, e.g. a value 10 means only detected foreground contours of size greater than 10 [ref_tile_size, ref_tile_size] tiles at spacing tiling.params.target_spacing_um will be kept) a_h: 2 # area filter threshold for holes (positive integer, the minimum size of detected holes/cavities in foreground contours to avoid, once again relative to the reference tile size ref_tile_size) max_n_holes: 8 # maximum of holes to consider per detected foreground contours (positive integer, higher values lead to more accurate patching but increase computational cost ; keeps the biggest holes) @@ -41,7 +47,8 @@ tiling: downsample: 32 # downsample to use for preview rendering speed: - num_workers: 8 # number of workers for tiling slides + num_preprocessing_workers: 8 # number of workers for hs2p tiling (WSI reading, JPEG encoding, tar writing) + num_cucim_workers: 4 # number of internal cucim threads per read_region call (embedding path, on-the-fly only); DataLoader workers are auto-set to cpu_count // num_cucim_workers wandb: enable: false diff --git a/slide2vec/data/__init__.py b/slide2vec/data/__init__.py index 98809c3..413a067 100644 --- a/slide2vec/data/__init__.py +++ b/slide2vec/data/__init__.py @@ -1,2 +1,3 @@ -from .dataset import TileDataset from .augmentations import RegionUnfolding +from .dataset import BatchTileCollator, TileIndexDataset +from .tile_store import TarTileReader diff --git a/slide2vec/data/cucim_tile_reader.py b/slide2vec/data/cucim_tile_reader.py new file mode 100644 index 0000000..b1930f4 --- /dev/null +++ b/slide2vec/data/cucim_tile_reader.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +import time +from typing import TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from pathlib import Path + + from hs2p import TilingResult + + +class SuperTileBatchSampler: + """Batch sampler that keeps super tiles intact. + + Greedily packs whole super tiles into batches of approximately + ``batch_size`` tiles. No super tile is ever split across batches, + so each WSI region is read exactly once. + """ + + def __init__(self, supertile_groups: list[np.ndarray], batch_size: int): + self.batches: list[list[int]] = [] + current: list[int] = [] + for group in supertile_groups: + positions = group.tolist() + if current and len(current) + len(positions) > batch_size: + self.batches.append(current) + current = positions + else: + current.extend(positions) + if current: + self.batches.append(current) + + def __iter__(self): + return iter(self.batches) + + def __len__(self): + return len(self.batches) + + +@dataclass(frozen=True) +class _SuperTile: + x_lv0: int + y_lv0: int + read_size_px: int + block_size: int + + +def _build_supertile_index(tiling_result: TilingResult): + """Build super tile grouping and per-tile lookup structures. + + Returns: + supertiles: list of ``_SuperTile`` + tile_to_st: array mapping tile_index → supertile id + tile_crop_x: array mapping tile_index → crop x offset at read level + tile_crop_y: array mapping tile_index → crop y offset at read level + ordered_indices: tile indices reordered so tiles in the same super tile are contiguous + """ + from hs2p.api import ( + _iter_grouped_read_plans_for_tar_extraction, + _resolve_read_step_px, + _resolve_step_px_lv0, + ) + + read_step_px = _resolve_read_step_px(tiling_result) + step_px_lv0 = _resolve_step_px_lv0(tiling_result) + + num_tiles = int(tiling_result.num_tiles) + tile_to_st = np.empty(num_tiles, dtype=np.int32) + tile_crop_x = np.empty(num_tiles, dtype=np.int32) + tile_crop_y = np.empty(num_tiles, dtype=np.int32) + supertiles: list[_SuperTile] = [] + ordered_indices: list[int] = [] + + for plan in _iter_grouped_read_plans_for_tar_extraction( + result=tiling_result, + read_step_px=read_step_px, + step_px_lv0=step_px_lv0, + ): + st_id = len(supertiles) + tile_index_iter = iter(plan.tile_indices) + for x_idx in range(plan.block_size): + for y_idx in range(plan.block_size): + tile_idx = next(tile_index_iter) + tile_to_st[tile_idx] = st_id + tile_crop_x[tile_idx] = x_idx * read_step_px + tile_crop_y[tile_idx] = y_idx * read_step_px + ordered_indices.append(tile_idx) + + supertiles.append(_SuperTile( + x_lv0=int(plan.x), + y_lv0=int(plan.y), + read_size_px=int(plan.read_size_px), + block_size=int(plan.block_size), + )) + + return supertiles, tile_to_st, tile_crop_x, tile_crop_y, np.array(ordered_indices, dtype=np.int64) + + +class CuCIMTileReader: + """Read tiles directly from a WSI using cucim's batched read_region. + + When ``use_supertiles=True``, tiles are grouped into larger read regions + (8x8, 4x4, or 2x2 blocks) following the same logic as hs2p tar extraction. + One ``read_region`` call per super tile replaces many individual calls. + """ + + def __init__( + self, + image_path: Path, + tiling_result: TilingResult, + *, + num_cucim_workers: int = 4, + gpu_decode: bool = False, + use_supertiles: bool = True, + ): + self._image_path = image_path + self._x = tiling_result.x + self._y = tiling_result.y + self._read_level = tiling_result.read_level + self._tile_size_px = int(tiling_result.read_tile_size_px) + self._num_cucim_workers = num_cucim_workers + self._gpu_decode = gpu_decode + self._cu_image = None + + self._use_supertiles = use_supertiles + if use_supertiles: + ( + self._supertiles, + self._tile_to_st, + self._tile_crop_x, + self._tile_crop_y, + self.ordered_indices, + ) = _build_supertile_index(tiling_result) + else: + self._supertiles = None + self._tile_to_st = None + self.ordered_indices = None + + def _ensure_open(self): + if self._cu_image is None: + try: + from cucim import CuImage + except ImportError as exc: + raise ImportError( + "cucim is required for on-the-fly tile reading. " + "Install it with: pip install cucim-cuXX (where XX matches your CUDA version)" + ) from exc + self._cu_image = CuImage(str(self._image_path)) + + def _read_region(self, locations, size): + kwargs = { + "level": int(self._read_level), + "num_workers": max(1, self._num_cucim_workers), + } + if self._gpu_decode: + kwargs["device"] = "cuda" + try: + return self._cu_image.read_region(locations, size, **kwargs) + except TypeError: + kwargs.pop("device", None) + return self._cu_image.read_region(locations, size, **kwargs) + + def read_batch(self, tile_indices: np.ndarray) -> torch.Tensor: + tensor, _timing = self.read_batch_with_timing(tile_indices) + return tensor + + def read_batch_with_timing(self, tile_indices: np.ndarray) -> tuple[torch.Tensor, dict[str, float]]: + if len(tile_indices) == 0: + return torch.empty( + (0, 3, self._tile_size_px, self._tile_size_px), dtype=torch.uint8 + ), {"reader_open_ms": 0.0, "reader_read_ms": 0.0} + was_closed = self._cu_image is None + open_start = time.perf_counter() + self._ensure_open() + reader_open_ms = (time.perf_counter() - open_start) * 1000.0 if was_closed else 0.0 + read_start = time.perf_counter() + + if not self._use_supertiles: + tensor = self._read_batch_simple(tile_indices) + else: + tensor = self._read_batch_supertiles(tile_indices) + reader_read_ms = (time.perf_counter() - read_start) * 1000.0 + return tensor, {"reader_open_ms": reader_open_ms, "reader_read_ms": reader_read_ms} + + def _read_batch_simple(self, tile_indices: np.ndarray) -> torch.Tensor: + locations = [(int(self._x[i]), int(self._y[i])) for i in tile_indices] + regions = self._read_region(locations, (self._tile_size_px, self._tile_size_px)) + batch = np.stack([np.asarray(r)[:, :, :3] for r in regions], axis=0) + return torch.from_numpy(batch).permute(0, 3, 1, 2).contiguous() + + def _read_batch_supertiles(self, tile_indices: np.ndarray) -> torch.Tensor: + ts = self._tile_size_px + batch = np.empty((len(tile_indices), ts, ts, 3), dtype=np.uint8) + + # Group requested tiles by super tile, then by read_size for batched reads. + st_to_batch_positions: dict[int, list[int]] = defaultdict(list) + for batch_pos, tile_idx in enumerate(tile_indices): + st_id = int(self._tile_to_st[tile_idx]) + st_to_batch_positions[st_id].append(batch_pos) + + by_read_size: dict[int, list[int]] = defaultdict(list) + for st_id in st_to_batch_positions: + rs = self._supertiles[st_id].read_size_px + by_read_size[rs].append(st_id) + + for read_size, st_ids in by_read_size.items(): + locations = [ + (self._supertiles[st_id].x_lv0, self._supertiles[st_id].y_lv0) + for st_id in st_ids + ] + regions = self._read_region(locations, (read_size, read_size)) + for st_id, region in zip(st_ids, regions): + region_arr = np.asarray(region)[:, :, :3] + for batch_pos in st_to_batch_positions[st_id]: + tile_idx = int(tile_indices[batch_pos]) + cx = int(self._tile_crop_x[tile_idx]) + cy = int(self._tile_crop_y[tile_idx]) + batch[batch_pos] = region_arr[cy : cy + ts, cx : cx + ts] + + return torch.from_numpy(batch).permute(0, 3, 1, 2).contiguous() + + +class OnTheFlyBatchTileCollator: + """Collator that reads tiles directly from a WSI via cucim. + + Same interface as ``BatchTileCollator``: returns ``(indices_tensor, image_tensor)``. + + When super tiles are enabled (default), tiles are grouped into larger read + regions to reduce the number of WSI reads. Use ``ordered_indices`` to + reorder the dataset so that tiles within the same super tile are batched + together by the DataLoader. + """ + + def __init__( + self, + *, + image_path: Path, + tiling_result: TilingResult, + num_cucim_workers: int = 4, + gpu_decode: bool = False, + use_supertiles: bool = True, + ): + self.tile_size = int(tiling_result.read_tile_size_px) + self._reader = CuCIMTileReader( + image_path, + tiling_result, + num_cucim_workers=num_cucim_workers, + gpu_decode=gpu_decode, + use_supertiles=use_supertiles, + ) + + @property + def ordered_indices(self) -> np.ndarray | None: + """Tile indices reordered so tiles in the same super tile are contiguous.""" + return self._reader.ordered_indices + + def build_batch_sampler( + self, + batch_size: int, + dataset_indices: np.ndarray, + ) -> SuperTileBatchSampler | None: + """Build a batch sampler that never splits super tiles across batches. + + ``dataset_indices`` are the tile indices that will be in the dataset + (after any DDP filtering). The sampler groups consecutive dataset + positions that belong to the same super tile. + + Returns None when super tiles are disabled. + """ + if self._reader._tile_to_st is None: + return None + tile_to_st = self._reader._tile_to_st + groups: list[np.ndarray] = [] + current_st = -1 + start = 0 + for pos, tile_idx in enumerate(dataset_indices): + st_id = int(tile_to_st[tile_idx]) + if st_id != current_st: + if pos > start: + groups.append(np.arange(start, pos, dtype=np.int64)) + current_st = st_id + start = pos + if start < len(dataset_indices): + groups.append(np.arange(start, len(dataset_indices), dtype=np.int64)) + return SuperTileBatchSampler(groups, batch_size) + + def __call__(self, batch_indices): + if not batch_indices: + return ( + torch.empty((0,), dtype=torch.long), + torch.empty((0, 3, self.tile_size, self.tile_size), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + worker_start = time.perf_counter() + tile_indices = np.asarray(batch_indices, dtype=np.int64) + tensor, timing = self._reader.read_batch_with_timing(tile_indices) + timing["worker_batch_ms"] = (time.perf_counter() - worker_start) * 1000.0 + return torch.as_tensor(tile_indices, dtype=torch.long), tensor, timing diff --git a/slide2vec/data/dataset.py b/slide2vec/data/dataset.py index e0cd731..7d57b2b 100644 --- a/slide2vec/data/dataset.py +++ b/slide2vec/data/dataset.py @@ -1,79 +1,49 @@ +import time from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING import numpy as np import torch -import wholeslidedata as wsd -from PIL import Image -from transformers.image_processing_utils import BaseImageProcessor -from slide2vec.utils.coordinates import coordinate_arrays, coordinate_matrix +from .tile_store import TarTileReader if TYPE_CHECKING: from hs2p import TilingResult -class TileDataset(torch.utils.data.Dataset): - def __init__( - self, - sample_id: str, - wsi_path: Path, - mask_path: Path | None, - tiling_result: "TilingResult", - backend: str, - transforms: BaseImageProcessor | Callable | None = None, - ): - self.sample_id = sample_id - self.path = wsi_path - self.mask_path = mask_path - self.tiling_result = tiling_result - self.target_spacing = float(tiling_result.target_spacing_um) - self.target_tile_size = int(tiling_result.target_tile_size_px) - self.read_spacing = float(tiling_result.read_spacing_um) - self.read_tile_size = int(tiling_result.read_tile_size_px) - self.resize_factor = self.target_spacing / self.read_spacing - self.backend = backend - self.name = sample_id - self.load_coordinates() - self.transforms = transforms - - def load_coordinates(self): - self.x, self.y = coordinate_arrays(self.tiling_result) - self.coordinates = coordinate_matrix(self.tiling_result) - self.scaled_coordinates = self.scale_coordinates() - self.tile_size_lv0 = int(self.tiling_result.tile_size_lv0) - - def scale_coordinates(self): - # coordinates are defined w.r.t. level 0 - # i need to scale them to target_spacing - wsi = wsd.WholeSlideImage(self.path, backend=self.backend) - min_spacing = wsi.spacings[0] - scale = min_spacing / self.target_spacing - # create a [N, 2] array with x and y coordinates - scaled_coordinates = (self.coordinates * scale).astype(int) - return scaled_coordinates +class TileIndexDataset(torch.utils.data.Dataset): + def __init__(self, tile_indices): + self.tile_indices = np.asarray(tile_indices, dtype=np.int64) def __len__(self): - return len(self.x) + return int(self.tile_indices.shape[0]) def __getitem__(self, idx): - wsi = wsd.WholeSlideImage( - self.path, backend=self.backend - ) # cannot be defined in __init__ because of multiprocessing - tile_arr = wsi.get_patch( - self.x[idx], - self.y[idx], - self.read_tile_size, - self.read_tile_size, - spacing=self.read_spacing, - center=False, + return int(self.tile_indices[idx]) + + +class BatchTileCollator: + def __init__( + self, + *, + tar_path: Path, + tiling_result: "TilingResult", + ): + self.tile_size = int(tiling_result.target_tile_size_px) + self._reader = TarTileReader( + tar_path=tar_path, + tile_size_px=self.tile_size, ) - tile = Image.fromarray(tile_arr).convert("RGB") - if self.target_tile_size != self.read_tile_size: - tile = tile.resize((self.target_tile_size, self.target_tile_size)) - if self.transforms: - if isinstance(self.transforms, BaseImageProcessor): # Hugging Face (`transformer`) - tile = self.transforms(tile, return_tensors="pt")["pixel_values"].squeeze(0) - else: # general callable such as torchvision transforms - tile = self.transforms(tile) - return idx, tile + + def __call__(self, batch_indices): + if not batch_indices: + return ( + torch.empty((0,), dtype=torch.long), + torch.empty((0, 3, self.tile_size, self.tile_size), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + worker_start = time.perf_counter() + tile_indices = np.asarray(batch_indices, dtype=np.int64) + tensor, timing = self._reader.read_batch_with_timing(tile_indices) + timing["worker_batch_ms"] = (time.perf_counter() - worker_start) * 1000.0 + return torch.as_tensor(tile_indices, dtype=torch.long), tensor, timing diff --git a/slide2vec/data/tile_store.py b/slide2vec/data/tile_store.py new file mode 100644 index 0000000..5c81180 --- /dev/null +++ b/slide2vec/data/tile_store.py @@ -0,0 +1,55 @@ +import io +import tarfile +import time +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + + +class TarTileReader: + """Read pre-extracted JPEG tiles from a tar archive. + + Reads pre-extracted JPEG tiles by index and returns them as a + ``[B, 3, H, W]`` uint8 tensor, used by ``BatchTileCollator``. + """ + + def __init__(self, tar_path: Path, tile_size_px: int): + self.tar_path = Path(tar_path) + self.tile_size_px = tile_size_px + self._tar_file: tarfile.TarFile | None = None + self._members: list[tarfile.TarInfo] | None = None + + def _ensure_open(self): + if self._tar_file is None: + self._tar_file = tarfile.open(self.tar_path, "r") + self._members = sorted(self._tar_file.getmembers(), key=lambda m: m.name) + + def read_batch(self, tile_indices: np.ndarray) -> torch.Tensor: + tensor, _timing = self.read_batch_with_timing(tile_indices) + return tensor + + def read_batch_with_timing(self, tile_indices: np.ndarray) -> tuple[torch.Tensor, dict[str, float]]: + if len(tile_indices) == 0: + return torch.empty( + (0, 3, self.tile_size_px, self.tile_size_px), dtype=torch.uint8 + ), {"reader_open_ms": 0.0, "reader_read_ms": 0.0} + was_closed = self._tar_file is None + open_start = time.perf_counter() + self._ensure_open() + reader_open_ms = (time.perf_counter() - open_start) * 1000.0 if was_closed else 0.0 + read_start = time.perf_counter() + batch = np.empty( + (len(tile_indices), self.tile_size_px, self.tile_size_px, 3), + dtype=np.uint8, + ) + for i, idx in enumerate(tile_indices): + f = self._tar_file.extractfile(self._members[idx]) + img = Image.open(io.BytesIO(f.read())).convert("RGB") + batch[i] = np.asarray(img) + reader_read_ms = (time.perf_counter() - read_start) * 1000.0 + return torch.from_numpy(batch).permute(0, 3, 1, 2).contiguous(), { + "reader_open_ms": reader_open_ms, + "reader_read_ms": reader_read_ms, + } diff --git a/slide2vec/data/wsd_tile_reader.py b/slide2vec/data/wsd_tile_reader.py new file mode 100644 index 0000000..4288470 --- /dev/null +++ b/slide2vec/data/wsd_tile_reader.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from pathlib import Path + + from hs2p import TilingResult + + +class WSDTileReader: + """Read tiles from a WSI via wholeslidedata (ASAP/OpenSlide backend). + + Supports two reading modes: + - ``use_supertiles=False``: one ``get_patch`` call per tile (baseline). + - ``use_supertiles=True``: one ``get_patch`` call per super tile block + (8×8/4×4/2×2), then individual tiles are cropped from the region. + + Lazy WSI open via ``_ensure_open()`` — DataLoader workers fork, so the WSI + handle must not be created in ``__init__``. + """ + + def __init__( + self, + image_path: "Path", + tiling_result: "TilingResult", + *, + backend: str = "asap", + use_supertiles: bool = False, + ): + self._image_path = str(image_path) + self._x = tiling_result.x + self._y = tiling_result.y + self._read_spacing_um = float(tiling_result.read_spacing_um) + self._tile_size_px = int(tiling_result.read_tile_size_px) + self._backend = backend + self._wsi = None + + self._use_supertiles = use_supertiles + if use_supertiles: + from slide2vec.data.cucim_tile_reader import _build_supertile_index + + ( + self._supertiles, + self._tile_to_st, + self._tile_crop_x, + self._tile_crop_y, + self.ordered_indices, + ) = _build_supertile_index(tiling_result) + else: + self._supertiles = None + self._tile_to_st = None + self.ordered_indices = None + + def _ensure_open(self) -> None: + if self._wsi is None: + import wholeslidedata as wsd + from hs2p.wsi.backend import coerce_wsd_path + + self._wsi = wsd.WholeSlideImage( + coerce_wsd_path(self._image_path, backend=self._backend), + backend=self._backend, + ) + + def read_batch(self, tile_indices: np.ndarray) -> torch.Tensor: + tensor, _timing = self.read_batch_with_timing(tile_indices) + return tensor + + def read_batch_with_timing(self, tile_indices: np.ndarray) -> tuple[torch.Tensor, dict[str, float]]: + if len(tile_indices) == 0: + ts = self._tile_size_px + return torch.empty((0, 3, ts, ts), dtype=torch.uint8), {"reader_open_ms": 0.0, "reader_read_ms": 0.0} + was_closed = self._wsi is None + open_start = time.perf_counter() + self._ensure_open() + reader_open_ms = (time.perf_counter() - open_start) * 1000.0 if was_closed else 0.0 + read_start = time.perf_counter() + if self._use_supertiles: + tensor = self._read_batch_supertiles(tile_indices) + else: + tensor = self._read_batch_simple(tile_indices) + reader_read_ms = (time.perf_counter() - read_start) * 1000.0 + return tensor, {"reader_open_ms": reader_open_ms, "reader_read_ms": reader_read_ms} + + def _read_batch_simple(self, tile_indices: np.ndarray) -> torch.Tensor: + ts = self._tile_size_px + tiles = [] + for i in tile_indices: + region = self._wsi.get_patch( + int(self._x[i]), + int(self._y[i]), + ts, + ts, + spacing=self._read_spacing_um, + center=False, + ) + tiles.append(np.asarray(region)[:, :, :3]) + batch = np.stack(tiles, axis=0) + return torch.from_numpy(batch).permute(0, 3, 1, 2).contiguous() + + def _read_batch_supertiles(self, tile_indices: np.ndarray) -> torch.Tensor: + ts = self._tile_size_px + batch = np.empty((len(tile_indices), ts, ts, 3), dtype=np.uint8) + + st_to_batch_positions: dict[int, list[int]] = {} + for batch_pos, tile_idx in enumerate(tile_indices): + st_id = int(self._tile_to_st[tile_idx]) + st_to_batch_positions.setdefault(st_id, []).append(batch_pos) + + for st_id, batch_positions in st_to_batch_positions.items(): + st = self._supertiles[st_id] + region = self._wsi.get_patch( + st.x_lv0, + st.y_lv0, + st.read_size_px, + st.read_size_px, + spacing=self._read_spacing_um, + center=False, + ) + region_arr = np.asarray(region)[:, :, :3] + for batch_pos in batch_positions: + tile_idx = int(tile_indices[batch_pos]) + cx = int(self._tile_crop_x[tile_idx]) + cy = int(self._tile_crop_y[tile_idx]) + batch[batch_pos] = region_arr[cy : cy + ts, cx : cx + ts] + + return torch.from_numpy(batch).permute(0, 3, 1, 2).contiguous() + + +class WSDOnTheFlyBatchTileCollator: + """Collator that reads individual tiles from a WSI via wholeslidedata. + + Same interface as ``OnTheFlyBatchTileCollator``: returns + ``(indices_tensor, image_tensor)`` where ``image_tensor`` is + ``(B, 3, read_tile_size_px, read_tile_size_px)`` uint8. + + When ``use_supertiles=False`` (default), each tile triggers a separate + ``get_patch`` call — the baseline for benchmarking against cucim. + When ``use_supertiles=True``, tiles are grouped into 8×8/4×4/2×2 blocks + and each block is read as one larger region that is then cropped. + """ + + def __init__( + self, + *, + image_path: "Path", + tiling_result: "TilingResult", + backend: str = "asap", + use_supertiles: bool = False, + ): + self.tile_size = int(tiling_result.read_tile_size_px) + self._reader = WSDTileReader(image_path, tiling_result, backend=backend, use_supertiles=use_supertiles) + + @property + def ordered_indices(self) -> np.ndarray | None: + return self._reader.ordered_indices + + def __call__(self, batch_indices): + if not batch_indices: + return ( + torch.empty((0,), dtype=torch.long), + torch.empty((0, 3, self.tile_size, self.tile_size), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + worker_start = time.perf_counter() + tile_indices = np.asarray(batch_indices, dtype=np.int64) + tensor, timing = self._reader.read_batch_with_timing(tile_indices) + timing["worker_batch_ms"] = (time.perf_counter() - worker_start) * 1000.0 + return torch.as_tensor(tile_indices, dtype=torch.long), tensor, timing diff --git a/slide2vec/distributed/direct_embed_worker.py b/slide2vec/distributed/direct_embed_worker.py index 3a3e1f2..163f722 100644 --- a/slide2vec/distributed/direct_embed_worker.py +++ b/slide2vec/distributed/direct_embed_worker.py @@ -55,7 +55,15 @@ def main(argv=None) -> int: for slide, tiling_result in zip(slide_records, tiling_results) } progress_events_path = request.get("progress_events_path") - reporter = JsonlProgressReporter(progress_events_path, rank=global_rank) if progress_events_path else None + reporter = ( + JsonlProgressReporter( + progress_events_path, + rank=global_rank, + progress_label=f"cuda:{local_rank}", + ) + if progress_events_path + else None + ) context = activate_progress_reporter(reporter) if reporter is not None else nullcontext() with context: diff --git a/slide2vec/distributed/pipeline_worker.py b/slide2vec/distributed/pipeline_worker.py index 1525dc9..3d50fef 100644 --- a/slide2vec/distributed/pipeline_worker.py +++ b/slide2vec/distributed/pipeline_worker.py @@ -52,7 +52,15 @@ def main(argv=None) -> int: assigned_slides = [slide for slide, _ in assigned_pairs] assigned_tiling_results = [tiling_result for _, tiling_result in assigned_pairs] progress_events_path = request.get("progress_events_path") - reporter = JsonlProgressReporter(progress_events_path, rank=global_rank) if progress_events_path else None + reporter = ( + JsonlProgressReporter( + progress_events_path, + rank=global_rank, + progress_label=f"cuda:{local_rank}", + ) + if progress_events_path + else None + ) context = activate_progress_reporter(reporter) if reporter is not None else nullcontext() with context: embedded_slides = _compute_embedded_slides( diff --git a/slide2vec/inference.py b/slide2vec/inference.py index f81632a..8ce7c9c 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -1,4 +1,6 @@ import json +import os +import re import shutil import subprocess import sys @@ -12,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence import numpy as np +from transformers.image_processing_utils import BaseImageProcessor from slide2vec.api import ( EmbeddedSlide, @@ -27,6 +30,7 @@ write_slide_embeddings, write_tile_embeddings, ) +from slide2vec.model_settings import canonicalize_model_name from slide2vec.progress import ( emit_progress, emit_progress_event, @@ -50,6 +54,29 @@ class LoadedModel: feature_dim: int device: Any + +@dataclass(frozen=True) +class BatchTransformSpec: + resize_size: tuple[int, int] | None + center_crop_size: tuple[int, int] | None + mean: tuple[float, ...] | None + std: tuple[float, ...] | None + region_unfold_tile_size: int | None = None + resize_interpolation: str = "bilinear" + + +@dataclass +class PreparedBatch: + indices: Any + image: Any + loader_wait_ms: float + preprocess_ms: float + ready_wait_ms: float = 0.0 + worker_batch_ms: float = 0.0 + reader_open_ms: float = 0.0 + reader_read_ms: float = 0.0 + + def _slide_spec_cls(): try: from hs2p import SlideSpec @@ -69,6 +96,33 @@ def _optional_float(value: Any) -> float | None: return float(value) +def _slurm_cpu_limit() -> int | None: + for env_name in ("SLURM_CPUS_PER_TASK", "SLURM_JOB_CPUS_PER_NODE"): + value = os.environ.get(env_name) + if not value: + continue + match = re.match(r"\s*(\d+)", value) + if match is None: + continue + limit = int(match.group(1)) + if limit > 0: + return limit + return None + + +def _resolve_on_the_fly_num_workers(num_cucim_workers: int) -> tuple[int, str]: + cpu_count = os.cpu_count() or 4 + worker_budget = cpu_count + details = [f"cpu_count={cpu_count}"] + slurm_limit = _slurm_cpu_limit() + if slurm_limit is not None: + worker_budget = min(worker_budget, slurm_limit) + details.append(f"slurm_cpu_limit={slurm_limit}") + effective_num_workers = max(1, worker_budget // num_cucim_workers) + details.append(f"num_cucim_workers={num_cucim_workers}") + return effective_num_workers, " // ".join(details) + + def _make_slide_spec( *, sample_id: str, @@ -103,10 +157,14 @@ def load_model( from slide2vec.models import ModelFactory from slide2vec.resources import load_config - model_cfg = OmegaConf.create(load_config("models", "default")) + name = canonicalize_model_name(name) + cfg = OmegaConf.merge( + OmegaConf.create(load_config("preprocessing", "default")), + OmegaConf.create(load_config("models", "default")), + ) preset_name = _preset_name(name, level) if preset_name is not None: - model_cfg = OmegaConf.merge(model_cfg, load_config("models", preset_name)) + cfg = OmegaConf.merge(cfg, load_config("models", preset_name)) overrides = { "name": name, @@ -121,7 +179,10 @@ def load_model( } for key, value in overrides.items(): if value is not None: - model_cfg[key] = value + cfg.model[key] = value + + OmegaConf.resolve(cfg) + model_cfg = cfg.model backend_model = ModelFactory(model_cfg).get_model() target_device = _resolve_device(device, backend_model.device) @@ -164,7 +225,7 @@ def embed_slides( slide_records, preprocessing, output_dir=work_dir, - num_workers=execution.num_workers, + num_workers=execution.num_preprocessing_workers, ) _emit_tiling_finished( process_list_path, @@ -271,39 +332,15 @@ def embed_tiles( loaded = model._load_backend() slide_records = [_coerce_slide_spec(slide) for slide in slides] resolved_tiling_results = _normalize_tiling_results(tiling_results, slide_records) - torch = _import_torch() - from slide2vec.data import TileDataset - - autocast_context = ( - torch.autocast(device_type="cuda", dtype=torch.float16) - if execution.mixed_precision and str(loaded.device).startswith("cuda") - else nullcontext() - ) artifacts: list[TileEmbeddingArtifact] = [] for slide, tiling_result in zip(slide_records, resolved_tiling_results): - transforms = _create_transforms(loaded) - dataset = TileDataset( - sample_id=slide.sample_id, - wsi_path=slide.image_path, - mask_path=slide.mask_path, - tiling_result=tiling_result, - backend=_resolve_backend(preprocessing), - transforms=transforms if model.level != "region" else _create_region_transforms(transforms, loaded.model), - ) - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=execution.batch_size, - shuffle=False, - num_workers=execution.num_workers, - pin_memory=str(loaded.device).startswith("cuda"), - ) - features = _run_forward_pass( - dataloader, + features = _compute_tile_embeddings_for_slide( loaded, - autocast_context, - sample_id=slide.sample_id, - total_items=len(dataset), - unit_label="tile", + model, + slide, + tiling_result, + preprocessing=preprocessing or PreprocessingConfig(), + execution=execution, ) metadata = _build_tile_embedding_metadata( model, @@ -311,7 +348,7 @@ def embed_tiles( image_path=slide.image_path, mask_path=slide.mask_path, tile_size_lv0=int(_require_attr(tiling_result, "tile_size_lv0")), - backend=_resolve_backend(preprocessing), + backend=_resolve_slide_backend(preprocessing, tiling_result), ) artifact = _write_tile_embedding_artifact( slide.sample_id, @@ -338,13 +375,13 @@ def aggregate_tiles( outputs: list[SlideEmbeddingArtifact] = [] for artifact in tile_artifacts: metadata = artifact.metadata - if not metadata.get("tiles_npz_path") or not metadata.get("tiles_meta_path"): + if not metadata.get("coordinates_npz_path") or not metadata.get("coordinates_meta_path"): raise ValueError( f"Tile artifact for {artifact.sample_id} is missing tiling metadata paths required for slide aggregation" ) tiling_result = _load_tiling_result( - Path(metadata["tiles_npz_path"]), - Path(metadata["tiles_meta_path"]), + Path(metadata["coordinates_npz_path"]), + Path(metadata["coordinates_meta_path"]), ) coordinates = _coordinate_matrix(tiling_result) image_path = Path(metadata["image_path"]) @@ -353,7 +390,7 @@ def aggregate_tiles( image_path, coordinates, float(_require_attr(tiling_result, "target_spacing_um")), - metadata.get("backend", _resolve_backend(preprocessing)), + metadata.get("backend", _resolve_slide_backend(preprocessing, tiling_result)), ) coordinate_tensor = torch.tensor(coordinates, dtype=torch.int, device=loaded.device) tile_features = load_array(artifact.path) @@ -412,7 +449,7 @@ def run_pipeline( slide_records, preprocessing, output_dir=output_dir, - num_workers=execution.num_workers, + num_workers=execution.num_preprocessing_workers, ) _emit_tiling_finished( process_list_path, @@ -800,38 +837,109 @@ def _compute_tile_embeddings_for_slide( tile_indices=None, ): torch = _import_torch() - from slide2vec.data import TileDataset + from slide2vec.data.dataset import BatchTileCollator, TileIndexDataset + autocast_dtype = _autocast_dtype(torch, execution.precision) autocast_context = ( - torch.autocast(device_type="cuda", dtype=torch.float16) - if execution.mixed_precision and str(loaded.device).startswith("cuda") + torch.autocast(device_type="cuda", dtype=autocast_dtype) + if autocast_dtype is not None and str(loaded.device).startswith("cuda") else nullcontext() ) - transforms = _create_transforms(loaded) - dataset = TileDataset( - sample_id=slide.sample_id, - wsi_path=slide.image_path, - mask_path=slide.mask_path, - tiling_result=tiling_result, - backend=_resolve_backend(preprocessing), - transforms=transforms if model.level != "region" else _create_region_transforms(transforms, loaded.model), - ) + resolved_indices = np.arange(_num_tiles(tiling_result), dtype=np.int64) if tile_indices is not None: - tile_indices = np.asarray(tile_indices, dtype=np.int64) - if tile_indices.size == 0: + resolved_indices = np.asarray(tile_indices, dtype=np.int64) + if resolved_indices.size == 0: return torch.empty((0, int(loaded.feature_dim)), dtype=torch.float32) - dataset = torch.utils.data.Subset(dataset, tile_indices.tolist()) + if preprocessing.on_the_fly and preprocessing.read_tiles_from is None: + resolved_backend = _resolve_slide_backend(preprocessing, tiling_result) + if resolved_backend == "cucim": + from slide2vec.data.cucim_tile_reader import OnTheFlyBatchTileCollator + + collate_fn = OnTheFlyBatchTileCollator( + image_path=slide.image_path, + tiling_result=tiling_result, + num_cucim_workers=preprocessing.num_cucim_workers, + gpu_decode=preprocessing.gpu_decode, + use_supertiles=preprocessing.use_supertiles, + ) + else: + from slide2vec.data.wsd_tile_reader import WSDOnTheFlyBatchTileCollator + + collate_fn = WSDOnTheFlyBatchTileCollator( + image_path=slide.image_path, + tiling_result=tiling_result, + backend=resolved_backend, + use_supertiles=preprocessing.use_supertiles, + ) + if collate_fn.ordered_indices is not None: + reorder = collate_fn.ordered_indices + if tile_indices is not None: + mask = np.isin(reorder, resolved_indices) + resolved_indices = reorder[mask] + else: + resolved_indices = reorder + if preprocessing.adaptive_batching: + batch_sampler = collate_fn.build_batch_sampler(execution.batch_size, resolved_indices) + else: + batch_sampler = None + else: + batch_sampler = None + if preprocessing.on_the_fly and preprocessing.read_tiles_from is not None: + import logging + logging.getLogger(__name__).warning( + "read_tiles_from is set; ignoring on_the_fly=True and reading tiles from tar archives" + ) + tar_path = _resolve_tile_store_archive_for_slide( + slide=slide, + tiling_result=tiling_result, + preprocessing=preprocessing, + ) + if tar_path is None: + raise ValueError( + f"Slide {slide.sample_id} is missing tiles_tar_path — " + "pre-extracted tile archives are required for embedding" + ) + collate_fn = BatchTileCollator( + tar_path=tar_path, + tiling_result=tiling_result, + ) + dataset = TileIndexDataset(resolved_indices) + batch_preprocessor = _build_batch_preprocessor( + loaded, + model, + tiling_result, + execution=execution, + ) + loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) + if preprocessing.on_the_fly and preprocessing.read_tiles_from is None: + import logging + + effective_num_workers, worker_context = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + if effective_num_workers != execution.num_workers: + logging.getLogger(__name__).info( + f"on-the-fly mode: setting DataLoader num_workers={effective_num_workers} " + f"({worker_context}); " + f"ignoring speed.num_workers={execution.num_workers}" + ) + loader_kwargs["num_workers"] = effective_num_workers + if effective_num_workers == 0: + loader_kwargs.pop("persistent_workers", None) + loader_kwargs.pop("prefetch_factor", None) + if batch_sampler is not None: + loader_kwargs["batch_sampler"] = batch_sampler + else: + loader_kwargs["batch_size"] = execution.batch_size + loader_kwargs["shuffle"] = False dataloader = torch.utils.data.DataLoader( dataset, - batch_size=execution.batch_size, - shuffle=False, - num_workers=execution.num_workers, - pin_memory=str(loaded.device).startswith("cuda"), + collate_fn=collate_fn, + **loader_kwargs, ) return _run_forward_pass( dataloader, loaded, autocast_context, + batch_preprocessor=batch_preprocessor, sample_id=slide.sample_id, total_items=len(dataset), unit_label="region" if model.level == "region" else "tile", @@ -857,7 +965,7 @@ def _aggregate_tile_embeddings_for_slide( slide.image_path, coordinates, float(_require_attr(tiling_result, "target_spacing_um")), - _resolve_backend(preprocessing), + _resolve_slide_backend(preprocessing, tiling_result), ) coordinate_tensor = torch.tensor(coordinates, dtype=torch.int, device=loaded.device) if not torch.is_tensor(tile_embeddings): @@ -923,7 +1031,7 @@ def _persist_embedded_slide( image_path=embedded_slide.image_path, mask_path=embedded_slide.mask_path, tile_size_lv0=embedded_slide.tile_size_lv0, - backend=_resolve_backend(preprocessing), + backend=_resolve_slide_backend(preprocessing, tiling_result), ), ) slide_artifact = None @@ -950,8 +1058,9 @@ def _build_tile_embedding_metadata( return { "encoder_name": model.name, "encoder_level": model.level, - "tiles_npz_path": str(_require_attr(tiling_result, "tiles_npz_path", allow_missing=True) or ""), - "tiles_meta_path": str(_require_attr(tiling_result, "tiles_meta_path", allow_missing=True) or ""), + "coordinates_npz_path": str(_require_attr(tiling_result, "coordinates_npz_path", allow_missing=True) or ""), + "coordinates_meta_path": str(_require_attr(tiling_result, "coordinates_meta_path", allow_missing=True) or ""), + "tiles_tar_path": str(_require_attr(tiling_result, "tiles_tar_path", allow_missing=True) or ""), "image_path": str(image_path), "mask_path": str(mask_path) if mask_path is not None else None, "tile_size_lv0": int(tile_size_lv0), @@ -1023,11 +1132,385 @@ def _create_region_transforms(base_transforms, backend_model): ] ) + +def _embedding_dataloader_kwargs(loaded: LoadedModel, execution: ExecutionOptions) -> dict[str, Any]: + kwargs: dict[str, Any] = { + "num_workers": execution.num_workers, + "pin_memory": str(loaded.device).startswith("cuda"), + } + if execution.num_workers > 0: + kwargs["persistent_workers"] = bool(execution.persistent_workers) + kwargs["prefetch_factor"] = int(execution.prefetch_factor) + return kwargs + + + +def _build_batch_preprocessor( + loaded: LoadedModel, + model, + tiling_result, + *, + execution: ExecutionOptions, +): + torch = _import_torch() + spec = _build_batch_transform_spec(loaded.transforms) + if spec is None: + raise ValueError("Batched preprocessing is only available for supported deterministic transform stacks") + + preprocess_device = loaded.device if execution.gpu_batch_preprocessing else torch.device("cpu") + + def preprocess(batch): + image = batch + image = _prepare_batch_tensor(image, preprocess_device=preprocess_device) + if spec.resize_size is None: + # Model has no Resize transform: apply bilinear resize to target tile size as fallback + image = _resize_image_batch( + image, + (int(tiling_result.target_tile_size_px), int(tiling_result.target_tile_size_px)), + ) + if model.level == "region": + image = _apply_region_batch_transform_spec( + image, + spec, + tile_size=int(getattr(loaded.model, "tile_size")), + ) + else: + image = _apply_batch_transform_spec(image, spec) + if image.device != loaded.device: + image = image.to(loaded.device, non_blocking=str(loaded.device).startswith("cuda")) + return image.contiguous() + + return preprocess + + +def _build_batch_transform_spec(transforms) -> BatchTransformSpec | None: + if isinstance(transforms, BaseImageProcessor): + resize_size = _normalize_hw( + getattr(transforms, "crop_size", None) or getattr(transforms, "size", None) + ) + if resize_size is None: + return None + mean = getattr(transforms, "image_mean", None) + std = getattr(transforms, "image_std", None) + return BatchTransformSpec( + resize_size=resize_size, + center_crop_size=None, + mean=tuple(float(value) for value in mean) if mean is not None else None, + std=tuple(float(value) for value in std) if std is not None else None, + region_unfold_tile_size=None, + ) + + transform_steps = _iter_transform_steps(transforms) + if transform_steps is None: + return None + + resize_size = None + resize_interpolation = "bilinear" + center_crop_size = None + mean = None + std = None + region_unfold_tile_size = None + supported_step_names = { + "Resize", + "CenterCrop", + "Normalize", + "ToTensor", + "MaybeToTensor", + "ToImage", + "ConvertImageDtype", + } + for step in transform_steps: + if _is_region_unfolding_transform(step): + step_tile_size = int(getattr(step, "tile_size")) + if region_unfold_tile_size is not None and region_unfold_tile_size != step_tile_size: + return None + region_unfold_tile_size = step_tile_size + continue + step_name = type(step).__name__ + if step_name not in supported_step_names: + return None + if step_name == "Resize": + resize_size = _normalize_hw(getattr(step, "size", None)) + resize_interpolation = _interp_mode_to_str(getattr(step, "interpolation", None)) + elif step_name == "CenterCrop": + center_crop_size = _normalize_hw(getattr(step, "size", None)) + elif step_name == "Normalize": + mean = tuple(float(value) for value in getattr(step, "mean")) + std = tuple(float(value) for value in getattr(step, "std")) + return BatchTransformSpec( + resize_size=resize_size, + center_crop_size=center_crop_size, + mean=mean, + std=std, + region_unfold_tile_size=region_unfold_tile_size, + resize_interpolation=resize_interpolation, + ) + + +def _iter_transform_steps(transforms): + transform_steps = getattr(transforms, "transforms", None) + if transform_steps is None: + return None + flattened = [] + for step in transform_steps: + nested = _iter_transform_steps(step) + if nested is not None: + flattened.extend(nested) + else: + flattened.append(step) + return flattened + + +def _is_region_unfolding_transform(step) -> bool: + return type(step).__name__ == "RegionUnfolding" and hasattr(step, "tile_size") + + +def _prepare_batch_tensor(image, *, preprocess_device): + torch = _import_torch() + if image.device != preprocess_device: + image = image.to(preprocess_device, non_blocking=str(preprocess_device).startswith("cuda")) + if image.dtype == torch.uint8: + return image.float().div(255.0) + return image.float() + + +def _interp_mode_to_str(interp_mode) -> str: + """Map a torchvision InterpolationMode to the string accepted by F.interpolate.""" + if interp_mode is None: + return "bilinear" + name = str(interp_mode).upper() + if "BICUBIC" in name: + return "bicubic" + if "NEAREST" in name: + return "nearest" + return "bilinear" + + +def _resize_image_batch(image, size: tuple[int, int], *, mode: str = "bilinear"): + if tuple(int(dim) for dim in image.shape[-2:]) == size: + return image + torch = _import_torch() + align_corners = False if mode in ("bilinear", "bicubic") else None + kwargs = {"antialias": True} if mode in ("bilinear", "bicubic") else {} + return torch.nn.functional.interpolate( + image, + size=size, + mode=mode, + **({"align_corners": align_corners} if align_corners is not None else {}), + **kwargs, + ) + + +def _apply_batch_transform_spec(image, spec: BatchTransformSpec): + torch = _import_torch() + if spec.resize_size is not None: + image = _resize_image_batch(image, spec.resize_size, mode=spec.resize_interpolation) + if spec.center_crop_size is not None: + image = _center_crop_batch(image, spec.center_crop_size) + if spec.mean is not None and spec.std is not None: + mean = torch.tensor(spec.mean, dtype=image.dtype, device=image.device).view(1, -1, 1, 1) + std = torch.tensor(spec.std, dtype=image.dtype, device=image.device).view(1, -1, 1, 1) + image = (image - mean) / std + return image + + +def _apply_region_batch_transform_spec(image, spec: BatchTransformSpec, *, tile_size: int): + if spec.region_unfold_tile_size is not None and spec.region_unfold_tile_size != tile_size: + raise ValueError( + "Region transform stack RegionUnfolding tile_size does not match the region model tile_size" + ) + region_tile_size = spec.region_unfold_tile_size or tile_size + batch_size = int(image.shape[0]) + unfolded = _unfold_region_batch(image, region_tile_size) + num_tiles = int(unfolded.shape[1]) + flattened = unfolded.reshape(batch_size * num_tiles, *unfolded.shape[-3:]) + transformed = _apply_batch_transform_spec(flattened, spec) + return transformed.reshape(batch_size, num_tiles, *transformed.shape[-3:]) + + +def _unfold_region_batch(image, tile_size: int): + torch = _import_torch() + height, width = (int(image.shape[-2]), int(image.shape[-1])) + if height % tile_size != 0 or width % tile_size != 0: + raise ValueError( + f"Region batch with shape {height}x{width} is not divisible by tile_size={tile_size}" + ) + unfolded = torch.nn.functional.unfold(image, kernel_size=tile_size, stride=tile_size) + unfolded = unfolded.transpose(1, 2) + return unfolded.reshape(image.shape[0], -1, image.shape[1], tile_size, tile_size) + + +def _normalize_hw(value) -> tuple[int, int] | None: + if value is None: + return None + if isinstance(value, int): + return (int(value), int(value)) + if isinstance(value, (tuple, list)): + if len(value) == 1: + return (int(value[0]), int(value[0])) + if len(value) >= 2: + return (int(value[0]), int(value[1])) + return None + if isinstance(value, dict): + if "height" in value and "width" in value: + return (int(value["height"]), int(value["width"])) + if "shortest_edge" in value: + edge = int(value["shortest_edge"]) + return (edge, edge) + return None + + +def _center_crop_batch(image, size: tuple[int, int]): + target_h, target_w = size + height, width = int(image.shape[-2]), int(image.shape[-1]) + crop_h = min(target_h, height) + crop_w = min(target_w, width) + top = max((height - crop_h) // 2, 0) + left = max((width - crop_w) // 2, 0) + return image[..., top : top + crop_h, left : left + crop_w] + + +@contextmanager +def _maybe_nvtx_range(torch, label: str): + nvtx = getattr(getattr(torch, "cuda", None), "nvtx", None) + if nvtx is None: + yield + return + pushed = False + try: + nvtx.range_push(label) + pushed = True + except Exception: + yield + return + try: + yield + finally: + if pushed: + try: + nvtx.range_pop() + except Exception: + return + + +class _BatchPrefetcher: + def __init__(self, dataloader, loaded: LoadedModel, batch_preprocessor): + self.torch = _import_torch() + self.iterator = iter(dataloader) + self.loaded = loaded + self.batch_preprocessor = batch_preprocessor + self.copy_stream = self._make_copy_stream() + self._pinned_host_buffer = None + self._next_batch: PreparedBatch | None = None + self._preload() + + def _unpack_loader_batch(self, batch): + if isinstance(batch, (tuple, list)): + if len(batch) == 3 and isinstance(batch[2], dict): + return batch[0], batch[1], batch[2] + if len(batch) == 2: + return batch[0], batch[1], {} + raise ValueError("Expected the embedding dataloader to yield (indices, image) or (indices, image, timing)") + + def _make_copy_stream(self): + if not str(self.loaded.device).startswith("cuda"): + return None + return self.torch.cuda.Stream(device=self.loaded.device) + + def _stage_host_batch(self, image): + if self.copy_stream is None or not self.torch.is_tensor(image): + return image + if image.device.type != "cpu" or image.is_pinned(): + return image + if ( + self._pinned_host_buffer is None + or tuple(self._pinned_host_buffer.shape) != tuple(image.shape) + or self._pinned_host_buffer.dtype != image.dtype + ): + self._pinned_host_buffer = self.torch.empty( + image.shape, + dtype=image.dtype, + pin_memory=True, + ) + self._pinned_host_buffer.copy_(image) + return self._pinned_host_buffer + + def _prepare_batch(self, image): + preprocess_start = time.perf_counter() + if self.batch_preprocessor is not None: + with _maybe_nvtx_range(self.torch, "slide2vec.batch_preprocess"): + prepared = self.batch_preprocessor(image) + else: + with _maybe_nvtx_range(self.torch, "slide2vec.batch_h2d"): + prepared = image.to( + self.loaded.device, + non_blocking=str(self.loaded.device).startswith("cuda"), + ) + preprocess_ms = (time.perf_counter() - preprocess_start) * 1000.0 + return prepared, preprocess_ms + + def _preload(self) -> None: + wait_start = time.perf_counter() + try: + batch = next(self.iterator) + except StopIteration: + self._next_batch = None + return + loader_wait_ms = (time.perf_counter() - wait_start) * 1000.0 + indices, image, timing = self._unpack_loader_batch(batch) + if self.copy_stream is None: + prepared, preprocess_ms = self._prepare_batch(image) + self._next_batch = PreparedBatch( + indices=indices, + image=prepared, + loader_wait_ms=loader_wait_ms, + preprocess_ms=preprocess_ms, + worker_batch_ms=float(timing.get("worker_batch_ms", 0.0)), + reader_open_ms=float(timing.get("reader_open_ms", 0.0)), + reader_read_ms=float(timing.get("reader_read_ms", 0.0)), + ) + return + + staged = self._stage_host_batch(image) + preprocess_start = time.perf_counter() + with self.torch.cuda.stream(self.copy_stream): + prepared = self.batch_preprocessor(staged) if self.batch_preprocessor is not None else staged.to( + self.loaded.device, + non_blocking=True, + ) + preprocess_ms = (time.perf_counter() - preprocess_start) * 1000.0 + self._next_batch = PreparedBatch( + indices=indices, + image=prepared, + loader_wait_ms=loader_wait_ms, + preprocess_ms=preprocess_ms, + worker_batch_ms=float(timing.get("worker_batch_ms", 0.0)), + reader_open_ms=float(timing.get("reader_open_ms", 0.0)), + reader_read_ms=float(timing.get("reader_read_ms", 0.0)), + ) + + def __iter__(self): + return self + + def __next__(self) -> PreparedBatch: + if self._next_batch is None: + raise StopIteration + current = self._next_batch + if self.copy_stream is not None: + ready_start = time.perf_counter() + current_stream = self.torch.cuda.current_stream(device=self.loaded.device) + current_stream.wait_stream(self.copy_stream) + current.ready_wait_ms = (time.perf_counter() - ready_start) * 1000.0 + self._preload() + return current + + def _run_forward_pass( dataloader, loaded: LoadedModel, autocast_context, *, + batch_preprocessor=None, sample_id: str | None = None, total_items: int | None = None, unit_label: str = "tile", @@ -1035,12 +1518,44 @@ def _run_forward_pass( torch = _import_torch() outputs = [] processed = 0 + batch_index = 0 + prefetcher = _BatchPrefetcher(dataloader, loaded, batch_preprocessor) with torch.inference_mode(), autocast_context: - for _, image in dataloader: - image = image.to(loaded.device, non_blocking=str(loaded.device).startswith("cuda")) - embedding = loaded.model(image)["embedding"].detach().cpu() + for prepared_batch in prefetcher: + image = prepared_batch.image + forward_start = time.perf_counter() + with _maybe_nvtx_range(torch, "slide2vec.batch_forward"): + embedding = loaded.model(image)["embedding"].detach().cpu() + forward_ms = (time.perf_counter() - forward_start) * 1000.0 outputs.append(embedding) processed += int(embedding.shape[0]) + batch_index += 1 + batch_total_ms = ( + prepared_batch.loader_wait_ms + + prepared_batch.ready_wait_ms + + prepared_batch.preprocess_ms + + forward_ms + ) + gpu_busy_fraction = ( + (prepared_batch.ready_wait_ms + prepared_batch.preprocess_ms + forward_ms) / batch_total_ms + if batch_total_ms > 0 + else 0.0 + ) + emit_progress( + "embedding.batch.timing", + sample_id=sample_id, + batch_index=batch_index, + batch_size=int(embedding.shape[0]), + loader_wait_ms=round(prepared_batch.loader_wait_ms, 4), + ready_wait_ms=round(prepared_batch.ready_wait_ms, 4), + preprocess_ms=round(prepared_batch.preprocess_ms, 4), + worker_batch_ms=round(prepared_batch.worker_batch_ms, 4), + reader_open_ms=round(prepared_batch.reader_open_ms, 4), + reader_read_ms=round(prepared_batch.reader_read_ms, 4), + forward_ms=round(forward_ms, 4), + gpu_busy_fraction=round(gpu_busy_fraction, 4), + unit=unit_label, + ) if sample_id is not None: emit_progress( "embedding.tile.progress", @@ -1055,8 +1570,8 @@ def _run_forward_pass( def _preset_name(name: str, level: str) -> str | None: - preset = name - if name == "prov-gigapath": + preset = canonicalize_model_name(name) + if preset == "prov-gigapath": preset = "prov-gigapath-slide" if level == "slide" else "prov-gigapath-tile" candidate = Path(__file__).parent / "configs" / "models" / f"{preset}.yaml" if candidate.is_file(): @@ -1260,10 +1775,16 @@ def _embedding_work_dir(output_dir: Path | None): yield Path(tmp_dir) -def _tile_slides(slides: Sequence[SlideSpec], preprocessing: PreprocessingConfig, *, output_dir: Path, num_workers: int): +def _tile_slides( + slides: Sequence[SlideSpec], + preprocessing: PreprocessingConfig, + *, + output_dir: Path, + num_workers: int, +): from hs2p import tile_slides - tiling_cfg, segmentation_cfg, filtering_cfg, preview_cfg, read_tiles_from, resume = _build_hs2p_configs(preprocessing) + tiling_cfg, segmentation_cfg, filtering_cfg, preview_cfg, read_coordinates_from, resume = _build_hs2p_configs(preprocessing) tile_slides( slides, tiling=tiling_cfg, @@ -1272,8 +1793,10 @@ def _tile_slides(slides: Sequence[SlideSpec], preprocessing: PreprocessingConfig preview=preview_cfg, output_dir=output_dir, num_workers=num_workers, - read_tiles_from=read_tiles_from, + read_coordinates_from=read_coordinates_from, resume=resume, + save_tiles=not preprocessing.on_the_fly and preprocessing.read_tiles_from is None, + jpeg_backend=preprocessing.jpeg_backend, ) @@ -1301,7 +1824,7 @@ def _build_hs2p_configs(preprocessing: PreprocessingConfig): from hs2p import FilterConfig, PreviewConfig, SegmentationConfig, TilingConfig tiling_cfg = TilingConfig( - backend=preprocessing.backend, + backend=_resolve_tiling_backend(preprocessing), target_spacing_um=preprocessing.target_spacing_um, target_tile_size_px=preprocessing.target_tile_size_px, tolerance=preprocessing.tolerance, @@ -1318,11 +1841,31 @@ def _build_hs2p_configs(preprocessing: PreprocessingConfig): segmentation_cfg, filtering_cfg, preview_cfg, - preprocessing.read_tiles_from, + preprocessing.read_coordinates_from, preprocessing.resume, ) +def _resolve_tile_store_archive_for_slide( + *, + slide: SlideSpec, + tiling_result, + preprocessing: PreprocessingConfig, +) -> Path | None: + if preprocessing.read_tiles_from is not None: + return _tile_store_archive_path(preprocessing.read_tiles_from, slide.sample_id) + return getattr(tiling_result, "tiles_tar_path", None) + + +def _tile_store_archive_path(tile_store_root: Path, sample_id: str) -> Path: + root = Path(tile_store_root) + if root.is_file(): + return root + if root.suffix == ".tar" and root.exists(): + return root + return root / f"{sample_id}.tiles.tar" + + def _load_process_df( process_list_path: Path, *, @@ -1344,10 +1887,10 @@ def _load_tiling_result_from_row(row): return load_tiling_result_from_row(row) -def _load_tiling_result(tiles_npz_path: Path, tiles_meta_path: Path): +def _load_tiling_result(coordinates_npz_path: Path, coordinates_meta_path: Path): from hs2p import load_tiling_result - return load_tiling_result(tiles_npz_path=tiles_npz_path, tiles_meta_path=tiles_meta_path) + return load_tiling_result(coordinates_npz_path=coordinates_npz_path, coordinates_meta_path=coordinates_meta_path) def _scale_coordinates(wsi_fp: Path, coordinates: np.ndarray, spacing: float, backend: str): @@ -1373,12 +1916,22 @@ def _maybe_import_torch(): return torch -def _resolve_backend(preprocessing: PreprocessingConfig | None) -> str: +def _resolve_tiling_backend(preprocessing: PreprocessingConfig | None) -> str: if preprocessing is None: return "asap" return preprocessing.backend +def _resolve_slide_backend(preprocessing: PreprocessingConfig | None, tiling_result) -> str: + backend = _resolve_tiling_backend(preprocessing) + if backend != "auto": + return backend + resolved_backend = getattr(tiling_result, "backend", None) + if isinstance(resolved_backend, str) and resolved_backend and resolved_backend != "auto": + return resolved_backend + return "asap" + + def _validate_multi_gpu_execution(model, execution: ExecutionOptions) -> None: if model._requested_device == "cpu": raise ValueError("ExecutionOptions.num_gpus > 1 is incompatible with device='cpu'") @@ -1743,6 +2296,7 @@ def _serialize_preprocessing(preprocessing: PreprocessingConfig) -> dict[str, An "tissue_threshold": preprocessing.tissue_threshold, "drop_holes": preprocessing.drop_holes, "use_padding": preprocessing.use_padding, + "read_coordinates_from": str(preprocessing.read_coordinates_from) if preprocessing.read_coordinates_from is not None else None, "read_tiles_from": str(preprocessing.read_tiles_from) if preprocessing.read_tiles_from is not None else None, "resume": preprocessing.resume, "segmentation": dict(preprocessing.segmentation), @@ -1758,7 +2312,10 @@ def _serialize_execution(execution: ExecutionOptions) -> dict[str, Any]: "batch_size": execution.batch_size, "num_workers": execution.num_workers, "num_gpus": execution.num_gpus, - "mixed_precision": execution.mixed_precision, + "precision": execution.precision, + "prefetch_factor": execution.prefetch_factor, + "persistent_workers": execution.persistent_workers, + "gpu_batch_preprocessing": execution.gpu_batch_preprocessing, "save_tile_embeddings": execution.save_tile_embeddings, "save_latents": execution.save_latents, } @@ -1774,6 +2331,7 @@ def deserialize_preprocessing(payload: dict[str, Any]) -> PreprocessingConfig: tissue_threshold=float(payload["tissue_threshold"]), drop_holes=bool(payload["drop_holes"]), use_padding=bool(payload["use_padding"]), + read_coordinates_from=Path(payload["read_coordinates_from"]) if payload.get("read_coordinates_from") else None, read_tiles_from=Path(payload["read_tiles_from"]) if payload.get("read_tiles_from") else None, resume=bool(payload.get("resume", False)), segmentation=dict(payload.get("segmentation", {})), @@ -1790,12 +2348,23 @@ def deserialize_execution(payload: dict[str, Any]) -> ExecutionOptions: batch_size=payload.get("batch_size"), num_workers=int(payload.get("num_workers", 0)), num_gpus=int(payload.get("num_gpus", 1)), - mixed_precision=bool(payload.get("mixed_precision", False)), + precision=payload.get("precision", "fp32"), + prefetch_factor=int(payload.get("prefetch_factor", 4)), + persistent_workers=bool(payload.get("persistent_workers", True)), + gpu_batch_preprocessing=bool(payload.get("gpu_batch_preprocessing", True)), save_tile_embeddings=bool(payload.get("save_tile_embeddings", False)), save_latents=bool(payload.get("save_latents", False)), ) +def _autocast_dtype(torch, precision: str): + if precision == "fp16": + return torch.float16 + if precision == "bf16": + return torch.bfloat16 + return None + + def _collect_pipeline_artifacts( slide_records: Sequence[SlideSpec], *, diff --git a/slide2vec/model_settings.py b/slide2vec/model_settings.py new file mode 100644 index 0000000..7112cd4 --- /dev/null +++ b/slide2vec/model_settings.py @@ -0,0 +1,196 @@ +import logging +from dataclasses import dataclass +from typing import Any + +logger = logging.getLogger("slide2vec") + + +@dataclass(frozen=True) +class RecommendedModelSettings: + input_size: tuple[int, int] + spacings_um: tuple[float, ...] + precision: str | None = None + + +PRECISION_ALIASES = { + "fp32": "fp32", + "float32": "fp32", + "32": "fp32", + "fp16": "fp16", + "float16": "fp16", + "16": "fp16", + "half": "fp16", + "bf16": "bf16", + "bfloat16": "bf16", +} + + +def _square_settings( + size: int, + spacings_um: list[float], + *, + precision: str | None = None, +) -> RecommendedModelSettings: + return RecommendedModelSettings( + input_size=(int(size), int(size)), + spacings_um=tuple(float(value) for value in spacings_um), + precision=normalize_precision_name(precision), + ) + + +def normalize_precision_name(value: Any) -> str | None: + if value is None: + return None + normalized = str(value).strip().lower() + if normalized not in PRECISION_ALIASES: + supported = ", ".join(sorted(PRECISION_ALIASES)) + raise ValueError(f"Unsupported precision {value!r}. Expected one of: {supported}") + return PRECISION_ALIASES[normalized] + + +MODEL_NAME_ALIASES = { + "conch-v1.5": "conchv15", + "conch_v15": "conchv15", + "conchv1.5": "conchv15", + "conchv1_5": "conchv15", + "phikon-v2": "phikonv2", + "hibou-b": "hibou", + "hibou-l": "hibou", + "h-optimus-0-mini": "h0-mini", + "prov-gigapath-tile": "prov-gigapath", + "prov-gigapath-slide": "prov-gigapath", +} + + +RECOMMENDED_MODEL_SETTINGS = { + "conch": _square_settings(448, [0.5], precision="fp32"), + "conchv15": _square_settings(448, [0.5], precision="fp16"), + "h0-mini": _square_settings(224, [0.5], precision="fp16"), + "h-optimus-0": _square_settings(224, [0.5], precision="fp16"), + "h-optimus-1": _square_settings(224, [0.5], precision="fp16"), + "hibou": _square_settings(224, [0.5], precision="fp16"), + "kaiko": _square_settings(224, [2.0, 1.0, 0.5, 0.25], precision="fp32"), + "kaiko-midnight": _square_settings(224, [2.0, 1.0, 0.5, 0.25], precision="fp16"), + "musk": _square_settings(384, [1.0, 0.5, 0.25], precision="fp16"), + "panda-vit-s": _square_settings(224, [0.5], precision="fp32"), + "pathojepa": _square_settings(224, [0.5], precision="fp32"), + "phikon": _square_settings(224, [0.5], precision="fp32"), + "phikonv2": _square_settings(224, [0.5], precision="fp32"), + "prism": _square_settings(224, [0.5], precision="fp16"), + "prov-gigapath": _square_settings(256, [0.5], precision="fp16"), + "rumc-vit-s-50k": _square_settings(224, [0.5]), + "titan": _square_settings(512, [0.5], precision="fp16"), + "uni": _square_settings(224, [0.5], precision="fp16"), + "uni2": _square_settings(224, [0.5], precision="bf16"), + "virchow": _square_settings(224, [0.5], precision="fp16"), + "virchow2": _square_settings(224, [2.0, 1.0, 0.5, 0.25], precision="fp16"), +} + + +def canonicalize_model_name(name: str) -> str: + normalized = name.strip().lower() + return MODEL_NAME_ALIASES.get(normalized, normalized) + + +def get_recommended_model_settings(name: str | None) -> RecommendedModelSettings | None: + if not name: + return None + return RECOMMENDED_MODEL_SETTINGS.get(canonicalize_model_name(name)) + + +def validate_model_settings( + *, + model_name: str | None, + requested_input_size: Any = None, + target_spacing_um: float | None = None, + requested_precision: Any = None, + allow_non_recommended_settings: bool = False, +) -> None: + settings = get_recommended_model_settings(model_name) + if settings is None: + return + + mismatches: list[str] = [] + normalized_input_size = _normalize_input_size(requested_input_size) + if normalized_input_size is not None and normalized_input_size != settings.input_size: + mismatches.append( + f"requested input_size={normalized_input_size[0]}x{normalized_input_size[1]} " + f"(recommended: {settings.input_size[0]}x{settings.input_size[1]})" + ) + + if target_spacing_um is not None and not _matches_supported_spacing( + float(target_spacing_um), settings.spacings_um + ): + supported_spacings = ", ".join(f"{spacing:g}" for spacing in settings.spacings_um) + mismatches.append( + f"requested target_spacing_um={float(target_spacing_um):g} " + f"(recommended: [{supported_spacings}])" + ) + + normalized_precision = normalize_precision_name(requested_precision) + if ( + normalized_precision is not None + and settings.precision is not None + and normalized_precision != settings.precision + ): + mismatches.append( + f"requested precision={normalized_precision} " + f"(recommended: {settings.precision})" + ) + + if not mismatches: + return + + message = ( + f"Model '{canonicalize_model_name(model_name)}' is configured with " + f"{'; '.join(mismatches)}. " + "Set `model.allow_non_recommended_settings=true` in YAML/CLI or " + "`allow_non_recommended_settings=True` in `Model.from_pretrained(...)` " + "to continue with a warning." + ) + if allow_non_recommended_settings: + logger.warning(message) + return + raise ValueError(message) + + +def validate_model_runtime_compatibility(model, preprocessing, execution=None) -> None: + validate_model_settings( + model_name=getattr(model, "name", None), + requested_input_size=_requested_input_size(model, preprocessing), + target_spacing_um=getattr(preprocessing, "target_spacing_um", None), + requested_precision=_requested_precision(model, execution), + allow_non_recommended_settings=bool( + getattr(model, "allow_non_recommended_settings", False) + ), + ) + + +def _normalize_input_size(value: Any) -> tuple[int, int] | None: + if value is None: + return None + if isinstance(value, (tuple, list)): + if len(value) != 2: + raise ValueError(f"Expected input_size to have two dimensions, got {value!r}") + return int(value[0]), int(value[1]) + size = int(value) + return size, size + + +def _matches_supported_spacing(value: float, supported_spacings: tuple[float, ...]) -> bool: + return any(abs(value - supported) <= 1e-8 for supported in supported_spacings) + + +def _requested_input_size(model, preprocessing) -> int | None: + explicit_input_size = getattr(model, "_model_kwargs", {}).get("input_size") + if explicit_input_size is not None: + return int(explicit_input_size) + if getattr(model, "level", None) in {"tile", "slide"}: + return int(getattr(preprocessing, "target_tile_size_px")) + return None + + +def _requested_precision(model, execution) -> str | None: + if getattr(model, "_requested_device", None) == "cpu": + return None + return getattr(execution, "precision", None) diff --git a/slide2vec/models/models.py b/slide2vec/models/models.py index 3f17d17..167ddc8 100644 --- a/slide2vec/models/models.py +++ b/slide2vec/models/models.py @@ -75,6 +75,10 @@ def _select_mode_embedding(cls_embedding, patch_embeddings, *, mode: str): return cls_embedding +def _resolve_mode(mode: str | None, *, default: str) -> str: + return default if mode is None else mode + + def _build_timm_hub_encoder(model_name: str, **kwargs): return timm.create_model(model_name, pretrained=True, **kwargs) @@ -89,6 +93,7 @@ def _build_tile_model(options: DictConfig): "h-optimus-0": Hoptimus0, "h-optimus-1": Hoptimus1, "conch": CONCH, + "conchv15": CONCHv15, "musk": MUSK, "phikonv2": PhikonV2, } @@ -113,14 +118,15 @@ def _build_tile_model(options: DictConfig): def _build_region_tile_encoder(options: DictConfig): region_factories = { - "virchow": Virchow, - "virchow2": Virchow2, + "virchow": lambda: Virchow(mode=options.mode), + "virchow2": lambda: Virchow2(mode=options.mode), "uni": UNI, "uni2": UNI2, "prov-gigapath": ProvGigaPath, "h-optimus-0": Hoptimus0, "h-optimus-1": Hoptimus1, "conch": CONCH, + "conchv15": CONCHv15, "musk": MUSK, "phikonv2": PhikonV2, } @@ -529,10 +535,10 @@ def forward(self, x): class Virchow(FeatureExtractor): - def __init__(self, mode: str = "cls"): - self.mode = mode + def __init__(self, mode: str | None = None): + self.mode = _resolve_mode(mode, default="full") self.features_dim = 1280 - if mode == "full": + if self.mode == "full": self.features_dim = 2560 super(Virchow, self).__init__() @@ -555,10 +561,10 @@ def forward(self, x): class Virchow2(FeatureExtractor): - def __init__(self, mode: str = "cls"): - self.mode = mode + def __init__(self, mode: str | None = None): + self.mode = _resolve_mode(mode, default="full") self.features_dim = 1280 - if mode == "full": + if self.mode == "full": self.features_dim = 2560 super(Virchow2, self).__init__() @@ -633,10 +639,10 @@ def forward(self, x): class Hoptimus0Mini(FeatureExtractor): - def __init__(self, mode: str = "cls"): - self.mode = mode + def __init__(self, mode: str | None = None): + self.mode = _resolve_mode(mode, default="cls") self.features_dim = 768 - if mode == "full": + if self.mode == "full": self.features_dim = 1536 super(Hoptimus0Mini, self).__init__() @@ -681,6 +687,25 @@ def forward(self, x): return _embedding_output(embedding) +class CONCHv15(FeatureExtractor): + def __init__(self): + self.features_dim = 768 + super(CONCHv15, self).__init__() + + def build_encoder(self): + titan = AutoModel.from_pretrained("MahmoodLab/TITAN", trust_remote_code=True) + encoder, transform = titan.return_conch() + self.transform = transform + return encoder + + def get_transforms(self): + return self.transform + + def forward(self, x): + embedding = self.encoder(x) + return _embedding_output(embedding) + + class MUSK(FeatureExtractor): def __init__(self): self.features_dim = 2048 @@ -712,7 +737,7 @@ def forward(self, x): image=x, with_head=False, out_norm=False, - ms_aug=True, + ms_aug=False, return_global=True, )[0] return _embedding_output(embedding) diff --git a/slide2vec/progress.py b/slide2vec/progress.py index 51ca3a1..420b37f 100644 --- a/slide2vec/progress.py +++ b/slide2vec/progress.py @@ -40,16 +40,23 @@ def write_log(self, message: str, *, stream=None) -> None: class JsonlProgressReporter: - def __init__(self, path: str | Path, *, rank: int | None = None) -> None: + def __init__( + self, + path: str | Path, + *, + rank: int | None = None, + progress_label: str | None = None, + ) -> None: base_path = Path(path) self.path = ranked_progress_events_path(base_path, rank) if rank is not None else base_path self.path.parent.mkdir(parents=True, exist_ok=True) self._handle = self.path.open("a", encoding="utf-8", buffering=1) + self.progress_label = progress_label def emit(self, event: ProgressEvent) -> None: payload = { "kind": event.kind, - "payload": event.payload, + "payload": _with_progress_label(event.payload, self.progress_label), "timestamp": time.time(), } self._handle.write(json.dumps(payload, sort_keys=True) + "\n") @@ -106,16 +113,23 @@ def _format_line(self, kind: str, payload: dict[str, Any]) -> str | None: f"Tiling finished: {payload['completed']}/{payload['total']} complete, " f"{payload['failed']} failed, {payload['discovered_tiles']} tiles" ) + if kind == "model.loading": + return f"Loading model {payload['model_name']}..." + if kind == "model.ready": + return f"Model {payload['model_name']} ready on {payload['device']}" if kind == "embedding.started": return f"Embedding slides ({payload['slide_count']} total)..." if kind == "embedding.slide.started": - return f"Embedding {payload['sample_id']} ({payload['total_tiles']} tiles)..." + return f"Embedding {_progress_subject(payload)} ({payload['total_tiles']} tiles)..." if kind == "embedding.tile.progress": - return f"Embedding {payload['sample_id']}: {payload['processed']}/{payload['total']} {payload['unit']}s" + return ( + f"Embedding {_progress_subject(payload)}: " + f"{payload['processed']}/{payload['total']} {payload['unit']}s" + ) if kind == "aggregation.started": - return f"Aggregating slide embedding for {payload['sample_id']}..." + return f"Aggregating slide embedding for {_progress_subject(payload)}..." if kind == "embedding.slide.finished": - return f"Completed {payload['sample_id']} ({payload['num_tiles']} tiles)" + return f"Completed {_progress_subject(payload)} ({payload['num_tiles']} tiles)" if kind == "embedding.finished": return ( f"Embedding finished: {payload['slides_completed']}/{payload['slide_count']} slides, " @@ -157,6 +171,8 @@ def __init__(self, *, output_dir: str | Path | None = None, console=None) -> Non ) self.progress.start() self._task_ids: dict[str, int] = {} + self._model_loading_counts: dict[str, int] = {} + self._model_loading_devices: dict[str, set[str]] = {} def emit(self, event: ProgressEvent) -> None: kind = event.kind @@ -193,50 +209,103 @@ def emit(self, event: ProgressEvent) -> None: ], ) return + if kind == "model.loading": + model_name = str(payload["model_name"]) + count = self._model_loading_counts.get(model_name, 0) + 1 + self._model_loading_counts[model_name] = count + task_id = self._task_ids.get("model_loading") + description = _model_loading_description(model_name, count) + if task_id is None: + self._task_ids["model_loading"] = self.progress.add_task( + description, + total=None, + ) + else: + self.progress.update(task_id, description=description) + return + if kind == "model.ready": + model_name = str(payload["model_name"]) + device = str(payload["device"]) + remaining = self._model_loading_counts.get(model_name, 0) + if remaining <= 0: + self.console.print( + f"[green]Model [bold]{model_name}[/bold] ready[/green] on {device}" + ) + return + devices = self._model_loading_devices.setdefault(model_name, set()) + devices.add(device) + remaining -= 1 + if remaining > 0: + self._model_loading_counts[model_name] = remaining + task_id = self._task_ids.get("model_loading") + if task_id is not None: + self.progress.update( + task_id, + description=_model_loading_description(model_name, remaining), + ) + return + self._model_loading_counts.pop(model_name, None) + devices = self._model_loading_devices.pop(model_name, set()) + task_id = self._task_ids.pop("model_loading", None) + if task_id is not None: + self.progress.remove_task(task_id) + if len(devices) > 1: + self.console.print( + f"[green]Model [bold]{model_name}[/bold] ready[/green] on {len(devices)} GPUs" + ) + else: + self.console.print( + f"[green]Model [bold]{model_name}[/bold] ready[/green] on {device}" + ) + return if kind == "embedding.started": self._task_ids["embedding"] = self.progress.add_task("Embedding slides", total=payload["slide_count"]) return if kind == "embedding.slide.started": - tile_task = self._task_ids.get("tiles") + tile_task_key = _progress_task_key("tiles", payload) + tile_task = self._task_ids.get(tile_task_key) + description = _progress_subject(payload) if tile_task is None: - self._task_ids["tiles"] = self.progress.add_task( - f"{payload['sample_id']}", + self._task_ids[tile_task_key] = self.progress.add_task( + description, total=payload["total_tiles"], ) else: self.progress.update( tile_task, - description=payload["sample_id"], + description=description, total=payload["total_tiles"], completed=0, visible=True, ) return if kind == "embedding.tile.progress": - task_id = self._task_ids.get("tiles") + task_id = self._task_ids.get(_progress_task_key("tiles", payload)) if task_id is not None: self.progress.update(task_id, completed=payload["processed"], total=payload["total"]) return if kind == "aggregation.started": - if "aggregation" not in self._task_ids: - self._task_ids["aggregation"] = self.progress.add_task( - f"Aggregating {payload['sample_id']}", + aggregation_task_key = _progress_task_key("aggregation", payload) + description = f"Aggregating {_progress_subject(payload)}" + if aggregation_task_key not in self._task_ids: + self._task_ids[aggregation_task_key] = self.progress.add_task( + description, total=None, ) else: - self.progress.update(self._task_ids["aggregation"], description=f"Aggregating {payload['sample_id']}") + self.progress.update(self._task_ids[aggregation_task_key], description=description) return if kind == "aggregation.finished": - task_id = self._task_ids.get("aggregation") + task_id = self._task_ids.get(_progress_task_key("aggregation", payload)) if task_id is not None: self.progress.remove_task(task_id) - self._task_ids.pop("aggregation", None) + self._task_ids.pop(_progress_task_key("aggregation", payload), None) return if kind == "embedding.slide.finished": embed_task = self._task_ids.get("embedding") if embed_task is not None: self.progress.advance(embed_task, 1) - tile_task = self._task_ids.get("tiles") + tile_task = self._task_ids.get(_progress_task_key("tiles", payload)) if tile_task is not None: self.progress.update(tile_task, completed=payload["num_tiles"]) return @@ -336,6 +405,42 @@ def ranked_progress_events_path(base_path: str | Path, rank: int) -> Path: return path.with_name(f"{path.stem}.rank{rank}{path.suffix}") +def _model_loading_description(model_name: str, worker_count: int) -> str: + if worker_count <= 1: + return f"Loading model [bold]{model_name}[/bold]..." + return f"Loading model [bold]{model_name}[/bold] on {worker_count} GPUs..." + + +def _with_progress_label(payload: dict[str, Any], progress_label: str | None) -> dict[str, Any]: + if progress_label is None or "progress_label" in payload: + return dict(payload) + tagged_payload = dict(payload) + tagged_payload["progress_label"] = progress_label + return tagged_payload + + +def _progress_label(payload: dict[str, Any]) -> str | None: + label = payload.get("progress_label") + if label is None or label == "": + return None + return str(label) + + +def _progress_subject(payload: dict[str, Any]) -> str: + sample_id = str(payload["sample_id"]) + label = _progress_label(payload) + if label is None: + return sample_id + return f"{label}: {sample_id}" + + +def _progress_task_key(base: str, payload: dict[str, Any]) -> str: + label = _progress_label(payload) + if label is None: + return base + return f"{base}:{label}" + + def _embedding_summary_rows(payload: dict[str, Any]) -> list[tuple[str, str]]: slide_count = int(payload["slide_count"]) completed = int(payload["slides_completed"]) diff --git a/slide2vec/utils/config.py b/slide2vec/utils/config.py index e835f3c..80b55c7 100644 --- a/slide2vec/utils/config.py +++ b/slide2vec/utils/config.py @@ -7,6 +7,7 @@ from omegaconf import OmegaConf import slide2vec.distributed as distributed +from slide2vec.model_settings import validate_model_settings from slide2vec.utils import initialize_wandb, fix_random_seeds, get_sha, setup_logging from slide2vec.configs import default_preprocessing_config, default_model_config @@ -28,6 +29,21 @@ def validate_removed_options(cfg) -> None: ) +def validate_model_recommended_settings(cfg, *, run_on_cpu: bool = False) -> None: + model_cfg = getattr(cfg, "model", None) + tiling = getattr(cfg, "tiling", None) + tiling_params = getattr(tiling, "params", None) if tiling is not None else None + validate_model_settings( + model_name=getattr(model_cfg, "name", None), + requested_input_size=getattr(model_cfg, "input_size", None), + target_spacing_um=getattr(tiling_params, "target_spacing_um", None), + requested_precision=None if run_on_cpu else getattr(getattr(cfg, "speed", None), "precision", None), + allow_non_recommended_settings=bool( + getattr(model_cfg, "allow_non_recommended_settings", False) + ), + ) + + def write_config(cfg, output_dir, name="config.yaml"): logger.info(OmegaConf.to_yaml(cfg)) saved_cfg_path = os.path.join(output_dir, name) @@ -47,6 +63,7 @@ def get_cfg_from_args(args): cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) OmegaConf.resolve(cfg) validate_removed_options(cfg) + validate_model_recommended_settings(cfg, run_on_cpu=bool(getattr(args, "run_on_cpu", False))) return cfg diff --git a/slide2vec/utils/tiling_io.py b/slide2vec/utils/tiling_io.py index 2d82fa2..4cfd308 100644 --- a/slide2vec/utils/tiling_io.py +++ b/slide2vec/utils/tiling_io.py @@ -14,8 +14,8 @@ "mask_path", "tiling_status", "num_tiles", - "tiles_npz_path", - "tiles_meta_path", + "coordinates_npz_path", + "coordinates_meta_path", "error", "traceback", ) @@ -105,6 +105,8 @@ def load_process_df( needs_feature_status = include_feature_status or include_aggregation_status if "spacing_at_level_0" not in df.columns: df["spacing_at_level_0"] = [None] * len(df) + if "tiles_tar_path" not in df.columns: + df["tiles_tar_path"] = [None] * len(df) if needs_feature_status and "feature_status" not in df.columns: df["feature_status"] = ["tbp"] * len(df) if include_aggregation_status and "aggregation_status" not in df.columns: @@ -116,9 +118,10 @@ def load_process_df( "spacing_at_level_0", "tiling_status", "num_tiles", - "tiles_npz_path", - "tiles_meta_path", + "coordinates_npz_path", + "coordinates_meta_path", ] + ordered_columns.append("tiles_tar_path") if needs_feature_status: ordered_columns.append("feature_status") if include_aggregation_status: @@ -130,9 +133,10 @@ def load_process_df( def load_tiling_result_from_row(row): hs2p = _hs2p_exports() tiling_result = hs2p["load_tiling_result"]( - tiles_npz_path=Path(row["tiles_npz_path"]), - tiles_meta_path=Path(row["tiles_meta_path"]), + coordinates_npz_path=Path(row["coordinates_npz_path"]), + coordinates_meta_path=Path(row["coordinates_meta_path"]), ) - setattr(tiling_result, "tiles_npz_path", Path(row["tiles_npz_path"])) - setattr(tiling_result, "tiles_meta_path", Path(row["tiles_meta_path"])) + setattr(tiling_result, "coordinates_npz_path", Path(row["coordinates_npz_path"])) + setattr(tiling_result, "coordinates_meta_path", Path(row["coordinates_meta_path"])) + setattr(tiling_result, "tiles_tar_path", _optional_path(row.get("tiles_tar_path"))) return tiling_result diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..86d5548 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,19 @@ +import importlib.util +import sys +import types + + +if importlib.util.find_spec("transformers") is None: + transformers_module = types.ModuleType("transformers") + transformers_module.__path__ = [] + + image_processing_utils = types.ModuleType("transformers.image_processing_utils") + + class BaseImageProcessor: + pass + + image_processing_utils.BaseImageProcessor = BaseImageProcessor + transformers_module.image_processing_utils = image_processing_utils + + sys.modules.setdefault("transformers", transformers_module) + sys.modules["transformers.image_processing_utils"] = image_processing_utils diff --git a/tests/fixtures/gt/test-wsi.tiles.meta.json b/tests/fixtures/gt/test-wsi.coordinates.meta.json similarity index 100% rename from tests/fixtures/gt/test-wsi.tiles.meta.json rename to tests/fixtures/gt/test-wsi.coordinates.meta.json diff --git a/tests/fixtures/gt/test-wsi.tiles.npz b/tests/fixtures/gt/test-wsi.coordinates.npz similarity index 100% rename from tests/fixtures/gt/test-wsi.tiles.npz rename to tests/fixtures/gt/test-wsi.coordinates.npz diff --git a/tests/fixtures/gt/test-wsi.pt b/tests/fixtures/gt/test-wsi.pt index e50b8e1..d7ac0a0 100644 Binary files a/tests/fixtures/gt/test-wsi.pt and b/tests/fixtures/gt/test-wsi.pt differ diff --git a/tests/test_batch_collator_timing.py b/tests/test_batch_collator_timing.py new file mode 100644 index 0000000..e9e8c5a --- /dev/null +++ b/tests/test_batch_collator_timing.py @@ -0,0 +1,104 @@ +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest + + +def test_batch_tile_collator_emits_worker_and_reader_timing(monkeypatch: pytest.MonkeyPatch): + torch = pytest.importorskip("torch") + from slide2vec.data import dataset + + timings = { + "reader_open_ms": 1.25, + "reader_read_ms": 8.5, + } + + class FakeReader: + def __init__(self, tar_path: Path, tile_size_px: int): + self.tar_path = tar_path + self.tile_size_px = tile_size_px + + def read_batch_with_timing(self, tile_indices): + tensor = torch.zeros((len(tile_indices), 3, self.tile_size_px, self.tile_size_px), dtype=torch.uint8) + return tensor, dict(timings) + + monkeypatch.setattr(dataset, "TarTileReader", FakeReader) + + collator = dataset.BatchTileCollator( + tar_path=Path("/tmp/fake.tiles.tar"), + tiling_result=SimpleNamespace(target_tile_size_px=4), + ) + + indices, tensor, timing = collator([2, 5]) + + assert indices.tolist() == [2, 5] + assert tuple(tensor.shape) == (2, 3, 4, 4) + assert timing["reader_open_ms"] == pytest.approx(1.25) + assert timing["reader_read_ms"] == pytest.approx(8.5) + assert timing["worker_batch_ms"] >= 0.0 + + +def test_wsd_collator_emits_worker_and_reader_timing(monkeypatch: pytest.MonkeyPatch): + torch = pytest.importorskip("torch") + import slide2vec.data.wsd_tile_reader as wsd_tile_reader + + class FakeReader: + ordered_indices = None + + def __init__(self, image_path, tiling_result, *, backend: str, use_supertiles: bool): + self.tile_size = int(tiling_result.read_tile_size_px) + + def read_batch_with_timing(self, tile_indices): + tensor = torch.zeros((len(tile_indices), 3, self.tile_size, self.tile_size), dtype=torch.uint8) + return tensor, {"reader_open_ms": 0.75, "reader_read_ms": 5.5} + + monkeypatch.setattr(wsd_tile_reader, "WSDTileReader", FakeReader) + + collator = wsd_tile_reader.WSDOnTheFlyBatchTileCollator( + image_path=Path("/tmp/fake.svs"), + tiling_result=SimpleNamespace(read_tile_size_px=4), + backend="asap", + use_supertiles=False, + ) + + indices, tensor, timing = collator([1, 3]) + + assert indices.tolist() == [1, 3] + assert tuple(tensor.shape) == (2, 3, 4, 4) + assert timing["reader_open_ms"] == pytest.approx(0.75) + assert timing["reader_read_ms"] == pytest.approx(5.5) + assert timing["worker_batch_ms"] >= 0.0 + + +def test_cucim_collator_emits_worker_and_reader_timing(monkeypatch: pytest.MonkeyPatch): + torch = pytest.importorskip("torch") + import slide2vec.data.cucim_tile_reader as cucim_tile_reader + + class FakeReader: + ordered_indices = None + + def __init__(self, image_path, tiling_result, *, num_cucim_workers: int, gpu_decode: bool, use_supertiles: bool): + self.tile_size = int(tiling_result.read_tile_size_px) + + def read_batch_with_timing(self, tile_indices): + tensor = torch.zeros((len(tile_indices), 3, self.tile_size, self.tile_size), dtype=torch.uint8) + return tensor, {"reader_open_ms": 2.0, "reader_read_ms": 7.25} + + monkeypatch.setattr(cucim_tile_reader, "CuCIMTileReader", FakeReader) + + collator = cucim_tile_reader.OnTheFlyBatchTileCollator( + image_path=Path("/tmp/fake.svs"), + tiling_result=SimpleNamespace(read_tile_size_px=4), + num_cucim_workers=4, + gpu_decode=False, + use_supertiles=False, + ) + + indices, tensor, timing = collator([0, 4]) + + assert indices.tolist() == [0, 4] + assert tuple(tensor.shape) == (2, 3, 4, 4) + assert timing["reader_open_ms"] == pytest.approx(2.0) + assert timing["reader_read_ms"] == pytest.approx(7.25) + assert timing["worker_batch_ms"] >= 0.0 diff --git a/tests/test_benchmark_embedding_throughput.py b/tests/test_benchmark_embedding_throughput.py new file mode 100644 index 0000000..74c4763 --- /dev/null +++ b/tests/test_benchmark_embedding_throughput.py @@ -0,0 +1,767 @@ +import csv +import importlib.util +import json +import sys +import types +from pathlib import Path + +import pytest + + +ROOT = Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "scripts" / "benchmark_embedding_throughput.py" + + +@pytest.fixture(scope="module") +def benchmark_module(): + spec = importlib.util.spec_from_file_location("benchmark_embedding_throughput", MODULE_PATH) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load benchmark module from {MODULE_PATH}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_build_balanced_sample_stratifies_by_file_size(benchmark_module): + slides = [ + {"sample_id": f"slide-{idx}", "image_path": Path(f"/tmp/slide-{idx}.svs"), "mask_path": None, "size_bytes": size} + for idx, size in enumerate([10, 20, 30, 40, 50, 60, 70, 80, 90], start=1) + ] + + sampled = benchmark_module.build_balanced_sample(slides, n_slides=6, seed=7) + + assert len(sampled) == 6 + sizes = sorted(slide["size_bytes"] for slide in sampled) + assert sum(size < 37 for size in sizes) == 2 + assert sum(37 <= size < 63 for size in sizes) == 2 + assert sum(size >= 63 for size in sizes) == 2 + + +def test_write_slides_csv_preserves_optional_spacing(benchmark_module, tmp_path: Path): + slides = [ + { + "sample_id": "slide-a", + "image_path": Path("/tmp/slide-a.svs"), + "mask_path": None, + "spacing_at_level_0": 0.25, + }, + { + "sample_id": "slide-b", + "image_path": Path("/tmp/slide-b.svs"), + "mask_path": Path("/tmp/slide-b.png"), + "spacing_at_level_0": None, + }, + ] + csv_path = tmp_path / "slides.csv" + + benchmark_module.write_slides_csv(slides, csv_path) + + with csv_path.open(newline="") as handle: + rows = list(csv.DictReader(handle)) + + assert rows == [ + { + "sample_id": "slide-a", + "image_path": "/tmp/slide-a.svs", + "mask_path": "", + "spacing_at_level_0": "0.25", + }, + { + "sample_id": "slide-b", + "image_path": "/tmp/slide-b.svs", + "mask_path": "/tmp/slide-b.png", + "spacing_at_level_0": "", + }, + ] + + +def test_prepend_repo_root_to_sys_path_places_repo_first_without_duplicates(benchmark_module): + repo_root = str(benchmark_module.REPO_ROOT) + starting = ["/tmp/elsewhere", repo_root, "/tmp/more"] + + updated = benchmark_module._prepend_repo_root_to_sys_path(starting) + + assert updated[0] == repo_root + assert updated.count(repo_root) == 1 + assert updated[1:] == ["/tmp/elsewhere", "/tmp/more"] + + +def test_build_trial_config_normalizes_non_benchmark_runtime_flags(benchmark_module, tmp_path: Path): + base_config = { + "csv": "/original/slides.csv", + "output_dir": "/original/output", + "resume": True, + "save_previews": True, + "model": { + "name": "virchow2", + "level": "tile", + "batch_size": 8, + }, + "tiling": { + "backend": "asap", + "params": {"target_spacing_um": 0.5, "target_tile_size_px": 224}, + }, + "speed": { + "num_workers": 12, + "num_workers_embedding": 6, + "precision": "fp16", + }, + "wandb": {"enable": True}, + } + + cfg = benchmark_module.build_trial_config( + base_config, + csv_path=tmp_path / "slides.csv", + output_dir=tmp_path / "trial-output", + batch_size=32, + embedding_workers=3, + ) + + assert cfg.csv == str(tmp_path / "slides.csv") + assert cfg.output_dir == str(tmp_path / "trial-output") + assert cfg.resume is False + assert cfg.save_previews is False + assert cfg.wandb.enable is False + assert cfg.model.batch_size == 32 + assert cfg.speed.num_workers_embedding == 3 + assert cfg.speed.num_workers == 12 + assert cfg.speed.precision == "fp16" + assert cfg.tiling.backend == "asap" + + +def test_load_cli_merged_config_uses_regular_cli_loader( + benchmark_module, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + config_path = tmp_path / "benchmark.yaml" + config_path.write_text( + "\n".join( + [ + 'model:', + ' name: "h0-mini"', + 'tiling:', + ' params:', + ' target_spacing_um: 0.5', + ' target_tile_size_px: 224', + 'speed: {}', + 'wandb:', + ' enable: false', + ] + ) + + "\n", + encoding="utf-8", + ) + + captured: dict[str, object] = {} + + def fake_get_cfg_from_args(args): + captured["config_file"] = args.config_file + captured["output_dir"] = args.output_dir + captured["opts"] = list(args.opts) + return { + "model": {"name": "h0-mini", "mode": "cls"}, + "speed": {"num_workers_embedding": 8}, + "wandb": {"enable": False}, + } + + monkeypatch.setitem( + sys.modules, + "slide2vec.utils.config", + types.SimpleNamespace(get_cfg_from_args=fake_get_cfg_from_args), + ) + + loaded = benchmark_module._load_cli_merged_config(config_path) + + assert captured == { + "config_file": str(config_path), + "output_dir": None, + "opts": [], + } + assert loaded == { + "model": {"name": "h0-mini", "mode": "cls"}, + "speed": {"num_workers_embedding": 8}, + "wandb": {"enable": False}, + } + + +def test_extract_stage_seconds_reads_progress_jsonl_and_leaves_missing_stages_none(benchmark_module, tmp_path: Path): + progress_path = tmp_path / "progress.jsonl" + records = [ + {"kind": "tiling.started", "payload": {}, "timestamp": 10.0}, + {"kind": "tiling.finished", "payload": {}, "timestamp": 14.5}, + {"kind": "embedding.started", "payload": {}, "timestamp": 14.5}, + {"kind": "aggregation.started", "payload": {"sample_id": "slide-a"}, "timestamp": 18.0}, + {"kind": "aggregation.finished", "payload": {"sample_id": "slide-a"}, "timestamp": 19.25}, + {"kind": "embedding.finished", "payload": {}, "timestamp": 21.0}, + ] + progress_path.write_text("".join(json.dumps(record) + "\n" for record in records), encoding="utf-8") + + stage_seconds = benchmark_module.extract_stage_seconds(progress_path) + + assert stage_seconds == { + "tiling_seconds": 4.5, + "embedding_seconds": 6.5, + "aggregation_seconds": 1.25, + } + + +def test_extract_batch_timing_metrics_summarizes_loader_and_forward_costs(benchmark_module, tmp_path: Path): + progress_path = tmp_path / "progress.jsonl" + records = [ + { + "kind": "embedding.batch.timing", + "payload": { + "batch_size": 16, + "loader_wait_ms": 10.0, + "ready_wait_ms": 2.0, + "preprocess_ms": 6.0, + "forward_ms": 20.0, + "worker_batch_ms": 9.0, + "reader_open_ms": 1.0, + "reader_read_ms": 8.0, + "gpu_busy_fraction": 0.7000, + }, + "timestamp": 1.0, + }, + { + "kind": "embedding.batch.timing", + "payload": { + "batch_size": 16, + "loader_wait_ms": 14.0, + "ready_wait_ms": 1.0, + "preprocess_ms": 8.0, + "forward_ms": 18.0, + "worker_batch_ms": 13.0, + "reader_open_ms": 0.0, + "reader_read_ms": 12.0, + "gpu_busy_fraction": 0.6429, + }, + "timestamp": 2.0, + }, + ] + progress_path.write_text("".join(json.dumps(record) + "\n" for record in records), encoding="utf-8") + + metrics = benchmark_module.extract_batch_timing_metrics(progress_path) + + assert metrics == { + "timed_batches": 2, + "mean_loader_wait_ms": 12.0, + "max_loader_wait_ms": 14.0, + "mean_ready_wait_ms": 1.5, + "mean_preprocess_ms": 7.0, + "mean_forward_ms": 19.0, + "mean_worker_batch_ms": 11.0, + "mean_reader_open_ms": 0.5, + "mean_reader_read_ms": 10.0, + "loader_wait_fraction": 0.3418, + "gpu_busy_fraction": 0.6714, + } + + +def test_aggregate_and_select_best_results_uses_deterministic_tie_breaks(benchmark_module): + trial_rows = [ + { + "gpu_label": "A100", + "model_label": "PathoJEPA-S", + "size_label": "S", + "config_file": "/tmp/pathojepa-s.yaml", + "batch_size": 64, + "embedding_workers": 8, + "repeat_index": 1, + "tiles_per_second": 100.0, + "end_to_end_seconds": 10.0, + "slides_per_second": 1.0, + "mean_loader_wait_ms": 12.0, + "max_loader_wait_ms": 14.0, + "mean_ready_wait_ms": 1.5, + "mean_preprocess_ms": 7.0, + "mean_forward_ms": 19.0, + "mean_worker_batch_ms": 11.0, + "mean_reader_open_ms": 0.5, + "mean_reader_read_ms": 10.0, + "loader_wait_fraction": 0.4286, + "gpu_busy_fraction": 0.5714, + }, + { + "gpu_label": "A100", + "model_label": "PathoJEPA-S", + "size_label": "S", + "config_file": "/tmp/pathojepa-s.yaml", + "batch_size": 64, + "embedding_workers": 8, + "repeat_index": 2, + "tiles_per_second": 100.0, + "end_to_end_seconds": 10.0, + "slides_per_second": 1.0, + "mean_loader_wait_ms": 10.0, + "max_loader_wait_ms": 12.0, + "mean_ready_wait_ms": 1.0, + "mean_preprocess_ms": 6.0, + "mean_forward_ms": 20.0, + "mean_worker_batch_ms": 10.0, + "mean_reader_open_ms": 0.0, + "mean_reader_read_ms": 9.0, + "loader_wait_fraction": 0.3846, + "gpu_busy_fraction": 0.6154, + }, + { + "gpu_label": "A100", + "model_label": "PathoJEPA-S", + "size_label": "S", + "config_file": "/tmp/pathojepa-s.yaml", + "batch_size": 32, + "embedding_workers": 4, + "repeat_index": 1, + "tiles_per_second": 100.0, + "end_to_end_seconds": 10.0, + "slides_per_second": 1.0, + "mean_loader_wait_ms": 9.0, + "max_loader_wait_ms": 10.0, + "mean_ready_wait_ms": 0.5, + "mean_preprocess_ms": 5.0, + "mean_forward_ms": 21.0, + "mean_worker_batch_ms": 8.0, + "mean_reader_open_ms": 0.0, + "mean_reader_read_ms": 7.0, + "loader_wait_fraction": 0.3214, + "gpu_busy_fraction": 0.6786, + }, + { + "gpu_label": "A100", + "model_label": "PathoJEPA-S", + "size_label": "S", + "config_file": "/tmp/pathojepa-s.yaml", + "batch_size": 32, + "embedding_workers": 4, + "repeat_index": 2, + "tiles_per_second": 100.0, + "end_to_end_seconds": 10.0, + "slides_per_second": 1.0, + "mean_loader_wait_ms": 8.0, + "max_loader_wait_ms": 11.0, + "mean_ready_wait_ms": 0.5, + "mean_preprocess_ms": 5.0, + "mean_forward_ms": 20.0, + "mean_worker_batch_ms": 7.0, + "mean_reader_open_ms": 0.0, + "mean_reader_read_ms": 6.0, + "loader_wait_fraction": 0.2963, + "gpu_busy_fraction": 0.7037, + }, + ] + + aggregated = benchmark_module.aggregate_trial_results(trial_rows) + best = benchmark_module.select_best_results(aggregated) + + assert len(aggregated) == 2 + assert best == [ + { + "gpu_label": "A100", + "model_label": "PathoJEPA-S", + "size_label": "S", + "config_file": "/tmp/pathojepa-s.yaml", + "batch_size": 32, + "embedding_workers": 4, + "num_gpus": 1, + "repeat_count": 2, + "mean_tiles_per_second": 100.0, + "std_tiles_per_second": 0.0, + "mean_end_to_end_seconds": 10.0, + "mean_slides_per_second": 1.0, + "mean_loader_wait_ms": 8.5, + "max_loader_wait_ms": 11.0, + "mean_ready_wait_ms": 0.5, + "mean_preprocess_ms": 5.0, + "mean_forward_ms": 20.5, + "mean_worker_batch_ms": 7.5, + "mean_reader_open_ms": 0.0, + "mean_reader_read_ms": 6.5, + "loader_wait_fraction": 0.3089, + "gpu_busy_fraction": 0.6911, + } + ] + + +def test_load_trial_results_csvs_merges_multiple_files(benchmark_module, tmp_path: Path): + left = tmp_path / "left.csv" + right = tmp_path / "right.csv" + fieldnames = [ + "gpu_label", + "model_label", + "size_label", + "config_file", + "batch_size", + "embedding_workers", + "repeat_index", + "tiles_per_second", + "mean_loader_wait_ms", + ] + with left.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + writer.writerow( + { + "gpu_label": "A100", + "model_label": "PathoJEPA-S", + "size_label": "S", + "config_file": "/tmp/pathojepa-s.yaml", + "batch_size": 32, + "embedding_workers": 4, + "repeat_index": 1, + "tiles_per_second": 123.4, + "mean_loader_wait_ms": 8.5, + } + ) + with right.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + writer.writerow( + { + "gpu_label": "H100", + "model_label": "PathoJEPA-B", + "size_label": "B", + "config_file": "/tmp/pathojepa-b.yaml", + "batch_size": 64, + "embedding_workers": 8, + "repeat_index": 1, + "tiles_per_second": 234.5, + "mean_loader_wait_ms": 6.0, + } + ) + + rows = benchmark_module.load_trial_results_csvs([left, right]) + + assert [row["gpu_label"] for row in rows] == ["A100", "H100"] + assert rows[0]["size_label"] == "S" + assert rows[0]["tiles_per_second"] == 123.4 + assert rows[0]["mean_loader_wait_ms"] == 8.5 + assert rows[1]["batch_size"] == 64 + + +def test_resolve_model_specs_supports_single_config_backwards_compatibility(benchmark_module, tmp_path: Path): + args = benchmark_module.argparse.Namespace( + config_file=tmp_path / "single.yaml", + config_files=None, + model_labels=None, + size_labels=None, + ) + + specs = benchmark_module.resolve_model_specs(args) + + assert specs == [ + { + "config_file": tmp_path / "single.yaml", + "model_label": "single", + "size_label": "unspecified", + } + ] + + +def test_resolve_model_specs_validates_multi_model_label_lengths(benchmark_module, tmp_path: Path): + args = benchmark_module.argparse.Namespace( + config_file=None, + config_files=[tmp_path / "a.yaml", tmp_path / "b.yaml"], + model_labels=["A"], + size_labels=["S", "B"], + ) + + with pytest.raises(ValueError, match="model-labels"): + benchmark_module.resolve_model_specs(args) + + +def test_build_trial_plan_creates_warmup_and_measurement_runs_per_model(benchmark_module, tmp_path: Path): + model_specs = [ + {"config_file": tmp_path / "pathojepa-s.yaml", "model_label": "PathoJEPA-S", "size_label": "S"}, + {"config_file": tmp_path / "pathojepa-b.yaml", "model_label": "PathoJEPA-B", "size_label": "B"}, + ] + plan = benchmark_module.build_trial_plan( + output_root=tmp_path, + model_specs=model_specs, + batch_sizes=[16, 32], + embedding_workers=[2], + num_gpus=[1], + repeat=2, + ) + + assert [item["kind"] for item in plan[:2]] == ["warmup", "measure"] + assert sum(item["kind"] == "warmup" for item in plan) == 4 + assert sum(item["kind"] == "measure" for item in plan) == 8 + assert plan[0]["model_label"] == "PathoJEPA-S" + assert plan[0]["size_label"] == "S" + assert plan[0]["num_gpus"] == 1 + assert plan[0]["run_dir"] == tmp_path / "runs" / "pathojepa-s" / "ng-01" / "bs-0016" / "ew-02" / "warmup" + assert plan[-1]["run_dir"] == tmp_path / "runs" / "pathojepa-b" / "ng-01" / "bs-0032" / "ew-02" / "rep-02" + + +def test_build_size_plot_rows_keeps_best_model_per_gpu_and_size(benchmark_module): + best_rows = [ + { + "gpu_label": "A100", + "model_label": "PathoJEPA-S", + "size_label": "S", + "config_file": "/tmp/pathojepa-s.yaml", + "batch_size": 16, + "embedding_workers": 4, + "repeat_count": 2, + "mean_tiles_per_second": 100.0, + "std_tiles_per_second": 0.0, + "mean_end_to_end_seconds": 10.0, + "mean_slides_per_second": 1.0, + }, + { + "gpu_label": "A100", + "model_label": "Kaiko-S", + "size_label": "S", + "config_file": "/tmp/kaiko-s.yaml", + "batch_size": 32, + "embedding_workers": 8, + "repeat_count": 2, + "mean_tiles_per_second": 120.0, + "std_tiles_per_second": 0.0, + "mean_end_to_end_seconds": 8.0, + "mean_slides_per_second": 1.2, + }, + { + "gpu_label": "A100", + "model_label": "PathoJEPA-B", + "size_label": "B", + "config_file": "/tmp/pathojepa-b.yaml", + "batch_size": 8, + "embedding_workers": 2, + "repeat_count": 2, + "mean_tiles_per_second": 90.0, + "std_tiles_per_second": 0.0, + "mean_end_to_end_seconds": 11.0, + "mean_slides_per_second": 0.9, + }, + ] + + collapsed = benchmark_module.build_size_plot_rows(best_rows) + + assert collapsed == [ + { + "gpu_label": "A100", + "size_label": "B", + "model_label": "PathoJEPA-B", + "mean_tiles_per_second": 90.0, + }, + { + "gpu_label": "A100", + "size_label": "S", + "model_label": "Kaiko-S", + "mean_tiles_per_second": 120.0, + }, + ] + + +def test_run_benchmark_builds_warmup_and_repeat_trials(monkeypatch, benchmark_module, tmp_path: Path): + observed_trial_specs = [] + observed_rows = [] + observed_saved_csvs = [] + + configs = { + tmp_path / "pathojepa-s.yaml": {"csv": "/input/slides.csv", "model": {"name": "pathojepa"}, "speed": {}, "tiling": {"params": {}}}, + tmp_path / "pathojepa-b.yaml": {"csv": "/input/slides.csv", "model": {"name": "pathojepa"}, "speed": {}, "tiling": {"params": {}}}, + } + + monkeypatch.setattr(benchmark_module, "_load_cli_merged_config", lambda path: configs[path]) + monkeypatch.setattr( + benchmark_module, + "load_slides_from_csv", + lambda path: [{"sample_id": "slide-a", "image_path": Path("/tmp/slide-a.svs"), "mask_path": None, "size_bytes": 10}], + ) + monkeypatch.setattr(benchmark_module, "build_balanced_sample", lambda slides, **kwargs: list(slides)) + monkeypatch.setattr(benchmark_module, "_resolve_gpu_label", lambda value: "A100") + + def fake_run_trial(*, trial_spec, slides, shared_csv_path, base_config, gpu_label): + observed_trial_specs.append(dict(trial_spec)) + return { + "gpu_label": gpu_label, + "model_label": trial_spec["model_label"], + "size_label": trial_spec["size_label"], + "config_file": str(trial_spec["config_file"]), + "batch_size": int(trial_spec["batch_size"]), + "embedding_workers": int(trial_spec["embedding_workers"]), + "num_gpus": int(trial_spec["num_gpus"]), + "repeat_index": int(trial_spec["repeat_index"]), + "run_kind": str(trial_spec["kind"]), + "exit_code": 0, + "slides_total": 1, + "slides_with_tiles": 1, + "failed_slides": 0, + "total_tiles": 12, + "end_to_end_seconds": 2.0, + "tiles_per_second": 6.0, + "slides_per_second": 0.5, + "tiling_seconds": 0.5, + "embedding_seconds": 1.25, + "aggregation_seconds": "", + "error": "", + } + + def fake_save_csv(rows, path): + observed_saved_csvs.append(path) + observed_rows.extend(rows) + + monkeypatch.setattr(benchmark_module, "run_trial", fake_run_trial) + monkeypatch.setattr(benchmark_module, "save_csv", fake_save_csv) + monkeypatch.setattr(benchmark_module, "_prepare_chart_outputs", lambda rows, output_dir: 0) + + args = benchmark_module.argparse.Namespace( + config_file=None, + config_files=[tmp_path / "pathojepa-s.yaml", tmp_path / "pathojepa-b.yaml"], + model_labels=["PathoJEPA-S", "PathoJEPA-B"], + size_labels=["S", "B"], + csv=None, + output_dir=tmp_path / "benchmark", + repeat=2, + seed=42, + n_slides=1, + batch_sizes=[16], + embedding_workers=[2], + num_gpus=[1], + gpu_label="manual", + copy_locally=False, + local_dir=tmp_path / "local", + chart_only=None, + internal_harness=False, + metrics_json=None, + progress_jsonl=None, + ) + + exit_code = benchmark_module.run_benchmark(args) + + assert exit_code == 0 + assert [spec["kind"] for spec in observed_trial_specs] == ["warmup", "measure", "measure", "warmup", "measure", "measure"] + assert observed_trial_specs[0]["run_dir"] == tmp_path / "benchmark" / "runs" / "pathojepa-s" / "ng-01" / "bs-0016" / "ew-02" / "warmup" + assert observed_trial_specs[3]["run_dir"] == tmp_path / "benchmark" / "runs" / "pathojepa-b" / "ng-01" / "bs-0016" / "ew-02" / "warmup" + assert all(spec["shared_csv_path"] == tmp_path / "benchmark" / "sampled_slides.csv" for spec in observed_trial_specs) + assert all(spec["num_gpus"] == 1 for spec in observed_trial_specs) + assert [row["run_kind"] for row in observed_rows] == ["measure", "measure", "measure", "measure"] + assert observed_saved_csvs[0] == tmp_path / "benchmark" / "trial_results.csv" + + +def test_run_trial_writes_metrics_error_when_harness_log_is_empty( + benchmark_module, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + run_dir = tmp_path / "warmup" + metrics_payload = { + "error": "CUDA out of memory on device 0", + "slides_total": 0, + "slides_with_tiles": 0, + "failed_slides": 0, + "total_tiles": 0, + "end_to_end_seconds": 0.1, + "tiles_per_second": 0.0, + "slides_per_second": 0.0, + } + + def fake_run_trial_subprocess(*, config_path: Path, metrics_path: Path, progress_path: Path, log_path: Path): + log_path.write_text("", encoding="utf-8") + metrics_path.write_text(json.dumps(metrics_payload), encoding="utf-8") + return benchmark_module.subprocess.CompletedProcess(args=["benchmark"], returncode=1, stdout="", stderr="") + + monkeypatch.setattr(benchmark_module, "_run_trial_subprocess", fake_run_trial_subprocess) + monkeypatch.setattr(benchmark_module, "_write_yaml", lambda data, path: None) + monkeypatch.setattr(benchmark_module, "cleanup_trial_output", lambda output_dir: None) + + row = benchmark_module.run_trial( + trial_spec={ + "run_dir": run_dir, + "model_label": "H0-mini", + "size_label": "ViT-S", + "config_file": tmp_path / "model.yaml", + "batch_size": 1, + "embedding_workers": 4, + "num_gpus": 1, + "repeat_index": 0, + "kind": "warmup", + }, + slides=[], + shared_csv_path=tmp_path / "slides.csv", + base_config={ + "model": {"name": "pathojepa"}, + "speed": {}, + "tiling": {"params": {}}, + }, + gpu_label="A100", + ) + + assert row["exit_code"] == 1 + assert row["error"] == "CUDA out of memory on device 0" + assert (run_dir / "harness.log").read_text(encoding="utf-8") == "ERROR: CUDA out of memory on device 0\n" + + +def test_run_benchmark_surfaces_warmup_error_details( + monkeypatch, benchmark_module, tmp_path: Path, capsys: pytest.CaptureFixture[str] +): + config_path = tmp_path / "model.yaml" + monkeypatch.setattr( + benchmark_module, + "_load_cli_merged_config", + lambda path: {"csv": "/input/slides.csv", "model": {"name": "pathojepa"}, "speed": {}, "tiling": {"params": {}}}, + ) + monkeypatch.setattr( + benchmark_module, + "load_slides_from_csv", + lambda path: [{"sample_id": "slide-a", "image_path": Path("/tmp/slide-a.svs"), "mask_path": None, "size_bytes": 10}], + ) + monkeypatch.setattr(benchmark_module, "build_balanced_sample", lambda slides, **kwargs: list(slides)) + monkeypatch.setattr(benchmark_module, "_resolve_gpu_label", lambda value: "A100") + + def fake_run_trial(*, trial_spec, slides, shared_csv_path, base_config, gpu_label): + return { + "gpu_label": gpu_label, + "model_label": trial_spec["model_label"], + "size_label": trial_spec["size_label"], + "config_file": str(trial_spec["config_file"]), + "batch_size": int(trial_spec["batch_size"]), + "embedding_workers": int(trial_spec["embedding_workers"]), + "num_gpus": int(trial_spec["num_gpus"]), + "repeat_index": int(trial_spec["repeat_index"]), + "run_kind": str(trial_spec["kind"]), + "exit_code": 1, + "slides_total": 0, + "slides_with_tiles": 0, + "failed_slides": 0, + "total_tiles": 0, + "end_to_end_seconds": 0.1, + "tiles_per_second": 0.0, + "slides_per_second": 0.0, + "tiling_seconds": "", + "embedding_seconds": "", + "aggregation_seconds": "", + "error": "CUDA out of memory on device 0", + } + + monkeypatch.setattr(benchmark_module, "run_trial", fake_run_trial) + + args = benchmark_module.argparse.Namespace( + config_file=config_path, + config_files=None, + model_labels=None, + size_labels=None, + csv=None, + output_dir=tmp_path / "benchmark", + repeat=1, + seed=42, + n_slides=1, + batch_sizes=[1], + embedding_workers=[4], + num_gpus=[1], + gpu_label="manual", + copy_locally=False, + local_dir=tmp_path / "local", + chart_only=None, + internal_harness=False, + metrics_json=None, + progress_jsonl=None, + ) + + exit_code = benchmark_module.run_benchmark(args) + captured = capsys.readouterr() + + assert exit_code == 1 + assert "Warmup failed for model [unspecified] bs=1 workers=4 gpus=1" in captured.err + assert "harness.log" in captured.err + assert "CUDA out of memory on device 0" in captured.err diff --git a/tests/test_benchmark_end_to_end_paths.py b/tests/test_benchmark_end_to_end_paths.py new file mode 100644 index 0000000..051c189 --- /dev/null +++ b/tests/test_benchmark_end_to_end_paths.py @@ -0,0 +1,424 @@ +import importlib.util +import json +import sys +from argparse import Namespace +from pathlib import Path +import types + + +ROOT = Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "scripts" / "benchmark_end_to_end_paths.py" + + +def _load_module(): + spec = importlib.util.spec_from_file_location("benchmark_end_to_end_paths", MODULE_PATH) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load benchmark module from {MODULE_PATH}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_extract_batch_timing_metrics_reports_subpath_totals(tmp_path: Path): + module = _load_module() + progress_path = tmp_path / "progress.jsonl" + records = [ + { + "kind": "embedding.batch.timing", + "payload": { + "loader_wait_ms": 10.0, + "ready_wait_ms": 5.0, + "preprocess_ms": 15.0, + "worker_batch_ms": 100.0, + "reader_open_ms": 2.0, + "reader_read_ms": 80.0, + "forward_ms": 70.0, + "gpu_busy_fraction": 0.8, + }, + }, + { + "kind": "embedding.batch.timing", + "payload": { + "loader_wait_ms": 30.0, + "ready_wait_ms": 5.0, + "preprocess_ms": 5.0, + "worker_batch_ms": 110.0, + "reader_open_ms": 1.0, + "reader_read_ms": 90.0, + "forward_ms": 90.0, + "gpu_busy_fraction": 0.9, + }, + }, + ] + progress_path.write_text("".join(json.dumps(record) + "\n" for record in records), encoding="utf-8") + + metrics = module.extract_batch_timing_metrics(progress_path) + + assert metrics["timed_batches"] == 2 + assert metrics["mean_loader_wait_ms"] == 20.0 + assert metrics["mean_forward_ms"] == 80.0 + assert metrics["data_pipeline_seconds"] == 0.07 + assert metrics["forward_seconds"] == 0.16 + assert metrics["accounted_embedding_seconds"] == 0.23 + assert metrics["data_pipeline_fraction"] == 0.3043 + assert metrics["forward_fraction"] == 0.6957 + assert metrics["loader_wait_fraction"] == 0.2174 + + +def test_run_trial_cleans_stale_run_dir_before_execution(monkeypatch, tmp_path: Path): + module = _load_module() + run_dir = tmp_path / "runs" / "tar" / "rep-01" + (run_dir / "output" / "tiles").mkdir(parents=True) + (run_dir / "progress.jsonl").write_text("stale progress\n", encoding="utf-8") + (run_dir / "metrics.json").write_text('{"success": false}\n', encoding="utf-8") + (run_dir / "stale.txt").write_text("old file\n", encoding="utf-8") + + def fake_run_trial_subprocess(*, config_path, metrics_path, progress_path, log_path): + assert not (run_dir / "stale.txt").exists() + assert not progress_path.exists() + metrics_path.write_text( + json.dumps( + { + "slides_total": 1, + "slides_with_tiles": 1, + "failed_slides": 0, + "total_tiles": 100, + "end_to_end_seconds": 2.5, + "tiles_per_second": 40.0, + "tiling_seconds": 0.2, + "embedding_seconds": 2.0, + "timed_batches": 3, + "mean_loader_wait_ms": 1.0, + "max_loader_wait_ms": 4.0, + "mean_ready_wait_ms": 0.1, + "mean_preprocess_ms": 0.2, + "mean_worker_batch_ms": 5.0, + "mean_reader_open_ms": 0.3, + "mean_reader_read_ms": 2.0, + "mean_forward_ms": 3.0, + "data_pipeline_seconds": 0.5, + "forward_seconds": 1.5, + "accounted_embedding_seconds": 2.0, + "data_pipeline_fraction": 0.25, + "forward_fraction": 0.75, + "loader_wait_fraction": 0.2, + "gpu_busy_fraction": 0.8, + } + ) + + "\n", + encoding="utf-8", + ) + log_path.write_text("fresh log\n", encoding="utf-8") + return types.SimpleNamespace(returncode=0) + + monkeypatch.setattr(module, "_run_trial_subprocess", fake_run_trial_subprocess) + monkeypatch.setattr(module, "cleanup_trial_output", lambda output_dir: None) + + row = module.run_trial( + mode="tar", + kind="measure", + repeat_index=1, + run_dir=run_dir, + config={"tiling": {}, "output_dir": str(tmp_path / "unused")}, + ) + + assert row["exit_code"] == 0 + assert row["total_tiles"] == 100 + assert row["mean_forward_ms"] == 3.0 + assert not (run_dir / "stale.txt").exists() + assert (run_dir / "progress.jsonl").exists() is False + + +def test_apply_mode_overrides_sets_tar_wsd_and_cucim_modes(): + module = _load_module() + base = { + "tiling": { + "on_the_fly": True, + "backend": "asap", + "use_supertiles": False, + "adaptive_batching": True, + "jpeg_backend": "pil", + "read_coordinates_from": "/tmp/coords", + "read_tiles_from": "/tmp/tiles", + } + } + + tar_cfg = module._apply_mode_overrides(base, "tar") + wsd_cfg = module._apply_mode_overrides(base, "wsd_single") + cucim_cfg = module._apply_mode_overrides(base, "cucim_supertiles") + + assert tar_cfg["tiling"]["on_the_fly"] is False + assert tar_cfg["tiling"]["backend"] == "cucim" + assert tar_cfg["tiling"]["use_supertiles"] is True + assert tar_cfg["tiling"]["read_coordinates_from"] is None + assert tar_cfg["tiling"]["read_tiles_from"] is None + + assert wsd_cfg["tiling"]["on_the_fly"] is True + assert wsd_cfg["tiling"]["backend"] == "asap" + assert wsd_cfg["tiling"]["use_supertiles"] is False + assert wsd_cfg["tiling"]["adaptive_batching"] is False + assert wsd_cfg["tiling"]["read_coordinates_from"] is None + assert wsd_cfg["tiling"]["read_tiles_from"] is None + + assert cucim_cfg["tiling"]["on_the_fly"] is True + assert cucim_cfg["tiling"]["backend"] == "cucim" + assert cucim_cfg["tiling"]["use_supertiles"] is True + assert cucim_cfg["tiling"]["adaptive_batching"] is False + + +def test_aggregate_trial_results_groups_by_mode(): + module = _load_module() + rows = [ + { + "mode": "tar", + "exit_code": 0, + "total_tiles": 1000, + "end_to_end_seconds": 10.0, + "tiles_per_second": 100.0, + "tiling_seconds": 1.0, + "embedding_seconds": 8.0, + "mean_loader_wait_ms": 1.0, + "mean_forward_ms": 2.0, + "data_pipeline_seconds": 1.5, + "forward_seconds": 8.0, + "accounted_embedding_seconds": 9.5, + "data_pipeline_fraction": 0.1579, + "forward_fraction": 0.8421, + "gpu_busy_fraction": 0.9, + }, + { + "mode": "tar", + "exit_code": 0, + "total_tiles": 1000, + "end_to_end_seconds": 12.0, + "tiles_per_second": 80.0, + "tiling_seconds": 1.5, + "embedding_seconds": 9.0, + "mean_loader_wait_ms": 2.0, + "mean_forward_ms": 3.0, + "data_pipeline_seconds": 2.0, + "forward_seconds": 9.0, + "accounted_embedding_seconds": 11.0, + "data_pipeline_fraction": 0.1818, + "forward_fraction": 0.8182, + "gpu_busy_fraction": 0.8, + }, + { + "mode": "wsd_single", + "exit_code": 0, + "total_tiles": 1000, + "end_to_end_seconds": 11.5, + "tiles_per_second": 87.0, + "tiling_seconds": 0.7, + "embedding_seconds": 10.1, + "mean_loader_wait_ms": 4.0, + "mean_forward_ms": 2.8, + "data_pipeline_seconds": 2.4, + "forward_seconds": 10.1, + "accounted_embedding_seconds": 12.5, + "data_pipeline_fraction": 0.192, + "forward_fraction": 0.808, + "gpu_busy_fraction": 0.88, + }, + { + "mode": "cucim_supertiles", + "exit_code": 0, + "total_tiles": 1000, + "end_to_end_seconds": 9.0, + "tiles_per_second": 111.0, + "tiling_seconds": 0.5, + "embedding_seconds": 8.2, + "mean_loader_wait_ms": 0.2, + "mean_forward_ms": 2.5, + "data_pipeline_seconds": 0.8, + "forward_seconds": 8.2, + "accounted_embedding_seconds": 9.0, + "data_pipeline_fraction": 0.0889, + "forward_fraction": 0.9111, + "gpu_busy_fraction": 0.95, + }, + ] + + aggregated = module.aggregate_trial_results(rows) + + assert [row["mode"] for row in aggregated] == ["tar", "wsd_single", "cucim_supertiles"] + assert aggregated[0]["mean_end_to_end_seconds"] == 11.0 + assert aggregated[0]["mean_tiles_per_second"] == 90.0 + assert aggregated[0]["mean_data_pipeline_seconds"] == 1.75 + assert aggregated[0]["mean_forward_seconds"] == 8.5 + assert aggregated[1]["mean_end_to_end_seconds"] == 11.5 + assert aggregated[2]["mean_end_to_end_seconds"] == 9.0 + + +def test_prepare_chart_outputs_writes_summary_and_plots(monkeypatch, tmp_path: Path): + module = _load_module() + called = {"save_csv": 0, "end_to_end": 0, "stage": 0, "embedding_subpath": 0} + + monkeypatch.setattr(module, "save_csv", lambda rows, path: called.__setitem__("save_csv", called["save_csv"] + 1)) + monkeypatch.setattr(module, "plot_end_to_end_by_path", lambda rows, path: called.__setitem__("end_to_end", called["end_to_end"] + 1)) + monkeypatch.setattr(module, "plot_stage_breakdown", lambda rows, path: called.__setitem__("stage", called["stage"] + 1)) + monkeypatch.setattr( + module, + "plot_embedding_subpath_breakdown", + lambda rows, path: called.__setitem__("embedding_subpath", called["embedding_subpath"] + 1), + ) + + rows = [ + { + "mode": "tar", + "exit_code": 0, + "total_tiles": 1000, + "end_to_end_seconds": 10.0, + "tiles_per_second": 100.0, + "tiling_seconds": 1.0, + "embedding_seconds": 8.0, + "mean_loader_wait_ms": 1.0, + "mean_forward_ms": 2.0, + "data_pipeline_seconds": 1.5, + "forward_seconds": 8.0, + "accounted_embedding_seconds": 9.5, + "data_pipeline_fraction": 0.1579, + "forward_fraction": 0.8421, + "gpu_busy_fraction": 0.9, + }, + { + "mode": "wsd_single", + "exit_code": 0, + "total_tiles": 1000, + "end_to_end_seconds": 11.5, + "tiles_per_second": 87.0, + "tiling_seconds": 0.7, + "embedding_seconds": 10.1, + "mean_loader_wait_ms": 4.0, + "mean_forward_ms": 2.8, + "data_pipeline_seconds": 2.4, + "forward_seconds": 10.1, + "accounted_embedding_seconds": 12.5, + "data_pipeline_fraction": 0.192, + "forward_fraction": 0.808, + "gpu_busy_fraction": 0.88, + }, + { + "mode": "cucim_supertiles", + "exit_code": 0, + "total_tiles": 1000, + "end_to_end_seconds": 9.0, + "tiles_per_second": 111.0, + "tiling_seconds": 0.5, + "embedding_seconds": 8.2, + "mean_loader_wait_ms": 0.2, + "mean_forward_ms": 2.5, + "data_pipeline_seconds": 0.8, + "forward_seconds": 8.2, + "accounted_embedding_seconds": 9.0, + "data_pipeline_fraction": 0.0889, + "forward_fraction": 0.9111, + "gpu_busy_fraction": 0.95, + }, + ] + + summary = module._prepare_chart_outputs(rows, tmp_path) + + assert len(summary) == 3 + assert called == {"save_csv": 1, "end_to_end": 1, "stage": 1, "embedding_subpath": 1} + + +def test_load_cli_merged_config_uses_regular_cli_loader(monkeypatch, tmp_path: Path): + module = _load_module() + config_path = tmp_path / "benchmark.yaml" + config_path.write_text( + "\n".join( + [ + "model:", + ' name: "h0-mini"', + "tiling:", + " params:", + " target_spacing_um: 0.5", + " target_tile_size_px: 224", + "speed: {}", + "wandb:", + " enable: false", + ] + ) + + "\n", + encoding="utf-8", + ) + + captured: dict[str, object] = {} + + def fake_get_cfg_from_args(args): + captured["config_file"] = args.config_file + captured["output_dir"] = args.output_dir + captured["opts"] = list(args.opts) + return { + "model": {"name": "h0-mini"}, + "speed": {"num_workers_embedding": 8}, + "wandb": {"enable": False}, + } + + monkeypatch.setitem( + sys.modules, + "slide2vec.utils.config", + types.SimpleNamespace(get_cfg_from_args=fake_get_cfg_from_args), + ) + + loaded = module._load_cli_merged_config(config_path) + + assert captured == { + "config_file": str(config_path), + "output_dir": None, + "opts": [], + } + assert loaded == { + "model": {"name": "h0-mini"}, + "speed": {"num_workers_embedding": 8}, + "wandb": {"enable": False}, + } + + +def test_run_benchmark_requires_config_file_only_for_normal_cli(monkeypatch, tmp_path: Path): + module = _load_module() + + console_messages: list[str] = [] + + class FakeConsole: + def print(self, message, *args, **kwargs): + console_messages.append(str(message)) + + monkeypatch.setitem(sys.modules, "rich.console", types.SimpleNamespace(Console=lambda: FakeConsole())) + monkeypatch.setitem(sys.modules, "rich.panel", types.SimpleNamespace(Panel=lambda *args, **kwargs: "panel")) + monkeypatch.setitem( + sys.modules, + "rich.progress", + types.SimpleNamespace( + SpinnerColumn=lambda *args, **kwargs: None, + TextColumn=lambda *args, **kwargs: None, + BarColumn=lambda *args, **kwargs: None, + TaskProgressColumn=lambda *args, **kwargs: None, + TimeElapsedColumn=lambda *args, **kwargs: None, + Progress=None, + ), + ) + + args = Namespace( + csv=tmp_path / "slides.csv", + config_file=None, + repeat=1, + warmup=0, + batch_size=32, + num_dataloader_workers=32, + num_cucim_workers=4, + num_preprocessing_workers=8, + output_dir=tmp_path / "out", + chart_only=None, + internal_harness=False, + harness_config=None, + metrics_json=None, + progress_jsonl=None, + ) + + result = module.run_benchmark(args) + + assert result == 1 + assert any("--config-file is required" in message for message in console_messages) diff --git a/tests/test_benchmark_tile_read_strategies.py b/tests/test_benchmark_tile_read_strategies.py new file mode 100644 index 0000000..fb93ae6 --- /dev/null +++ b/tests/test_benchmark_tile_read_strategies.py @@ -0,0 +1,151 @@ +import importlib.util +import sys +from argparse import Namespace +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "scripts" / "benchmark_tile_read_strategies.py" + + +def _load_module(): + spec = importlib.util.spec_from_file_location("benchmark_tile_read_strategies", MODULE_PATH) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load benchmark module from {MODULE_PATH}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_resolve_batch_sizes_prefers_sweep_and_deduplicates(): + module = _load_module() + + args = Namespace(batch_size=256, batch_sizes=[32, 64, 32, 128]) + + assert module._resolve_batch_sizes(args) == [32, 64, 128] + + +def test_aggregate_trial_results_groups_by_mode_and_batch_size(): + module = _load_module() + + rows = [ + { + "mode": "tar", + "batch_size": 64, + "exit_code": 0, + "total_tiles": 1000, + "tiles_per_second": 100.0, + "end_to_end_seconds": 10.0, + "mean_loader_wait_ms": 1.0, + "max_loader_wait_ms": 2.0, + "mean_ready_wait_ms": 0.5, + "mean_preprocess_ms": 3.0, + "mean_worker_batch_ms": 4.0, + "mean_reader_open_ms": 0.1, + "mean_reader_read_ms": 2.5, + "mean_forward_ms": 5.0, + "loader_wait_fraction": 0.1, + "gpu_busy_fraction": 0.9, + }, + { + "mode": "tar", + "batch_size": 128, + "exit_code": 0, + "total_tiles": 1000, + "tiles_per_second": 150.0, + "end_to_end_seconds": 8.0, + "mean_loader_wait_ms": 0.5, + "max_loader_wait_ms": 1.0, + "mean_ready_wait_ms": 0.2, + "mean_preprocess_ms": 2.0, + "mean_worker_batch_ms": 3.0, + "mean_reader_open_ms": 0.1, + "mean_reader_read_ms": 2.0, + "mean_forward_ms": 4.0, + "loader_wait_fraction": 0.05, + "gpu_busy_fraction": 0.95, + }, + { + "mode": "cucim_single", + "batch_size": 64, + "exit_code": 0, + "total_tiles": 1000, + "tiles_per_second": 80.0, + "end_to_end_seconds": 12.0, + "mean_loader_wait_ms": 2.0, + "max_loader_wait_ms": 4.0, + "mean_ready_wait_ms": 0.8, + "mean_preprocess_ms": 3.5, + "mean_worker_batch_ms": 7.0, + "mean_reader_open_ms": 0.2, + "mean_reader_read_ms": 6.0, + "mean_forward_ms": 5.5, + "loader_wait_fraction": 0.2, + "gpu_busy_fraction": 0.8, + }, + ] + + aggregated = module.aggregate_trial_results(rows) + + assert [(row["mode"], row["batch_size"]) for row in aggregated] == [ + ("tar", 64), + ("cucim_single", 64), + ("tar", 128), + ] + assert aggregated[0]["mean_tiles_per_second"] == 100.0 + assert aggregated[2]["mean_tiles_per_second"] == 150.0 + + +def test_prepare_chart_outputs_uses_batch_size_plot_for_sweeps(monkeypatch, tmp_path: Path): + module = _load_module() + called = {"save_csv": 0, "strategy": 0, "timing": 0, "batch_curve": 0} + + monkeypatch.setattr(module, "save_csv", lambda rows, path: called.__setitem__("save_csv", called["save_csv"] + 1)) + monkeypatch.setattr(module, "plot_throughput_by_strategy", lambda rows, path: called.__setitem__("strategy", called["strategy"] + 1)) + monkeypatch.setattr(module, "plot_timing_breakdown", lambda rows, path: called.__setitem__("timing", called["timing"] + 1)) + monkeypatch.setattr(module, "plot_throughput_vs_batch_size", lambda rows, path: called.__setitem__("batch_curve", called["batch_curve"] + 1)) + + trial_rows = [ + { + "mode": "tar", + "batch_size": 64, + "exit_code": 0, + "total_tiles": 100, + "tiles_per_second": 100.0, + "end_to_end_seconds": 1.0, + "mean_loader_wait_ms": 1.0, + "max_loader_wait_ms": 1.0, + "mean_ready_wait_ms": 0.1, + "mean_preprocess_ms": 1.0, + "mean_worker_batch_ms": 2.0, + "mean_reader_open_ms": 0.1, + "mean_reader_read_ms": 1.5, + "mean_forward_ms": 3.0, + "loader_wait_fraction": 0.1, + "gpu_busy_fraction": 0.9, + }, + { + "mode": "tar", + "batch_size": 128, + "exit_code": 0, + "total_tiles": 100, + "tiles_per_second": 120.0, + "end_to_end_seconds": 1.0, + "mean_loader_wait_ms": 1.0, + "max_loader_wait_ms": 1.0, + "mean_ready_wait_ms": 0.1, + "mean_preprocess_ms": 1.0, + "mean_worker_batch_ms": 2.0, + "mean_reader_open_ms": 0.1, + "mean_reader_read_ms": 1.5, + "mean_forward_ms": 3.0, + "loader_wait_fraction": 0.1, + "gpu_busy_fraction": 0.9, + }, + ] + + summary_rows = module._prepare_chart_outputs(trial_rows, tmp_path) + + assert len(summary_rows) == 2 + assert called == {"save_csv": 1, "strategy": 2, "timing": 2, "batch_curve": 1} diff --git a/tests/test_dependency_split.py b/tests/test_dependency_split.py index a41ead4..25867ed 100644 --- a/tests/test_dependency_split.py +++ b/tests/test_dependency_split.py @@ -128,10 +128,6 @@ def test_readme_documents_core_and_models_installs(): assert 'pip install "slide2vec[models]"' in readme -def test_tile_dataset_uses_direct_transformers_type_check(): - source = (ROOT / "slide2vec" / "data" / "dataset.py").read_text(encoding="utf-8") - - assert "from transformers.image_processing_utils import BaseImageProcessor" in source - assert "isinstance(self.transforms, BaseImageProcessor)" in source +def test_models_module_imports_transformers(): imported_modules = _top_level_imported_modules(ROOT / "slide2vec" / "models" / "models.py") assert "transformers" in imported_modules diff --git a/tests/test_hs2p_package_cutover.py b/tests/test_hs2p_package_cutover.py index 62fba82..2035ed8 100644 --- a/tests/test_hs2p_package_cutover.py +++ b/tests/test_hs2p_package_cutover.py @@ -96,8 +96,8 @@ def test_load_process_df_requires_hs2p_process_list_columns(tmp_path: Path): process_list = tmp_path / "process_list.csv" process_list.write_text( - "sample_id,image_path,mask_path,tiling_status,num_tiles,tiles_npz_path,tiles_meta_path,error,traceback\n" - "slide-1,/data/slide-1.svs,/data/slide-1-mask.png,success,4,/tmp/slide-1.tiles.npz,/tmp/slide-1.tiles.meta.json,,\n", + "sample_id,image_path,mask_path,tiling_status,num_tiles,coordinates_npz_path,coordinates_meta_path,error,traceback\n" + "slide-1,/data/slide-1.svs,/data/slide-1-mask.png,success,4,/tmp/slide-1.coordinates.npz,/tmp/slide-1.coordinates.meta.json,,\n", encoding="utf-8", ) df = helper.load_process_df( @@ -112,8 +112,9 @@ def test_load_process_df_requires_hs2p_process_list_columns(tmp_path: Path): "spacing_at_level_0", "tiling_status", "num_tiles", - "tiles_npz_path", - "tiles_meta_path", + "coordinates_npz_path", + "coordinates_meta_path", + "tiles_tar_path", "feature_status", "aggregation_status", "error", diff --git a/tests/test_output_consistency.py b/tests/test_output_consistency.py index 63cb6d5..96acbe4 100644 --- a/tests/test_output_consistency.py +++ b/tests/test_output_consistency.py @@ -69,7 +69,7 @@ # -- speed -- SPEED_PARAMS = dict( - fp16=True, # override (default: false) + precision="fp16", # override (default: fp32) num_workers=4, # override (default: 8) num_workers_embedding=4, # override (default: 8) ) @@ -122,7 +122,9 @@ def test_output_consistency(wsi_path, mask_path, tmp_path): "save_previews": False, # override (default: true) "seed": 0, "tiling": { + "read_coordinates_from": None, "read_tiles_from": None, + "on_the_fly": True, "backend": "asap", "params": TILING_PARAMS, "seg_params": TILING_SEG_PARAMS, @@ -149,12 +151,13 @@ def test_output_consistency(wsi_path, mask_path, tmp_path): ) # 4. Assert coordinates match exactly (tiling is deterministic) - gt_coords = np.load(GT_DIR / "test-wsi.tiles.npz", allow_pickle=False) - coords = np.load(tmp_path / "coordinates" / "test-wsi.tiles.npz", allow_pickle=False) + gt_coords = np.load(GT_DIR / "test-wsi.coordinates.npz", allow_pickle=False) + coords = np.load(tmp_path / "tiles" / "test-wsi.coordinates.npz", allow_pickle=False) np.testing.assert_array_equal(coords, gt_coords) - meta = json.loads((tmp_path / "coordinates" / "test-wsi.tiles.meta.json").read_text()) + meta = json.loads((tmp_path / "tiles" / "test-wsi.coordinates.meta.json").read_text()) assert meta["sample_id"] == "test-wsi" + assert meta["backend"] == "asap" assert meta["target_spacing_um"] == pytest.approx(0.5) assert meta["target_tile_size_px"] == 224 diff --git a/tests/test_progress.py b/tests/test_progress.py index 6355f9d..545503c 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -1,5 +1,7 @@ import io import json +import sys +import types from contextlib import nullcontext from pathlib import Path from types import SimpleNamespace @@ -24,6 +26,68 @@ def write_log(self, message, *, stream=None): self.log_lines.append(message) +def _install_fake_rich_runtime(monkeypatch): + fake_rich = types.ModuleType("rich") + fake_console = types.ModuleType("rich.console") + fake_progress = types.ModuleType("rich.progress") + + class FakeConsole: + def __init__(self, file=None): + self.file = file + self.is_terminal = True + self.lines = [] + + def print(self, message, **kwargs): + self.lines.append((message, kwargs)) + + class FakeProgress: + def __init__(self, *args, **kwargs): + self.tasks = {} + self.next_task_id = 1 + + def start(self): + return None + + def stop(self): + return None + + def add_task(self, description, total=None, completed=0, visible=True): + task_id = self.next_task_id + self.next_task_id += 1 + self.tasks[task_id] = { + "description": description, + "total": total, + "completed": completed, + "visible": visible, + } + return task_id + + def update(self, task_id, **kwargs): + self.tasks[task_id].update(kwargs) + + def remove_task(self, task_id): + self.tasks.pop(task_id, None) + + def advance(self, task_id, advance=1): + self.tasks[task_id]["completed"] = self.tasks[task_id].get("completed", 0) + advance + + fake_console.Console = FakeConsole + fake_progress.Progress = FakeProgress + fake_progress.BarColumn = lambda *args, **kwargs: None + fake_progress.MofNCompleteColumn = lambda *args, **kwargs: None + fake_progress.SpinnerColumn = lambda *args, **kwargs: None + fake_progress.TaskProgressColumn = lambda *args, **kwargs: None + fake_progress.TextColumn = lambda *args, **kwargs: None + fake_progress.TimeElapsedColumn = lambda *args, **kwargs: None + fake_progress.TimeRemainingColumn = lambda *args, **kwargs: None + fake_rich.console = fake_console + fake_rich.progress = fake_progress + monkeypatch.setitem(sys.modules, "rich", fake_rich) + monkeypatch.setitem(sys.modules, "rich.console", fake_console) + monkeypatch.setitem(sys.modules, "rich.progress", fake_progress) + return FakeConsole, FakeProgress + + def test_cli_main_installs_progress_reporter_only_during_pipeline_run(monkeypatch, tmp_path: Path): import slide2vec.cli as cli import slide2vec.progress as progress @@ -170,6 +234,49 @@ def __call__(self, image): assert all(payload["sample_id"] == "slide-a" for payload in payloads) +def test_run_forward_pass_emits_batch_timing_events(): + torch = pytest.importorskip("torch") + import slide2vec.inference as inference + import slide2vec.progress as progress + + reporter = RecordingReporter() + + class FakeModel: + def __call__(self, image): + batch_size = image.shape[0] + return {"embedding": torch.ones((batch_size, 3), dtype=torch.float32)} + + dataloader = [ + (torch.tensor([0, 1]), torch.ones((2, 3, 4, 4), dtype=torch.float32)), + (torch.tensor([2]), torch.ones((1, 3, 4, 4), dtype=torch.float32)), + ] + loaded = SimpleNamespace(device="cpu", feature_dim=3, model=FakeModel()) + + with progress.activate_progress_reporter(reporter): + inference._run_forward_pass( + dataloader, + loaded, + nullcontext(), + sample_id="slide-a", + total_items=3, + unit_label="tile", + ) + + timing_payloads = [event.payload for event in reporter.events if event.kind == "embedding.batch.timing"] + assert len(timing_payloads) == 2 + assert [payload["batch_size"] for payload in timing_payloads] == [2, 1] + assert all(payload["sample_id"] == "slide-a" for payload in timing_payloads) + assert all(payload["loader_wait_ms"] >= 0.0 for payload in timing_payloads) + assert all(payload["ready_wait_ms"] >= 0.0 for payload in timing_payloads) + assert all(payload["forward_ms"] >= 0.0 for payload in timing_payloads) + assert all(payload["preprocess_ms"] >= 0.0 for payload in timing_payloads) + assert all(payload["worker_batch_ms"] >= 0.0 for payload in timing_payloads) + assert all(payload["reader_open_ms"] >= 0.0 for payload in timing_payloads) + assert all(payload["reader_read_ms"] >= 0.0 for payload in timing_payloads) + assert all(payload["gpu_busy_fraction"] >= 0.0 for payload in timing_payloads) + assert all(payload["gpu_busy_fraction"] <= 1.0 for payload in timing_payloads) + + def test_read_tiling_progress_snapshot_summarizes_process_list(tmp_path: Path): import slide2vec.progress as progress @@ -265,6 +372,103 @@ def wait(self, timeout=None): assert [event.kind for event in reporter.events] == ["embedding.slide.started"] +def test_rich_reporter_collapses_multi_gpu_model_loading_into_one_task(monkeypatch): + import slide2vec.progress as progress + + FakeConsole, _FakeProgress = _install_fake_rich_runtime(monkeypatch) + console = FakeConsole() + reporter = progress.RichCliProgressReporter(console=console) + + reporter.emit(progress.ProgressEvent(kind="model.loading", payload={"model_name": "h0-mini"})) + reporter.emit(progress.ProgressEvent(kind="model.loading", payload={"model_name": "h0-mini"})) + + assert len(reporter.progress.tasks) == 1 + + reporter.emit( + progress.ProgressEvent( + kind="model.ready", + payload={"model_name": "h0-mini", "device": "cuda:0"}, + ) + ) + + assert len(reporter.progress.tasks) == 1 + + reporter.emit( + progress.ProgressEvent( + kind="model.ready", + payload={"model_name": "h0-mini", "device": "cuda:1"}, + ) + ) + + assert reporter.progress.tasks == {} + assert len(console.lines) == 1 + + +def test_jsonl_progress_reporter_tags_worker_events_with_gpu_label(tmp_path: Path): + import slide2vec.progress as progress + + progress_path = tmp_path / "logs" / "worker.progress.jsonl" + reporter = progress.JsonlProgressReporter( + progress_path, + rank=1, + progress_label="cuda:1", + ) + reporter.emit( + progress.ProgressEvent( + kind="embedding.slide.started", + payload={"sample_id": "slide-b", "total_tiles": 8}, + ) + ) + reporter.close() + + events, _offsets = progress.read_progress_events(progress_path) + + assert [event.kind for event in events] == ["embedding.slide.started"] + assert events[0].payload["progress_label"] == "cuda:1" + + +def test_rich_reporter_tracks_multi_gpu_embedding_rows_separately(monkeypatch): + import slide2vec.progress as progress + + FakeConsole, _FakeProgress = _install_fake_rich_runtime(monkeypatch) + reporter = progress.RichCliProgressReporter(console=FakeConsole()) + + reporter.emit( + progress.ProgressEvent( + kind="embedding.slide.started", + payload={"sample_id": "slide-a", "total_tiles": 5, "progress_label": "cuda:0"}, + ) + ) + reporter.emit( + progress.ProgressEvent( + kind="embedding.slide.started", + payload={"sample_id": "slide-b", "total_tiles": 7, "progress_label": "cuda:1"}, + ) + ) + + descriptions = sorted(task["description"] for task in reporter.progress.tasks.values()) + assert descriptions == ["cuda:0: slide-a", "cuda:1: slide-b"] + + reporter.emit( + progress.ProgressEvent( + kind="embedding.tile.progress", + payload={ + "sample_id": "slide-a", + "processed": 3, + "total": 5, + "unit": "tile", + "progress_label": "cuda:0", + }, + ) + ) + + task_by_description = { + task["description"]: task for task in reporter.progress.tasks.values() + } + assert task_by_description["cuda:0: slide-a"]["completed"] == 3 + assert task_by_description["cuda:1: slide-b"]["completed"] == 0 + + def test_progress_aware_log_handler_routes_logs_through_active_reporter(): import logging diff --git a/tests/test_regression_core.py b/tests/test_regression_core.py index b684531..f8e7380 100644 --- a/tests/test_regression_core.py +++ b/tests/test_regression_core.py @@ -28,132 +28,32 @@ def test_resource_loading_uses_packaged_configs(): assert "model" in cfg assert config_resource("preprocessing", "default").name == "default.yaml" -def test_tile_dataset_scales_coordinates_and_returns_transformed_tiles(monkeypatch): - pytest.importorskip("torch") - pytest.importorskip("wholeslidedata") - from slide2vec.data.dataset import TileDataset - - tiling_result = SimpleNamespace( - target_spacing_um=0.5, - target_tile_size_px=4, - read_spacing_um=0.5, - read_tile_size_px=2, - tile_size_lv0=224, - x=np.array([10, 30]), - y=np.array([20, 40]), - ) - - class FakeWholeSlideImage: - constructor_calls = [] - patch_calls = [] - def __init__(self, path, backend): - self.path = path - self.backend = backend - self.spacings = [0.25] - type(self).constructor_calls.append((Path(path), backend)) +def test_packaged_model_presets_align_with_recommended_settings(): + pytest.importorskip("omegaconf") - def get_patch(self, x, y, width, height, spacing, center): - type(self).patch_calls.append((x, y, width, height, spacing, center)) - return np.full((height, width, 3), fill_value=64, dtype=np.uint8) + from slide2vec.utils.config import get_cfg_from_args - monkeypatch.setattr("slide2vec.data.dataset.wsd.WholeSlideImage", FakeWholeSlideImage) + preset_dir = ROOT / "slide2vec" / "configs" / "models" + preset_paths = sorted(path for path in preset_dir.glob("*.yaml") if path.name != "default.yaml") - seen_shapes = [] + for preset_path in preset_paths: + args = SimpleNamespace(config_file=str(preset_path), output_dir=None, opts=[]) + cfg = get_cfg_from_args(args) + assert cfg.model.name - def transform(tile): - arr = np.asarray(tile) - seen_shapes.append(arr.shape) - return arr - dataset = TileDataset( - sample_id="slide-a", - wsi_path=Path("/tmp/slide-a.svs"), - mask_path=None, - tiling_result=tiling_result, - backend="asap", - transforms=transform, - ) +def test_packaged_non_default_model_presets_do_not_contain_comments(): + preset_dir = ROOT / "slide2vec" / "configs" / "models" + preset_paths = sorted(path for path in preset_dir.glob("*.yaml") if path.name != "default.yaml") - np.testing.assert_array_equal(dataset.coordinates, np.array([[10, 20], [30, 40]])) - np.testing.assert_array_equal(dataset.scaled_coordinates, np.array([[5, 10], [15, 20]])) - assert len(dataset) == 2 - - idx, tile = dataset[1] - - assert idx == 1 - assert tile.shape == (4, 4, 3) - assert seen_shapes == [(4, 4, 3)] - assert FakeWholeSlideImage.patch_calls == [(30, 40, 2, 2, 0.5, False)] - assert len(FakeWholeSlideImage.constructor_calls) == 2 - -def test_tile_dataset_requires_coordinate_arrays(): - pytest.importorskip("torch") - pytest.importorskip("wholeslidedata") - from slide2vec.data.dataset import TileDataset - - tiling_result = SimpleNamespace( - target_spacing_um=0.5, - target_tile_size_px=4, - read_spacing_um=0.5, - read_tile_size_px=2, - tile_size_lv0=224, - x=np.array([10]), - y=None, - ) - - with pytest.raises(ValueError, match="Tiling result must expose x/y coordinates"): - TileDataset( - sample_id="slide-a", - wsi_path=Path("/tmp/slide-a.svs"), - mask_path=None, - tiling_result=tiling_result, - backend="asap", - transforms=None, - ) - -def test_tile_dataset_load_coordinates_delegates_to_shared_helpers(monkeypatch): - pytest.importorskip("torch") - pytest.importorskip("wholeslidedata") - from slide2vec.data.dataset import TileDataset - - captured = {} - - def fake_coordinate_arrays(tiling_result): - captured["arrays_arg"] = tiling_result - return np.array([9, 10]), np.array([11, 12]) - - def fake_coordinate_matrix(tiling_result): - captured["matrix_arg"] = tiling_result - return np.array([[9, 11], [10, 12]], dtype=np.int64) - - monkeypatch.setattr("slide2vec.data.dataset.coordinate_arrays", fake_coordinate_arrays) - monkeypatch.setattr("slide2vec.data.dataset.coordinate_matrix", fake_coordinate_matrix) - monkeypatch.setattr(TileDataset, "scale_coordinates", lambda self: np.array([[1, 2], [3, 4]], dtype=np.int64)) - - tiling_result = SimpleNamespace( - target_spacing_um=0.5, - target_tile_size_px=4, - read_spacing_um=0.5, - read_tile_size_px=2, - tile_size_lv0=224, - x=np.array([0]), - y=np.array([1]), - ) - dataset = TileDataset( - sample_id="slide-a", - wsi_path=Path("/tmp/slide-a.svs"), - mask_path=None, - tiling_result=tiling_result, - backend="asap", - transforms=None, - ) - - assert captured["arrays_arg"] is tiling_result - assert captured["matrix_arg"] is tiling_result - np.testing.assert_array_equal(dataset.x, np.array([9, 10])) - np.testing.assert_array_equal(dataset.y, np.array([11, 12])) - np.testing.assert_array_equal(dataset.coordinates, np.array([[9, 11], [10, 12]], dtype=np.int64)) + for preset_path in preset_paths: + lines_with_comments = [ + f"{index}: {line}" + for index, line in enumerate(preset_path.read_text().splitlines(), start=1) + if "#" in line + ] + assert lines_with_comments == [], f"{preset_path} still contains comments: {lines_with_comments}" def test_npz_artifacts_round_trip(tmp_path: Path): features = np.arange(12, dtype=np.float32).reshape(3, 4) @@ -162,7 +62,7 @@ def test_npz_artifacts_round_trip(tmp_path: Path): features, output_dir=tmp_path, output_format="npz", - metadata={"tiles_npz_path": "/tmp/sample-a.tiles.npz"}, + metadata={"coordinates_npz_path": "/tmp/sample-a.coordinates.npz"}, tile_index=np.array([0, 1, 2], dtype=np.int64), ) @@ -172,7 +72,7 @@ def test_npz_artifacts_round_trip(tmp_path: Path): np.testing.assert_array_equal(loaded, features) assert artifact.path == tmp_path / "tile_embeddings" / "sample-a.npz" assert metadata["sample_id"] == "sample-a" - assert metadata["tiles_npz_path"] == "/tmp/sample-a.tiles.npz" + assert metadata["coordinates_npz_path"] == "/tmp/sample-a.coordinates.npz" def test_pt_artifacts_round_trip(tmp_path: Path): torch = pytest.importorskip("torch") @@ -224,6 +124,12 @@ def test_execution_options_validate_num_gpus(): with pytest.raises(ValueError, match="num_gpus"): ExecutionOptions(num_gpus=0) +def test_model_from_pretrained_canonicalizes_conchv15_alias(): + model = Model.from_pretrained("conchv1.5") + + assert model.name == "conchv15" + assert model.level == "tile" + def test_execution_options_defaults_to_all_available_gpus(monkeypatch): import slide2vec.api as api @@ -243,10 +149,13 @@ def test_execution_options_from_config_maps_cli_fields(tmp_path: Path): save_latents=True, ), speed=SimpleNamespace( - fp16=True, + precision="bf16", num_workers=6, num_workers_embedding=2, num_gpus=3, + prefetch_factor_embedding=5, + persistent_workers_embedding=False, + gpu_batch_preprocessing=False, ), ) @@ -257,7 +166,10 @@ def test_execution_options_from_config_maps_cli_fields(tmp_path: Path): assert execution.batch_size == 4 assert execution.num_workers == 2 assert execution.num_gpus == 3 - assert execution.mixed_precision is True + assert execution.precision == "bf16" + assert execution.prefetch_factor == 5 + assert execution.persistent_workers is False + assert execution.gpu_batch_preprocessing is False assert execution.save_tile_embeddings is True assert execution.save_latents is True @@ -273,18 +185,25 @@ def test_execution_options_from_config_defaults_to_all_available_gpus_when_unset save_latents=False, ), speed=SimpleNamespace( - fp16=False, + precision="fp32", num_workers=6, num_workers_embedding=2, num_gpus=None, + prefetch_factor_embedding=3, + persistent_workers_embedding=True, + gpu_batch_preprocessing=True, ), ) execution = api.ExecutionOptions.from_config(cfg) assert execution.num_gpus == 6 + assert execution.precision == "fp32" + assert execution.prefetch_factor == 3 + assert execution.persistent_workers is True + assert execution.gpu_batch_preprocessing is True -def test_execution_options_from_config_disables_mixed_precision_for_cpu_runs(monkeypatch, tmp_path: Path): +def test_execution_options_from_config_forces_fp32_for_cpu_runs(monkeypatch, tmp_path: Path): import slide2vec.api as api monkeypatch.setattr(api, "_default_num_gpus", lambda: 8) @@ -296,16 +215,19 @@ def test_execution_options_from_config_disables_mixed_precision_for_cpu_runs(mon save_latents=False, ), speed=SimpleNamespace( - fp16=True, + precision="bf16", num_workers=4, num_workers_embedding=4, num_gpus=1, + prefetch_factor_embedding=4, + persistent_workers_embedding=True, + gpu_batch_preprocessing=True, ), ) execution = api.ExecutionOptions.from_config(cfg, run_on_cpu=True) - assert execution.mixed_precision is False + assert execution.precision == "fp32" assert execution.num_gpus == 1 def test_preprocessing_with_backend_preserves_other_fields(): @@ -318,6 +240,7 @@ def test_preprocessing_with_backend_preserves_other_fields(): tissue_threshold=0.4, drop_holes=True, use_padding=False, + read_coordinates_from=Path("/tmp/coordinates"), read_tiles_from=Path("/tmp/tiles"), resume=True, segmentation={"downsample": 32}, @@ -333,8 +256,14 @@ def test_preprocessing_with_backend_preserves_other_fields(): assert updated.segmentation == base.segmentation assert updated.filtering == base.filtering assert updated.preview == base.preview + assert updated.read_coordinates_from == base.read_coordinates_from + assert updated.read_tiles_from == base.read_tiles_from assert updated is not base + +def test_preprocessing_config_defaults_backend_to_auto(): + assert PreprocessingConfig().backend == "auto" + def test_execution_options_with_output_dir_preserves_other_fields(tmp_path: Path): base = ExecutionOptions( output_dir=None, @@ -342,7 +271,10 @@ def test_execution_options_with_output_dir_preserves_other_fields(tmp_path: Path batch_size=8, num_workers=3, num_gpus=2, - mixed_precision=True, + precision="bf16", + prefetch_factor=6, + persistent_workers=False, + gpu_batch_preprocessing=False, save_tile_embeddings=True, save_latents=True, ) @@ -354,7 +286,10 @@ def test_execution_options_with_output_dir_preserves_other_fields(tmp_path: Path assert updated.batch_size == base.batch_size assert updated.num_workers == base.num_workers assert updated.num_gpus == base.num_gpus - assert updated.mixed_precision == base.mixed_precision + assert updated.precision == base.precision + assert updated.prefetch_factor == base.prefetch_factor + assert updated.persistent_workers == base.persistent_workers + assert updated.gpu_batch_preprocessing == base.gpu_batch_preprocessing assert updated.save_tile_embeddings == base.save_tile_embeddings assert updated.save_latents == base.save_latents assert updated is not base @@ -376,13 +311,14 @@ def test_cli_build_model_and_pipeline_delegates_to_public_api(monkeypatch, tmp_p input_size=224, patch_size=256, token_size=16, + allow_non_recommended_settings=True, save_tile_embeddings=False, save_latents=False, ), - speed=SimpleNamespace(fp16=False, num_workers=2, num_workers_embedding=3, num_gpus=2), + speed=SimpleNamespace(precision="fp32", num_workers=2, num_workers_embedding=3, num_gpus=2), tiling=SimpleNamespace( backend="asap", - read_tiles_from=None, + read_coordinates_from=None, params=SimpleNamespace( target_spacing_um=0.5, target_tile_size_px=224, @@ -422,17 +358,189 @@ def fake_from_pretrained(*model_args, **model_kwargs): assert returned_cfg is cfg assert captured["model_args"] == ("virchow2",) assert captured["model_kwargs"]["device"] == "cpu" + assert captured["model_kwargs"]["allow_non_recommended_settings"] is True assert captured["preprocessing"].backend == "asap" assert captured["execution"].output_dir == tmp_path assert captured["execution"].num_gpus == 1 -def test_preprocessing_config_from_config_combines_user_facing_preprocessing_fields(): + +def test_get_cfg_from_args_rejects_non_recommended_model_settings_by_default(tmp_path: Path): + pytest.importorskip("omegaconf") + + from slide2vec.utils.config import get_cfg_from_args + + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "csv: /tmp/slides.csv", + "output_dir: output", + "tiling:", + " params:", + " target_spacing_um: 1.0", + " target_tile_size_px: 256", + "model:", + " name: virchow2", + " level: tile", + ] + ) + ) + + args = SimpleNamespace(config_file=str(config_path), output_dir=None, opts=[]) + + with pytest.raises(ValueError, match="allow_non_recommended_settings"): + get_cfg_from_args(args) + + +def test_get_cfg_from_args_warns_when_non_recommended_model_settings_are_allowed( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +): + pytest.importorskip("omegaconf") + + from slide2vec.utils.config import get_cfg_from_args + + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "csv: /tmp/slides.csv", + "output_dir: output", + "tiling:", + " params:", + " target_spacing_um: 1.0", + " target_tile_size_px: 256", + "model:", + " name: virchow2", + " level: tile", + " allow_non_recommended_settings: true", + ] + ) + ) + + args = SimpleNamespace(config_file=str(config_path), output_dir=None, opts=[]) + + with caplog.at_level("WARNING", logger="slide2vec"): + cfg = get_cfg_from_args(args) + + assert cfg.model.allow_non_recommended_settings is True + assert "virchow2" in caplog.text + assert "recommended" in caplog.text + + +def test_get_cfg_from_args_rejects_non_recommended_model_precision_by_default(tmp_path: Path): + pytest.importorskip("omegaconf") + + from slide2vec.utils.config import get_cfg_from_args + + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "csv: /tmp/slides.csv", + "output_dir: output", + "tiling:", + " params:", + " target_spacing_um: 0.5", + " target_tile_size_px: 224", + "model:", + " name: virchow2", + " level: tile", + "speed:", + " precision: fp32", + ] + ) + ) + + args = SimpleNamespace(config_file=str(config_path), output_dir=None, opts=[]) + + with pytest.raises(ValueError, match="requested precision=fp32"): + get_cfg_from_args(args) + + +def test_get_cfg_from_args_warns_when_non_recommended_model_precision_is_allowed( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +): + pytest.importorskip("omegaconf") + + from slide2vec.utils.config import get_cfg_from_args + + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "csv: /tmp/slides.csv", + "output_dir: output", + "tiling:", + " params:", + " target_spacing_um: 0.5", + " target_tile_size_px: 224", + "model:", + " name: virchow2", + " level: tile", + " allow_non_recommended_settings: true", + "speed:", + " precision: fp32", + ] + ) + ) + + args = SimpleNamespace(config_file=str(config_path), output_dir=None, opts=[]) + + with caplog.at_level("WARNING", logger="slide2vec"): + cfg = get_cfg_from_args(args) + + assert cfg.speed.precision == "fp32" + assert "requested precision=fp32" in caplog.text + + +def test_get_cfg_from_args_allows_cpu_runs_with_non_recommended_precision(tmp_path: Path): + pytest.importorskip("omegaconf") + + from slide2vec.utils.config import get_cfg_from_args + + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "csv: /tmp/slides.csv", + "output_dir: output", + "tiling:", + " params:", + " target_spacing_um: 0.5", + " target_tile_size_px: 224", + "model:", + " name: prism", + " level: slide", + "speed:", + " precision: fp32", + ] + ) + ) + + args = SimpleNamespace( + config_file=str(config_path), + output_dir=None, + opts=[], + run_on_cpu=True, + ) + + cfg = get_cfg_from_args(args) + + assert cfg.model.name == "prism" + assert cfg.speed.precision == "fp32" + + +def test_preprocessing_config_from_config_defaults_read_coordinates_from_output_dir(): cfg = SimpleNamespace( resume=True, + output_dir="/tmp/run-001", save_previews=False, tiling=SimpleNamespace( backend="asap", - read_tiles_from="/tmp/precomputed", + read_coordinates_from=None, + read_tiles_from=None, params=SimpleNamespace( target_spacing_um=0.5, target_tile_size_px=224, @@ -452,7 +560,9 @@ def test_preprocessing_config_from_config_combines_user_facing_preprocessing_fie assert preprocessing.backend == "asap" assert preprocessing.target_tile_size_px == 224 - assert preprocessing.read_tiles_from == Path("/tmp/precomputed") + assert preprocessing.read_coordinates_from == Path("/tmp/run-001/coordinates") + assert preprocessing.read_tiles_from is None + assert not hasattr(preprocessing, "save_tiles") assert preprocessing.resume is True assert preprocessing.segmentation == {"downsample": 64} assert preprocessing.filtering == {"ref_tile_size": 224} @@ -462,6 +572,69 @@ def test_preprocessing_config_from_config_combines_user_facing_preprocessing_fie "downsample": 32, } + +def test_preprocessing_config_from_config_preserves_tile_store_dir(): + cfg = SimpleNamespace( + output_dir="/tmp/run-002", + resume=False, + save_previews=True, + speed=SimpleNamespace(num_cucim_workers=6), + tiling=SimpleNamespace( + backend="asap", + read_coordinates_from=None, + read_tiles_from="/tmp/tile-store", + params=SimpleNamespace( + target_spacing_um=0.5, + target_tile_size_px=224, + tolerance=0.07, + overlap=0.0, + tissue_threshold=0.1, + drop_holes=False, + use_padding=True, + ), + seg_params={"downsample": 64}, + filter_params={"ref_tile_size": 224}, + preview=SimpleNamespace(downsample=32), + ), + ) + + preprocessing = PreprocessingConfig.from_config(cfg) + + assert preprocessing.read_coordinates_from == Path("/tmp/run-002/coordinates") + assert preprocessing.read_tiles_from == Path("/tmp/tile-store") + assert preprocessing.num_cucim_workers == 6 + assert not hasattr(preprocessing, "save_tiles") + + +def test_preprocessing_config_from_config_falls_back_to_legacy_tiling_num_cucim_workers(): + cfg = SimpleNamespace( + output_dir="/tmp/run-003", + resume=False, + save_previews=False, + tiling=SimpleNamespace( + backend="asap", + num_cucim_workers=5, + read_coordinates_from=None, + read_tiles_from=None, + params=SimpleNamespace( + target_spacing_um=0.5, + target_tile_size_px=224, + tolerance=0.07, + overlap=0.0, + tissue_threshold=0.1, + drop_holes=False, + use_padding=True, + ), + seg_params={"downsample": 64}, + filter_params={"ref_tile_size": 224}, + preview=SimpleNamespace(downsample=32), + ), + ) + + preprocessing = PreprocessingConfig.from_config(cfg) + + assert preprocessing.num_cucim_workers == 5 + def test_validate_removed_options_rejects_legacy_preview_keys(): pytest.importorskip("omegaconf") from omegaconf import OmegaConf diff --git a/tests/test_regression_inference.py b/tests/test_regression_inference.py index 73039f8..2b665ca 100644 --- a/tests/test_regression_inference.py +++ b/tests/test_regression_inference.py @@ -2,6 +2,7 @@ import sys from pathlib import Path from types import SimpleNamespace +import types import numpy as np import pandas as pd @@ -39,6 +40,234 @@ def make_slide( spacing_at_level_0=spacing_at_level_0, ) + +def test_load_model_merges_preprocessing_defaults_for_cross_file_interpolations(monkeypatch): + import slide2vec.inference as inference + + class AttrDict(dict): + def __getattr__(self, name): + try: + return self[name] + except KeyError as exc: + raise AttributeError(name) from exc + + def __setattr__(self, name, value): + self[name] = value + + def convert(value): + if isinstance(value, dict): + return AttrDict({key: convert(item) for key, item in value.items()}) + if isinstance(value, list): + return [convert(item) for item in value] + return value + + def merge_values(left, right): + if isinstance(left, dict) and isinstance(right, dict): + merged = AttrDict({key: convert(value) for key, value in left.items()}) + for key, value in right.items(): + if key in merged: + merged[key] = merge_values(merged[key], value) + else: + merged[key] = convert(value) + return merged + return convert(right) + + def lookup(root, path): + current = root + for segment in path.split("."): + current = current[segment] + return current + + def resolve_value(root, value): + if isinstance(value, dict): + for key, item in list(value.items()): + value[key] = resolve_value(root, item) + return value + if isinstance(value, list): + for index, item in enumerate(list(value)): + value[index] = resolve_value(root, item) + return value + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + return resolve_value(root, lookup(root, value[2:-1])) + return value + + class FakeOmegaConf: + @staticmethod + def create(value): + return convert(value) + + @staticmethod + def merge(*values): + merged = AttrDict() + for value in values: + merged = merge_values(merged, convert(value)) + return merged + + @staticmethod + def resolve(value): + resolve_value(value, value) + + captured: dict[str, object] = {} + + class FakeBackend: + device = "cpu" + features_dim = 128 + + def to(self, _device): + return self + + def get_transforms(self): + return "TRANSFORMS" + + class FakeModelFactory: + def __init__(self, options): + captured["options"] = options + + def get_model(self): + return FakeBackend() + + def fake_load_config(*parts): + if parts == ("preprocessing", "default"): + return { + "tiling": { + "params": { + "target_tile_size_px": 256, + } + } + } + if parts == ("models", "default"): + return { + "model": { + "mode": "cls", + "input_size": "${tiling.params.target_tile_size_px}", + "patch_size": 256, + "token_size": 16, + "normalize_embeddings": False, + } + } + if parts == ("models", "h0-mini"): + return {"model": {}} + raise AssertionError(parts) + + monkeypatch.setitem(sys.modules, "omegaconf", types.SimpleNamespace(OmegaConf=FakeOmegaConf)) + monkeypatch.setitem(sys.modules, "slide2vec.models", types.SimpleNamespace(ModelFactory=FakeModelFactory)) + monkeypatch.setitem(sys.modules, "slide2vec.resources", types.SimpleNamespace(load_config=fake_load_config)) + monkeypatch.setattr(inference, "_resolve_device", lambda requested, device: device) + + loaded = inference.load_model(name="h0-mini", level="tile") + + assert captured["options"].name == "h0-mini" + assert captured["options"].level == "tile" + assert captured["options"].mode == "cls" + assert captured["options"].input_size == 256 + assert loaded.feature_dim == 128 + + +def test_load_model_uses_conchv15_preset_for_canonicalized_alias(monkeypatch): + import slide2vec.inference as inference + + class AttrDict(dict): + def __getattr__(self, name): + try: + return self[name] + except KeyError as exc: + raise AttributeError(name) from exc + + def __setattr__(self, name, value): + self[name] = value + + def convert(value): + if isinstance(value, dict): + return AttrDict({key: convert(item) for key, item in value.items()}) + if isinstance(value, list): + return [convert(item) for item in value] + return value + + def merge_values(left, right): + if isinstance(left, dict) and isinstance(right, dict): + merged = AttrDict({key: convert(value) for key, value in left.items()}) + for key, value in right.items(): + if key in merged: + merged[key] = merge_values(merged[key], value) + else: + merged[key] = convert(value) + return merged + return convert(right) + + def lookup(root, path): + current = root + for segment in path.split("."): + current = current[segment] + return current + + def resolve_value(root, value): + if isinstance(value, dict): + for key, item in list(value.items()): + value[key] = resolve_value(root, item) + return value + if isinstance(value, list): + for index, item in enumerate(list(value)): + value[index] = resolve_value(root, item) + return value + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + return resolve_value(root, lookup(root, value[2:-1])) + return value + + class FakeOmegaConf: + @staticmethod + def create(value): + return convert(value) + + @staticmethod + def merge(*values): + merged = AttrDict() + for value in values: + merged = merge_values(merged, convert(value)) + return merged + + @staticmethod + def resolve(value): + resolve_value(value, value) + + captured = {} + + class FakeBackend: + device = "cpu" + features_dim = 768 + + def to(self, _device): + return self + + def get_transforms(self): + return "TRANSFORMS" + + class FakeModelFactory: + def __init__(self, options): + captured["options"] = options + + def get_model(self): + return FakeBackend() + + def fake_load_config(*parts): + if parts == ("preprocessing", "default"): + return {"tiling": {"params": {"target_tile_size_px": 256}}} + if parts == ("models", "default"): + return {"model": {"mode": "cls", "input_size": "${tiling.params.target_tile_size_px}"}} + if parts == ("models", "conchv15"): + return {"model": {"name": "conchv15", "input_size": 448}} + raise AssertionError(parts) + + monkeypatch.setitem(sys.modules, "omegaconf", types.SimpleNamespace(OmegaConf=FakeOmegaConf)) + monkeypatch.setitem(sys.modules, "slide2vec.models", types.SimpleNamespace(ModelFactory=FakeModelFactory)) + monkeypatch.setitem(sys.modules, "slide2vec.resources", types.SimpleNamespace(load_config=fake_load_config)) + monkeypatch.setattr(inference, "_resolve_device", lambda requested, device: device) + + loaded = inference.load_model(name="conchv15", level="tile") + + assert captured["options"].name == "conchv15" + assert captured["options"].input_size == 448 + assert loaded.feature_dim == 768 + def test_pipeline_run_uses_distributed_embedding_path_when_num_gpus_is_greater_than_one( monkeypatch, tmp_path: Path, @@ -316,9 +545,9 @@ def test_run_pipeline_local_branch_persists_completed_slides_before_later_failur ] process_list_path = tmp_path / "process_list.csv" process_list_path.write_text( - "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,tiles_npz_path,tiles_meta_path,feature_status,error,traceback\n" - "slide-a,/tmp/slide-a.svs,,,success,1,/tmp/slide-a.tiles.npz,/tmp/slide-a.tiles.meta.json,tbp,,\n" - "slide-b,/tmp/slide-b.svs,,,success,1,/tmp/slide-b.tiles.npz,/tmp/slide-b.tiles.meta.json,tbp,,\n", + "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,coordinates_npz_path,coordinates_meta_path,feature_status,error,traceback\n" + "slide-a,/tmp/slide-a.svs,,,success,1,/tmp/slide-a.coordinates.npz,/tmp/slide-a.coordinates.meta.json,tbp,,\n" + "slide-b,/tmp/slide-b.svs,,,success,1,/tmp/slide-b.coordinates.npz,/tmp/slide-b.coordinates.meta.json,tbp,,\n", encoding="utf-8", ) @@ -367,9 +596,9 @@ def test_run_pipeline_resume_skips_successful_local_embeddings(monkeypatch, tmp_ ] process_list_path = tmp_path / "process_list.csv" process_list_path.write_text( - "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,tiles_npz_path,tiles_meta_path,feature_status,error,traceback\n" - "slide-a,/tmp/slide-a.svs,,,success,1,/tmp/slide-a.tiles.npz,/tmp/slide-a.tiles.meta.json,success,,\n" - "slide-b,/tmp/slide-b.svs,,,success,1,/tmp/slide-b.tiles.npz,/tmp/slide-b.tiles.meta.json,tbp,,\n", + "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,coordinates_npz_path,coordinates_meta_path,feature_status,error,traceback\n" + "slide-a,/tmp/slide-a.svs,,,success,1,/tmp/slide-a.coordinates.npz,/tmp/slide-a.coordinates.meta.json,success,,\n" + "slide-b,/tmp/slide-b.svs,,,success,1,/tmp/slide-b.coordinates.npz,/tmp/slide-b.coordinates.meta.json,tbp,,\n", encoding="utf-8", ) write_tile_embeddings( @@ -423,24 +652,24 @@ def test_run_pipeline_local_persists_completed_embeddings_before_later_slide_fai x=np.array([0, 1]), y=np.array([2, 3]), tile_size_lv0=224, - tiles_npz_path=Path("/tmp/slide-a.tiles.npz"), - tiles_meta_path=Path("/tmp/slide-a.tiles.meta.json"), + coordinates_npz_path=Path("/tmp/slide-a.coordinates.npz"), + coordinates_meta_path=Path("/tmp/slide-a.coordinates.meta.json"), ), SimpleNamespace( x=np.array([4, 5]), y=np.array([6, 7]), tile_size_lv0=224, - tiles_npz_path=Path("/tmp/slide-b.tiles.npz"), - tiles_meta_path=Path("/tmp/slide-b.tiles.meta.json"), + coordinates_npz_path=Path("/tmp/slide-b.coordinates.npz"), + coordinates_meta_path=Path("/tmp/slide-b.coordinates.meta.json"), ), ] process_list_path = tmp_path / "process_list.csv" process_list_path.write_text( - "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,tiles_npz_path,tiles_meta_path,error,traceback\n" + "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,coordinates_npz_path,coordinates_meta_path,error,traceback\n" "slide-a,/tmp/slide-a.svs,,," # spacing_at_level_0 - "success,2,/tmp/slide-a.tiles.npz,/tmp/slide-a.tiles.meta.json,,\n" + "success,2,/tmp/slide-a.coordinates.npz,/tmp/slide-a.coordinates.meta.json,,\n" "slide-b,/tmp/slide-b.svs,,," - "success,2,/tmp/slide-b.tiles.npz,/tmp/slide-b.tiles.meta.json,,\n", + "success,2,/tmp/slide-b.coordinates.npz,/tmp/slide-b.coordinates.meta.json,,\n", encoding="utf-8", ) @@ -517,13 +746,40 @@ def fake_tile_slides(slides, **kwargs): inference._tile_slides( [slide], - PreprocessingConfig(), + PreprocessingConfig(on_the_fly=False), output_dir=tmp_path, num_workers=0, ) assert captured["slides"][0].spacing_at_level_0 == pytest.approx(0.25) assert captured["kwargs"]["preview"] == "preview" + assert captured["kwargs"]["save_tiles"] is True + + +def test_tile_slides_skips_saving_tiles_when_external_store_is_configured(monkeypatch, tmp_path: Path): + import slide2vec.inference as inference + + captured = {} + + def fake_tile_slides(slides, **kwargs): + captured["slides"] = list(slides) + captured["kwargs"] = kwargs + + monkeypatch.setitem(sys.modules, "hs2p", SimpleNamespace(tile_slides=fake_tile_slides)) + monkeypatch.setattr( + inference, + "_build_hs2p_configs", + lambda preprocessing: ("tiling", "segmentation", "filtering", "preview", None, False), + ) + + inference._tile_slides( + [make_slide("slide-a")], + PreprocessingConfig(read_tiles_from=Path("/tmp/existing-tiles")), + output_dir=tmp_path, + num_workers=0, + ) + + assert captured["kwargs"]["save_tiles"] is False def test_build_hs2p_configs_constructs_preview_config(monkeypatch): @@ -570,7 +826,7 @@ def __init__(self, **kwargs): preview={"save_mask_preview": True, "save_tiling_preview": False, "downsample": 32}, ) - tiling_cfg, segmentation_cfg, filtering_cfg, preview_cfg, read_tiles_from, resume = ( + tiling_cfg, segmentation_cfg, filtering_cfg, preview_cfg, read_coordinates_from, resume = ( inference._build_hs2p_configs(preprocessing) ) @@ -582,7 +838,7 @@ def __init__(self, **kwargs): "save_tiling_preview": False, "downsample": 32, } - assert read_tiles_from is None + assert read_coordinates_from is None assert resume is False @@ -591,8 +847,8 @@ def test_prepare_tiled_slides_records_spacing_at_level_0_in_process_list(monkeyp process_list_path = tmp_path / "process_list.csv" process_list_path.write_text( - "sample_id,image_path,mask_path,tiling_status,num_tiles,tiles_npz_path,tiles_meta_path,error,traceback\n" - "slide-a,/tmp/slide-a.svs,,success,1,/tmp/slide-a.tiles.npz,/tmp/slide-a.tiles.meta.json,,\n", + "sample_id,image_path,mask_path,tiling_status,num_tiles,coordinates_npz_path,coordinates_meta_path,error,traceback\n" + "slide-a,/tmp/slide-a.svs,,success,1,/tmp/slide-a.coordinates.npz,/tmp/slide-a.coordinates.meta.json,,\n", encoding="utf-8", ) @@ -612,13 +868,22 @@ def test_prepare_tiled_slides_records_spacing_at_level_0_in_process_list(monkeyp assert recorded.loc[0, "spacing_at_level_0"] == pytest.approx(0.25) +def test_resolve_slide_backend_uses_tiling_result_backend_for_auto(): + import slide2vec.inference as inference + + assert inference._resolve_slide_backend(PreprocessingConfig(backend="auto"), SimpleNamespace(backend="cucim")) == "cucim" + assert inference._resolve_slide_backend(PreprocessingConfig(backend="auto"), SimpleNamespace(backend="asap")) == "asap" + assert inference._resolve_slide_backend(PreprocessingConfig(backend="auto"), SimpleNamespace()) == "asap" + assert inference._resolve_slide_backend(PreprocessingConfig(backend="cucim"), SimpleNamespace(backend="asap")) == "cucim" + + def test_load_successful_tiled_slides_preserves_spacing_at_level_0(monkeypatch, tmp_path: Path): import slide2vec.inference as inference process_list_path = tmp_path / "process_list.csv" process_list_path.write_text( - "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,tiles_npz_path,tiles_meta_path,error,traceback\n" - "slide-a,/tmp/slide-a.svs,,0.25,success,1,/tmp/slide-a.tiles.npz,/tmp/slide-a.tiles.meta.json,,\n", + "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,coordinates_npz_path,coordinates_meta_path,error,traceback\n" + "slide-a,/tmp/slide-a.svs,,0.25,success,1,/tmp/slide-a.coordinates.npz,/tmp/slide-a.coordinates.meta.json,,\n", encoding="utf-8", ) @@ -908,8 +1173,8 @@ def test_direct_embed_slides_allows_no_output_dir_and_optional_persistence(monke x=np.array([0, 1]), y=np.array([2, 3]), tile_size_lv0=224, - tiles_npz_path=Path("/tmp/slide-a.tiles.npz"), - tiles_meta_path=Path("/tmp/slide-a.tiles.meta.json"), + coordinates_npz_path=Path("/tmp/slide-a.coordinates.npz"), + coordinates_meta_path=Path("/tmp/slide-a.coordinates.meta.json"), ) embedded = EmbeddedSlide( sample_id="slide-a", @@ -975,24 +1240,24 @@ def test_direct_embed_slides_persists_completed_embeddings_before_later_slide_fa x=np.array([0, 1]), y=np.array([2, 3]), tile_size_lv0=224, - tiles_npz_path=Path("/tmp/slide-a.tiles.npz"), - tiles_meta_path=Path("/tmp/slide-a.tiles.meta.json"), + coordinates_npz_path=Path("/tmp/slide-a.coordinates.npz"), + coordinates_meta_path=Path("/tmp/slide-a.coordinates.meta.json"), ), SimpleNamespace( x=np.array([4, 5]), y=np.array([6, 7]), tile_size_lv0=224, - tiles_npz_path=Path("/tmp/slide-b.tiles.npz"), - tiles_meta_path=Path("/tmp/slide-b.tiles.meta.json"), + coordinates_npz_path=Path("/tmp/slide-b.coordinates.npz"), + coordinates_meta_path=Path("/tmp/slide-b.coordinates.meta.json"), ), ] process_list_path = tmp_path / "process_list.csv" process_list_path.write_text( - "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,tiles_npz_path,tiles_meta_path,error,traceback\n" + "sample_id,image_path,mask_path,spacing_at_level_0,tiling_status,num_tiles,coordinates_npz_path,coordinates_meta_path,error,traceback\n" "slide-a,/tmp/slide-a.svs,,," # spacing_at_level_0 - "success,2,/tmp/slide-a.tiles.npz,/tmp/slide-a.tiles.meta.json,,\n" + "success,2,/tmp/slide-a.coordinates.npz,/tmp/slide-a.coordinates.meta.json,,\n" "slide-b,/tmp/slide-b.svs,,," - "success,2,/tmp/slide-b.tiles.npz,/tmp/slide-b.tiles.meta.json,,\n", + "success,2,/tmp/slide-b.coordinates.npz,/tmp/slide-b.coordinates.meta.json,,\n", encoding="utf-8", ) @@ -1050,8 +1315,8 @@ def test_slide_level_pipeline_skips_tile_artifacts_when_save_tile_embeddings_is_ x=np.array([0, 1]), y=np.array([2, 3]), tile_size_lv0=224, - tiles_npz_path=Path("/tmp/slide-a.tiles.npz"), - tiles_meta_path=Path("/tmp/slide-a.tiles.meta.json"), + coordinates_npz_path=Path("/tmp/slide-a.coordinates.npz"), + coordinates_meta_path=Path("/tmp/slide-a.coordinates.meta.json"), ) embedded = EmbeddedSlide( sample_id="slide-a", @@ -1278,3 +1543,830 @@ def __call__(self, image): assert result.shape == (0, 5) assert result.dtype == torch.float32 + + +def test_region_batch_preprocessor_resizes_whole_region_before_unfolding(): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + loaded = inference.LoadedModel( + name="region-model", + level="region", + model=SimpleNamespace(tile_size=2), + transforms=SimpleNamespace(transforms=[]), + feature_dim=3, + device=torch.device("cpu"), + ) + tiling_result = SimpleNamespace( + target_tile_size_px=4, + read_tile_size_px=2, + ) + execution = ExecutionOptions(gpu_batch_preprocessing=False) + + preprocess = inference._build_batch_preprocessor( + loaded, + SimpleNamespace(level="region"), + tiling_result, + execution=execution, + ) + + batch = torch.full((1, 3, 2, 2), 255, dtype=torch.uint8) + processed = preprocess(batch) + + assert processed.shape == (1, 4, 3, 2, 2) + assert processed.dtype == torch.float32 + assert torch.allclose(processed, torch.ones_like(processed)) + + +def test_region_batch_preprocessor_unfolds_then_applies_tile_transforms(): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + class Resize: + def __init__(self, size): + self.size = size + + loaded = inference.LoadedModel( + name="region-model", + level="region", + model=SimpleNamespace(tile_size=2), + transforms=SimpleNamespace(transforms=[Resize(1)]), + feature_dim=3, + device=torch.device("cpu"), + ) + tiling_result = SimpleNamespace( + target_tile_size_px=4, + read_tile_size_px=4, + ) + execution = ExecutionOptions(gpu_batch_preprocessing=False) + + preprocess = inference._build_batch_preprocessor( + loaded, + SimpleNamespace(level="region"), + tiling_result, + execution=execution, + ) + + quadrant_values = torch.tensor( + [ + [ + [0, 0, 85, 85], + [0, 0, 85, 85], + [170, 170, 255, 255], + [170, 170, 255, 255], + ] + ], + dtype=torch.uint8, + ) + batch = quadrant_values.unsqueeze(0).repeat(1, 3, 1, 1) + processed = preprocess(batch) + + expected = torch.tensor([0.0, 85.0 / 255.0, 170.0 / 255.0, 1.0], dtype=torch.float32) + + assert processed.shape == (1, 4, 3, 1, 1) + assert torch.allclose(processed[0, :, 0, 0, 0], expected, atol=1e-5) + assert torch.allclose(processed[0, :, 1, 0, 0], expected, atol=1e-5) + assert torch.allclose(processed[0, :, 2, 0, 0], expected, atol=1e-5) + + +def test_build_batch_transform_spec_supports_nested_region_unfolding_transform(): + import slide2vec.inference as inference + + class Compose: + def __init__(self, transforms): + self.transforms = transforms + + class RegionUnfolding: + def __init__(self, tile_size): + self.tile_size = tile_size + + class Normalize: + def __init__(self, mean, std): + self.mean = mean + self.std = std + + transforms = Compose( + [ + Compose( + [ + RegionUnfolding(8), + Normalize((0.5, 0.4, 0.3), (0.2, 0.3, 0.4)), + ] + ) + ] + ) + + spec = inference._build_batch_transform_spec(transforms) + + assert spec is not None + assert spec.region_unfold_tile_size == 8 + assert spec.mean == (0.5, 0.4, 0.3) + assert spec.std == (0.2, 0.3, 0.4) + + +def test_region_batch_preprocessor_uses_region_unfolding_from_transform_stack(): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + class Compose: + def __init__(self, transforms): + self.transforms = transforms + + class RegionUnfolding: + def __init__(self, tile_size): + self.tile_size = tile_size + + loaded = inference.LoadedModel( + name="region-model", + level="region", + model=SimpleNamespace(tile_size=4), + transforms=Compose([RegionUnfolding(4)]), + feature_dim=3, + device=torch.device("cpu"), + ) + tiling_result = SimpleNamespace( + target_tile_size_px=8, + read_tile_size_px=8, + ) + + preprocess = inference._build_batch_preprocessor( + loaded, + SimpleNamespace(level="region"), + tiling_result, + execution=ExecutionOptions(gpu_batch_preprocessing=False), + ) + + batch = torch.ones((1, 3, 8, 8), dtype=torch.uint8) + processed = preprocess(batch) + + assert processed.shape == (1, 4, 3, 4, 4) + + +def test_region_batch_preprocessor_rejects_mismatched_region_unfolding_tile_size(): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + class Compose: + def __init__(self, transforms): + self.transforms = transforms + + class RegionUnfolding: + def __init__(self, tile_size): + self.tile_size = tile_size + + loaded = inference.LoadedModel( + name="region-model", + level="region", + model=SimpleNamespace(tile_size=2), + transforms=Compose([RegionUnfolding(4)]), + feature_dim=3, + device=torch.device("cpu"), + ) + tiling_result = SimpleNamespace( + target_tile_size_px=8, + read_tile_size_px=8, + ) + + preprocess = inference._build_batch_preprocessor( + loaded, + SimpleNamespace(level="region"), + tiling_result, + execution=ExecutionOptions(gpu_batch_preprocessing=False), + ) + + with pytest.raises(ValueError, match="tile_size"): + preprocess(torch.ones((1, 3, 8, 8), dtype=torch.uint8)) + + +def test_serialize_execution_preserves_loader_optimization_fields(): + import slide2vec.inference as inference + + execution = ExecutionOptions( + output_dir=Path("/tmp/output"), + batch_size=64, + num_workers=8, + num_gpus=2, + precision="bf16", + prefetch_factor=7, + persistent_workers=False, + gpu_batch_preprocessing=False, + save_tile_embeddings=True, + save_latents=True, + ) + + payload = inference._serialize_execution(execution) + restored = inference.deserialize_execution(payload) + + assert payload["prefetch_factor"] == 7 + assert payload["persistent_workers"] is False + assert payload["gpu_batch_preprocessing"] is False + assert payload["precision"] == "bf16" + assert restored.prefetch_factor == 7 + assert restored.persistent_workers is False + assert restored.gpu_batch_preprocessing is False + assert restored.precision == "bf16" + + +def test_compute_tile_embeddings_for_slide_uses_batched_loader_knobs(monkeypatch): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + captured = {} + + class DummyLoader: + def __init__(self, dataset, **kwargs): + captured["dataset"] = dataset + captured["kwargs"] = kwargs + + def __iter__(self): + yield ( + torch.tensor([0, 1], dtype=torch.long), + torch.zeros((2, 3, 4, 4), dtype=torch.uint8), + ) + + def __len__(self): + return 1 + + class DummyEncoder: + pretrained_cfg = {} + + class DummyModel: + encoder = DummyEncoder() + + def __call__(self, image): + return {"embedding": torch.ones((image.shape[0], 3), dtype=torch.float32, device=image.device)} + + fake_dataset_module = types.SimpleNamespace( + BatchTileCollator=lambda **kwargs: ("collator", kwargs), + TileIndexDataset=lambda tile_indices: list(tile_indices), + ) + fake_data_package = types.ModuleType("slide2vec.data") + fake_data_package.__path__ = [] + fake_data_package.dataset = fake_dataset_module + + monkeypatch.setitem(sys.modules, "slide2vec.data", fake_data_package) + monkeypatch.setitem(sys.modules, "slide2vec.data.dataset", fake_dataset_module) + + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + monkeypatch.setattr(inference, "_build_batch_preprocessor", lambda *args, **kwargs: lambda batch: batch.float()) + + loaded = inference.LoadedModel( + name="prov-gigapath", + level="tile", + model=DummyModel(), + transforms=object(), + feature_dim=3, + device=torch.device("cpu"), + ) + slide = make_slide("slide-a") + tiling_result = SimpleNamespace( + x=np.array([0, 10]), + y=np.array([5, 15]), + target_spacing_um=0.5, + target_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + tiles_tar_path=Path("/tmp/slide-a.tiles.tar"), + ) + execution = ExecutionOptions( + batch_size=2, + num_workers=3, + num_gpus=1, + prefetch_factor=9, + persistent_workers=True, + gpu_batch_preprocessing=True, + ) + + result = inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="tile"), + slide, + tiling_result, + preprocessing=PreprocessingConfig(on_the_fly=False), + execution=execution, + ) + + assert result.shape == (2, 3) + assert captured["kwargs"]["num_workers"] == 3 + assert captured["kwargs"]["persistent_workers"] is True + assert captured["kwargs"]["prefetch_factor"] == 9 + assert captured["kwargs"]["collate_fn"] == ( + "collator", + { + "tar_path": Path("/tmp/slide-a.tiles.tar"), + "tiling_result": tiling_result, + }, + ) + + +def test_compute_tile_embeddings_for_slide_prefers_explicit_tile_store_root(monkeypatch): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + captured = {} + + class DummyLoader: + def __init__(self, dataset, **kwargs): + captured["dataset"] = dataset + captured["kwargs"] = kwargs + + def __iter__(self): + yield ( + torch.tensor([0], dtype=torch.long), + torch.zeros((1, 3, 4, 4), dtype=torch.uint8), + ) + + def __len__(self): + return 1 + + class DummyEncoder: + pretrained_cfg = {} + + class DummyModel: + encoder = DummyEncoder() + + def __call__(self, image): + return {"embedding": torch.ones((image.shape[0], 3), dtype=torch.float32, device=image.device)} + + fake_dataset_module = types.SimpleNamespace( + BatchTileCollator=lambda **kwargs: ("collator", kwargs), + TileIndexDataset=lambda tile_indices: list(tile_indices), + ) + fake_data_package = types.ModuleType("slide2vec.data") + fake_data_package.__path__ = [] + fake_data_package.dataset = fake_dataset_module + + monkeypatch.setitem(sys.modules, "slide2vec.data", fake_data_package) + monkeypatch.setitem(sys.modules, "slide2vec.data.dataset", fake_dataset_module) + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + monkeypatch.setattr(inference, "_build_batch_preprocessor", lambda *args, **kwargs: lambda batch: batch.float()) + + loaded = inference.LoadedModel( + name="prov-gigapath", + level="tile", + model=DummyModel(), + transforms=object(), + feature_dim=3, + device=torch.device("cpu"), + ) + slide = make_slide("slide-a") + tiling_result = SimpleNamespace( + x=np.array([0]), + y=np.array([5]), + target_spacing_um=0.5, + target_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + tiles_tar_path=Path("/tmp/current-run.tiles.tar"), + ) + + result = inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="tile"), + slide, + tiling_result, + preprocessing=PreprocessingConfig(read_tiles_from=Path("/tmp/external-tiles")), + execution=ExecutionOptions(batch_size=1, num_workers=0, num_gpus=1), + ) + + assert result.shape == (1, 3) + assert captured["kwargs"]["collate_fn"] == ( + "collator", + { + "tar_path": Path("/tmp/external-tiles/slide-a.tiles.tar"), + "tiling_result": tiling_result, + }, + ) + + +def test_resolve_on_the_fly_num_workers_caps_to_slurm_allocation(monkeypatch): + import slide2vec.inference as inference + + monkeypatch.setattr(inference.os, "cpu_count", lambda: 96) + monkeypatch.setenv("SLURM_JOB_CPUS_PER_NODE", "32") + monkeypatch.delenv("SLURM_CPUS_PER_TASK", raising=False) + + workers, details = inference._resolve_on_the_fly_num_workers(4) + + assert workers == 8 + assert "cpu_count=96" in details + assert "slurm_cpu_limit=32" in details + assert "num_cucim_workers=4" in details + + +def test_compute_tile_embeddings_for_slide_caps_on_the_fly_workers_to_slurm(monkeypatch): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + captured = {} + + class DummyLoader: + def __init__(self, dataset, **kwargs): + captured["dataset"] = dataset + captured["kwargs"] = kwargs + + def __iter__(self): + yield ( + torch.tensor([0, 1], dtype=torch.long), + torch.zeros((2, 3, 4, 4), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + + def __len__(self): + return 1 + + class DummyEncoder: + pretrained_cfg = {} + + class DummyModel: + encoder = DummyEncoder() + + def __call__(self, image): + return {"embedding": torch.ones((image.shape[0], 3), dtype=torch.float32, device=image.device)} + + class DummyCollator: + ordered_indices = None + + def __init__(self, **kwargs): + captured["collator_kwargs"] = kwargs + + def __call__(self, batch_indices): + tile_indices = torch.as_tensor(batch_indices, dtype=torch.long) + batch = torch.zeros((len(batch_indices), 3, 4, 4), dtype=torch.uint8) + return tile_indices, batch, {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0} + + fake_cucim_module = types.SimpleNamespace(OnTheFlyBatchTileCollator=DummyCollator) + monkeypatch.setitem(sys.modules, "slide2vec.data.cucim_tile_reader", fake_cucim_module) + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + monkeypatch.setattr(inference, "_build_batch_preprocessor", lambda *args, **kwargs: lambda batch: batch.float()) + monkeypatch.setattr(inference.os, "cpu_count", lambda: 96) + monkeypatch.setenv("SLURM_JOB_CPUS_PER_NODE", "32") + monkeypatch.delenv("SLURM_CPUS_PER_TASK", raising=False) + + loaded = inference.LoadedModel( + name="prov-gigapath", + level="tile", + model=DummyModel(), + transforms=object(), + feature_dim=3, + device=torch.device("cpu"), + ) + slide = make_slide("slide-a") + tiling_result = SimpleNamespace( + x=np.array([0, 10]), + y=np.array([5, 15]), + target_spacing_um=0.5, + target_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + ) + execution = ExecutionOptions( + batch_size=2, + num_workers=99, + num_gpus=1, + prefetch_factor=9, + persistent_workers=True, + gpu_batch_preprocessing=True, + ) + + result = inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="tile"), + slide, + tiling_result, + preprocessing=PreprocessingConfig(on_the_fly=True, backend="cucim", num_cucim_workers=4), + execution=execution, + ) + + assert result.shape == (2, 3) + assert captured["kwargs"]["num_workers"] == 8 + assert captured["kwargs"]["persistent_workers"] is True + assert captured["kwargs"]["prefetch_factor"] == 9 + + +def test_compute_tile_embeddings_for_slide_uses_resolved_cucim_backend_when_auto(monkeypatch): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + captured = {} + + class DummyLoader: + def __init__(self, dataset, **kwargs): + captured["kwargs"] = kwargs + + def __iter__(self): + yield ( + torch.tensor([0, 1], dtype=torch.long), + torch.zeros((2, 3, 4, 4), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + + def __len__(self): + return 1 + + class DummyEncoder: + pretrained_cfg = {} + + class DummyModel: + encoder = DummyEncoder() + + def __call__(self, image): + return {"embedding": torch.ones((image.shape[0], 3), dtype=torch.float32, device=image.device)} + + class DummyCucimCollator: + ordered_indices = None + + def __init__(self, **kwargs): + captured["cucim_collator_kwargs"] = kwargs + + def __call__(self, batch_indices): + tile_indices = torch.as_tensor(batch_indices, dtype=torch.long) + batch = torch.zeros((len(batch_indices), 3, 4, 4), dtype=torch.uint8) + return tile_indices, batch, {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0} + + class DummyWSDCollator: + def __init__(self, **kwargs): + raise AssertionError("wsd collator should not be used") + + monkeypatch.setitem(sys.modules, "slide2vec.data.cucim_tile_reader", types.SimpleNamespace(OnTheFlyBatchTileCollator=DummyCucimCollator)) + monkeypatch.setitem(sys.modules, "slide2vec.data.wsd_tile_reader", types.SimpleNamespace(WSDOnTheFlyBatchTileCollator=DummyWSDCollator)) + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + monkeypatch.setattr(inference, "_build_batch_preprocessor", lambda *args, **kwargs: lambda batch: batch.float()) + monkeypatch.setattr(inference.os, "cpu_count", lambda: 32) + + loaded = inference.LoadedModel( + name="prov-gigapath", + level="tile", + model=DummyModel(), + transforms=object(), + feature_dim=3, + device=torch.device("cpu"), + ) + + result = inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="tile"), + make_slide("slide-a"), + SimpleNamespace( + x=np.array([0, 10]), + y=np.array([5, 15]), + backend="cucim", + target_spacing_um=0.5, + target_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + ), + preprocessing=PreprocessingConfig(on_the_fly=True, backend="auto", num_cucim_workers=4), + execution=ExecutionOptions(batch_size=2, num_workers=8, num_gpus=1), + ) + + assert result.shape == (2, 3) + assert captured["cucim_collator_kwargs"]["num_cucim_workers"] == 4 + + +def test_compute_tile_embeddings_for_slide_uses_resolved_wsd_backend_when_auto(monkeypatch): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + captured = {} + + class DummyLoader: + def __init__(self, dataset, **kwargs): + captured["kwargs"] = kwargs + + def __iter__(self): + yield ( + torch.tensor([0, 1], dtype=torch.long), + torch.zeros((2, 3, 4, 4), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + + def __len__(self): + return 1 + + class DummyEncoder: + pretrained_cfg = {} + + class DummyModel: + encoder = DummyEncoder() + + def __call__(self, image): + return {"embedding": torch.ones((image.shape[0], 3), dtype=torch.float32, device=image.device)} + + class DummyCucimCollator: + def __init__(self, **kwargs): + raise AssertionError("cucim collator should not be used") + + class DummyWSDCollator: + ordered_indices = None + + def __init__(self, **kwargs): + captured["wsd_collator_kwargs"] = kwargs + + def __call__(self, batch_indices): + tile_indices = torch.as_tensor(batch_indices, dtype=torch.long) + batch = torch.zeros((len(batch_indices), 3, 4, 4), dtype=torch.uint8) + return tile_indices, batch, {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0} + + monkeypatch.setitem(sys.modules, "slide2vec.data.cucim_tile_reader", types.SimpleNamespace(OnTheFlyBatchTileCollator=DummyCucimCollator)) + monkeypatch.setitem(sys.modules, "slide2vec.data.wsd_tile_reader", types.SimpleNamespace(WSDOnTheFlyBatchTileCollator=DummyWSDCollator)) + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + monkeypatch.setattr(inference, "_build_batch_preprocessor", lambda *args, **kwargs: lambda batch: batch.float()) + monkeypatch.setattr(inference.os, "cpu_count", lambda: 32) + + loaded = inference.LoadedModel( + name="prov-gigapath", + level="tile", + model=DummyModel(), + transforms=object(), + feature_dim=3, + device=torch.device("cpu"), + ) + + result = inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="tile"), + make_slide("slide-a"), + SimpleNamespace( + x=np.array([0, 10]), + y=np.array([5, 15]), + backend="asap", + target_spacing_um=0.5, + target_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + ), + preprocessing=PreprocessingConfig(on_the_fly=True, backend="auto", num_cucim_workers=4), + execution=ExecutionOptions(batch_size=2, num_workers=8, num_gpus=1), + ) + + assert result.shape == (2, 3) + assert captured["wsd_collator_kwargs"]["backend"] == "asap" + + +def test_persist_embedded_slide_records_resolved_backend_when_auto(monkeypatch, tmp_path: Path): + import slide2vec.inference as inference + + embedded = EmbeddedSlide( + sample_id="slide-a", + tile_embeddings=np.zeros((2, 4), dtype=np.float32), + slide_embedding=None, + coordinates=np.array([[0, 2], [1, 3]], dtype=np.int64), + tile_size_lv0=224, + image_path=Path("/tmp/slide-a.svs"), + mask_path=None, + ) + captured = {} + + monkeypatch.setattr( + inference, + "_write_tile_embedding_artifact", + lambda sample_id, features, *, execution, metadata: captured.setdefault("metadata", metadata) or SimpleNamespace(), + ) + + inference._persist_embedded_slide( + SimpleNamespace(name="prov-gigapath", level="tile"), + embedded, + SimpleNamespace( + backend="cucim", + coordinates_npz_path=Path("/tmp/slide-a.coordinates.npz"), + coordinates_meta_path=Path("/tmp/slide-a.coordinates.meta.json"), + tiles_tar_path=Path("/tmp/slide-a.tiles.tar"), + ), + preprocessing=PreprocessingConfig(backend="auto"), + execution=ExecutionOptions(output_dir=tmp_path), + ) + + assert captured["metadata"]["backend"] == "cucim" + + +def test_compute_tile_embeddings_for_slide_requires_current_run_tile_store_without_explicit_override(): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + loaded = inference.LoadedModel( + name="prov-gigapath", + level="tile", + model=object(), + transforms=object(), + feature_dim=3, + device=torch.device("cpu"), + ) + + with pytest.raises(ValueError, match="missing tiles_tar_path"): + inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="tile"), + make_slide("slide-a"), + SimpleNamespace( + x=np.array([0]), + y=np.array([1]), + target_spacing_um=0.5, + target_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + tiles_tar_path=None, + ), + preprocessing=PreprocessingConfig(on_the_fly=False), + execution=ExecutionOptions(batch_size=1, num_workers=0, num_gpus=1), + ) + + +def test_compute_tile_embeddings_for_slide_uses_batched_loader_for_region_models(monkeypatch): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + captured = {} + + class DummyLoader: + def __init__(self, dataset, **kwargs): + captured["dataset"] = dataset + captured["kwargs"] = kwargs + + def __iter__(self): + yield ( + torch.tensor([0, 1], dtype=torch.long), + torch.full((2, 3, 4, 4), 255, dtype=torch.uint8), + ) + + def __len__(self): + return 1 + + class DummyRegionModel: + tile_size = 2 + + def __call__(self, image): + assert image.ndim == 5 + assert image.shape[1:] == (4, 3, 2, 2) + return {"embedding": torch.ones((image.shape[0], image.shape[1], 3), dtype=torch.float32, device=image.device)} + + fake_dataset_module = types.SimpleNamespace( + BatchTileCollator=lambda **kwargs: ("collator", kwargs), + TileIndexDataset=lambda tile_indices: list(tile_indices), + ) + fake_data_package = types.ModuleType("slide2vec.data") + fake_data_package.__path__ = [] + fake_data_package.dataset = fake_dataset_module + + monkeypatch.setitem(sys.modules, "slide2vec.data", fake_data_package) + monkeypatch.setitem(sys.modules, "slide2vec.data.dataset", fake_dataset_module) + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + + class Normalize: + def __init__(self, mean, std): + self.mean = mean + self.std = std + + loaded = inference.LoadedModel( + name="region-model", + level="region", + model=DummyRegionModel(), + transforms=SimpleNamespace(transforms=[Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), + feature_dim=3, + device=torch.device("cpu"), + ) + slide = make_slide("slide-a") + tiling_result = SimpleNamespace( + x=np.array([0, 10]), + y=np.array([5, 15]), + target_spacing_um=0.5, + target_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + tiles_tar_path=Path("/tmp/slide-a.tiles.tar"), + ) + execution = ExecutionOptions( + batch_size=2, + num_workers=3, + num_gpus=1, + prefetch_factor=9, + persistent_workers=True, + gpu_batch_preprocessing=False, + ) + + result = inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="region"), + slide, + tiling_result, + preprocessing=PreprocessingConfig(on_the_fly=False), + execution=execution, + ) + + assert result.shape == (2, 4, 3) + assert captured["kwargs"]["persistent_workers"] is True + assert captured["kwargs"]["prefetch_factor"] == 9 + assert captured["kwargs"]["collate_fn"] == ( + "collator", + { + "tar_path": Path("/tmp/slide-a.tiles.tar"), + "tiling_result": tiling_result, + }, + ) diff --git a/tests/test_regression_models.py b/tests/test_regression_models.py index d9fecb8..b56baf1 100644 --- a/tests/test_regression_models.py +++ b/tests/test_regression_models.py @@ -117,6 +117,208 @@ def test_build_region_tile_encoder_raises_clear_error_for_missing_dino_arch(): ) ) +@pytest.mark.parametrize("name", ["virchow", "virchow2"]) +def test_build_region_tile_encoder_passes_mode_to_virchow_variants(monkeypatch, name): + pytest.importorskip("timm") + import slide2vec.models.models as models_module + + captured = {} + + def fake_encoder(*, mode): + captured["mode"] = mode + return f"{name}:{mode}" + + monkeypatch.setattr( + models_module, + "Virchow" if name == "virchow" else "Virchow2", + fake_encoder, + ) + + result = models_module._build_region_tile_encoder( + SimpleNamespace( + name=name, + arch=None, + pretrained_weights=None, + input_size=224, + token_size=16, + patch_size=256, + normalize_embeddings=False, + mode="cls", + ) + ) + + assert result == f"{name}:cls" + assert captured["mode"] == "cls" + +def test_build_tile_model_supports_conchv15(monkeypatch): + pytest.importorskip("timm") + import slide2vec.models.models as models_module + + monkeypatch.setattr(models_module, "CONCHv15", lambda: "CONCHV15") + + result = models_module._build_tile_model( + SimpleNamespace( + name="conchv15", + arch=None, + pretrained_weights=None, + input_size=448, + token_size=16, + patch_size=256, + normalize_embeddings=False, + mode="cls", + ) + ) + + assert result == "CONCHV15" + + +def test_build_region_tile_encoder_supports_conchv15(monkeypatch): + pytest.importorskip("timm") + import slide2vec.models.models as models_module + + monkeypatch.setattr(models_module, "CONCHv15", lambda: "CONCHV15") + + result = models_module._build_region_tile_encoder( + SimpleNamespace( + name="conchv15", + arch=None, + pretrained_weights=None, + input_size=448, + token_size=16, + patch_size=256, + normalize_embeddings=False, + mode="cls", + ) + ) + + assert result == "CONCHV15" + + +def test_conchv15_loader_uses_titan_return_conch(monkeypatch): + pytest.importorskip("timm") + import slide2vec.models.models as models_module + + captured = {} + + class FakeEncoder: + def __call__(self, x): + captured["encoder_input"] = x + return "EMBEDDING" + + class FakeTitan: + def return_conch(self): + captured["return_conch"] = True + return FakeEncoder(), "TRANSFORM" + + def fake_from_pretrained(model_name, trust_remote_code): + captured["model_name"] = model_name + captured["trust_remote_code"] = trust_remote_code + return FakeTitan() + + monkeypatch.setattr( + models_module, + "AutoModel", + SimpleNamespace(from_pretrained=fake_from_pretrained), + ) + + model = models_module.CONCHv15() + + assert captured["model_name"] == "MahmoodLab/TITAN" + assert captured["trust_remote_code"] is True + assert captured["return_conch"] is True + assert model.get_transforms() == "TRANSFORM" + assert model.features_dim == 768 + assert model.forward("BATCH") == {"embedding": "EMBEDDING"} + assert captured["encoder_input"] == "BATCH" + + +def test_virchow_defaults_to_concatenated_cls_and_mean_patches(monkeypatch): + torch = pytest.importorskip("torch") + import slide2vec.models.models as models_module + + encoded = torch.tensor( + [[[1.0, 2.0], [10.0, 20.0], [30.0, 40.0]]], + dtype=torch.float32, + ) + + class FakeEncoder: + def __call__(self, x): + return encoded + + monkeypatch.setattr(models_module.Virchow, "build_encoder", lambda self: FakeEncoder()) + + model = models_module.Virchow() + output = model.forward(torch.zeros((1, 3, 224, 224), dtype=torch.float32)) + + assert model.mode == "full" + assert model.features_dim == 2560 + torch.testing.assert_close( + output["embedding"], + torch.tensor([[1.0, 2.0, 20.0, 30.0]], dtype=torch.float32), + ) + + +def test_virchow2_defaults_to_concatenated_cls_and_mean_patches(monkeypatch): + torch = pytest.importorskip("torch") + import slide2vec.models.models as models_module + + encoded = torch.tensor( + [ + [ + [1.0, 2.0], + [100.0, 100.0], + [101.0, 101.0], + [102.0, 102.0], + [103.0, 103.0], + [10.0, 20.0], + [30.0, 40.0], + ] + ], + dtype=torch.float32, + ) + + class FakeEncoder: + def __call__(self, x): + return encoded + + monkeypatch.setattr(models_module.Virchow2, "build_encoder", lambda self: FakeEncoder()) + + model = models_module.Virchow2() + output = model.forward(torch.zeros((1, 3, 224, 224), dtype=torch.float32)) + + assert model.mode == "full" + assert model.features_dim == 2560 + torch.testing.assert_close( + output["embedding"], + torch.tensor([[1.0, 2.0, 20.0, 30.0]], dtype=torch.float32), + ) + + +def test_musk_forward_disables_ms_aug(monkeypatch): + torch = pytest.importorskip("torch") + import slide2vec.models.models as models_module + + captured = {} + + class FakeEncoder: + def __call__(self, **kwargs): + captured.update(kwargs) + return [torch.tensor([[1.0, 2.0]], dtype=torch.float32)] + + monkeypatch.setattr(models_module.MUSK, "build_encoder", lambda self: FakeEncoder()) + + model = models_module.MUSK() + output = model.forward(torch.zeros((1, 3, 384, 384), dtype=torch.float32)) + + assert captured["ms_aug"] is False + assert captured["with_head"] is False + assert captured["out_norm"] is False + assert captured["return_global"] is True + torch.testing.assert_close( + output["embedding"], + torch.tensor([[1.0, 2.0]], dtype=torch.float32), + ) + @pytest.mark.parametrize( ("options", "message"), [ @@ -322,6 +524,112 @@ def fake_embed_slides(model_arg, slides, **kwargs): assert captured["slides"] == ["/tmp/slide-a.svs", "/tmp/slide-b.svs"] assert captured["kwargs"]["execution"].num_gpus == 2 + +def test_model_embed_slides_rejects_non_recommended_preprocessing_by_default(): + model = Model.from_pretrained("virchow2") + + with pytest.raises(ValueError, match="allow_non_recommended_settings"): + model.embed_slides( + [{"sample_id": "slide-a", "image_path": "/tmp/slide-a.svs"}], + preprocessing=PreprocessingConfig(target_spacing_um=1.0, target_tile_size_px=256), + ) + + +def test_model_embed_slides_warns_when_non_recommended_settings_are_allowed( + monkeypatch, + caplog: pytest.LogCaptureFixture, +): + model = Model.from_pretrained("virchow2", allow_non_recommended_settings=True) + expected = [ + EmbeddedSlide( + sample_id="slide-a", + tile_embeddings=np.zeros((1, 2), dtype=np.float32), + slide_embedding=None, + coordinates=np.array([[0, 0]], dtype=np.int64), + tile_size_lv0=224, + image_path=Path("/tmp/slide-a.svs"), + mask_path=None, + ), + ] + + monkeypatch.setattr("slide2vec.inference.embed_slides", lambda *args, **kwargs: expected) + + with caplog.at_level("WARNING", logger="slide2vec"): + result = model.embed_slides( + [{"sample_id": "slide-a", "image_path": "/tmp/slide-a.svs"}], + preprocessing=PreprocessingConfig(target_spacing_um=1.0, target_tile_size_px=256), + ) + + assert result == expected + assert "virchow2" in caplog.text + assert "recommended" in caplog.text + + +def test_model_embed_slides_rejects_non_recommended_precision_by_default(): + model = Model.from_pretrained("virchow2") + + with pytest.raises(ValueError, match="requested precision=fp32"): + model.embed_slides( + [{"sample_id": "slide-a", "image_path": "/tmp/slide-a.svs"}], + preprocessing=PreprocessingConfig(), + execution=ExecutionOptions(precision="fp32"), + ) + + +def test_model_embed_slides_warns_when_non_recommended_precision_is_allowed( + monkeypatch, + caplog: pytest.LogCaptureFixture, +): + model = Model.from_pretrained("virchow2", allow_non_recommended_settings=True) + expected = [ + EmbeddedSlide( + sample_id="slide-a", + tile_embeddings=np.zeros((1, 2), dtype=np.float32), + slide_embedding=None, + coordinates=np.array([[0, 0]], dtype=np.int64), + tile_size_lv0=224, + image_path=Path("/tmp/slide-a.svs"), + mask_path=None, + ), + ] + + monkeypatch.setattr("slide2vec.inference.embed_slides", lambda *args, **kwargs: expected) + + with caplog.at_level("WARNING", logger="slide2vec"): + result = model.embed_slides( + [{"sample_id": "slide-a", "image_path": "/tmp/slide-a.svs"}], + preprocessing=PreprocessingConfig(), + execution=ExecutionOptions(precision="fp32"), + ) + + assert result == expected + assert "requested precision=fp32" in caplog.text + + +def test_model_embed_slides_allows_cpu_execution_with_fp32_precision(monkeypatch): + model = Model.from_pretrained("prism", device="cpu") + expected = [ + EmbeddedSlide( + sample_id="slide-a", + tile_embeddings=np.zeros((1, 2), dtype=np.float32), + slide_embedding=np.zeros((2,), dtype=np.float32), + coordinates=np.array([[0, 0]], dtype=np.int64), + tile_size_lv0=224, + image_path=Path("/tmp/slide-a.svs"), + mask_path=None, + ), + ] + + monkeypatch.setattr("slide2vec.inference.embed_slides", lambda *args, **kwargs: expected) + + result = model.embed_slides( + [{"sample_id": "slide-a", "image_path": "/tmp/slide-a.svs"}], + preprocessing=PreprocessingConfig(), + execution=ExecutionOptions(precision="fp32"), + ) + + assert result == expected + def test_model_embed_tiles_requires_output_dir_at_api_boundary(): model = Model.from_pretrained("virchow2") diff --git a/tests/test_release.py b/tests/test_release.py new file mode 100644 index 0000000..b4caf2a --- /dev/null +++ b/tests/test_release.py @@ -0,0 +1,67 @@ +import release + + +def test_push_branch_and_tag_uses_plain_version_tags(monkeypatch, capsys): + commands: list[str] = [] + + def fake_run(cmd: str, check: bool = True) -> str: + commands.append(cmd) + return "" + + monkeypatch.setattr(release, "run", fake_run) + + release.push_branch_and_tag("release-2.0.3", "2.0.3") + + assert commands == [ + "git push origin release-2.0.3", + "git tag 2.0.3", + "git push origin 2.0.3", + ] + assert "Creating and pushing tag 2.0.3" in capsys.readouterr().out + + +def test_push_tag_and_branch_uses_plain_version_names(monkeypatch, capsys): + commands: list[str] = [] + + def fake_run(cmd: str, check: bool = True) -> str: + commands.append(cmd) + if cmd == "git tag": + return "" + return "" + + monkeypatch.setattr(release, "run", fake_run) + + branch = release.push_tag_and_branch("2.0.3") + + assert branch == "release-2.0.3" + assert commands == [ + "git checkout -b release-2.0.3", + "git push origin release-2.0.3", + "git tag", + "git tag 2.0.3", + "git push origin 2.0.3", + ] + output = capsys.readouterr().out + assert "Creating branch release-2.0.3" in output + assert "Creating tag 2.0.3" in output + + +def test_create_pull_request_and_release_draft_use_plain_versions(monkeypatch, capsys): + commands: list[str] = [] + + def fake_run(cmd: str, check: bool = True) -> str: + commands.append(cmd) + if cmd == "git remote get-url origin": + return "git@github.com:example/slide2vec.git" + return "" + + monkeypatch.setattr(release, "run", fake_run) + + release.create_pull_request("release-2.0.3", "2.0.3") + release.open_release_draft("2.0.3") + + assert commands == [ + 'gh pr create --title "Release 2.0.3" --body "This PR bumps the version to 2.0.3 and tags the release." --base main --head release-2.0.3', + "git remote get-url origin", + ] + assert "releases/new?tag=2.0.3&title=2.0.3" in capsys.readouterr().out diff --git a/tests/test_tile_store.py b/tests/test_tile_store.py new file mode 100644 index 0000000..a374fb3 --- /dev/null +++ b/tests/test_tile_store.py @@ -0,0 +1,74 @@ +"""Tests for TarTileReader — the tar-based batch tile reader.""" + +import io +import tarfile +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +from slide2vec.data.tile_store import TarTileReader # noqa: direct import avoids cucim dep in __init__ + + +def _create_test_tar(tar_path: Path, colors: list[tuple[int, int, int]], tile_size: int = 64): + """Create a tar with solid-color JPEG tiles.""" + with tarfile.open(tar_path, "w") as tf: + for i, color in enumerate(colors): + img = Image.new("RGB", (tile_size, tile_size), color) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=95) + buf.seek(0) + info = tarfile.TarInfo(name=f"{i:06d}.jpg") + info.size = buf.getbuffer().nbytes + tf.addfile(info, buf) + + +class TestTarTileReader: + def test_read_batch_returns_correct_shape(self, tmp_path: Path): + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] + tar_path = tmp_path / "tiles.tar" + _create_test_tar(tar_path, colors, tile_size=64) + + reader = TarTileReader(tar_path, tile_size_px=64) + batch = reader.read_batch(np.array([0, 1, 2], dtype=np.int64)) + + assert batch.shape == (3, 3, 64, 64) + assert batch.dtype == torch.uint8 + + def test_read_batch_pixel_values_within_jpeg_tolerance(self, tmp_path: Path): + tar_path = tmp_path / "tiles.tar" + _create_test_tar(tar_path, [(200, 100, 50)], tile_size=32) + + reader = TarTileReader(tar_path, tile_size_px=32) + batch = reader.read_batch(np.array([0], dtype=np.int64)) + + # JPEG is lossy — check within tolerance + r_mean = batch[0, 0].float().mean().item() + g_mean = batch[0, 1].float().mean().item() + b_mean = batch[0, 2].float().mean().item() + assert abs(r_mean - 200) < 5 + assert abs(g_mean - 100) < 5 + assert abs(b_mean - 50) < 5 + + def test_read_batch_subset_indices(self, tmp_path: Path): + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] + tar_path = tmp_path / "tiles.tar" + _create_test_tar(tar_path, colors, tile_size=32) + + reader = TarTileReader(tar_path, tile_size_px=32) + batch = reader.read_batch(np.array([2], dtype=np.int64)) + + assert batch.shape == (1, 3, 32, 32) + # Blue tile: channel 2 should dominate + assert batch[0, 2].float().mean() > batch[0, 0].float().mean() + + def test_read_batch_empty_indices(self, tmp_path: Path): + tar_path = tmp_path / "tiles.tar" + _create_test_tar(tar_path, [(128, 128, 128)], tile_size=16) + + reader = TarTileReader(tar_path, tile_size_px=16) + batch = reader.read_batch(np.array([], dtype=np.int64)) + + assert batch.shape == (0, 3, 16, 16) + assert batch.dtype == torch.uint8