diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 40804af..ab87f99 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -35,4 +35,4 @@ jobs: - name: Run pylint if: always() run: | - pylint --exit-zero slide2vec + pylint --exit-zero slide2vec \ No newline at end of file diff --git a/.github/workflows/pr-test.yaml b/.github/workflows/pr-test.yaml new file mode 100644 index 0000000..3adf37a --- /dev/null +++ b/.github/workflows/pr-test.yaml @@ -0,0 +1,109 @@ +name: Test WSI to embedding consistency + +on: + pull_request: + types: [opened, synchronize, reopened] + workflow_dispatch: + +jobs: + docker-test: + runs-on: ubuntu-latest + timeout-minutes: 60 + permissions: + contents: read + actions: write # needed for Buildx GHA cache + + # don't run secret-using job on forked PRs + if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }} + + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Verify required folders exist + run: | + set -euo pipefail + test -d test/input + test -d test/gt + mkdir -p test/output # ensure host-mapped output exists + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Free disk space on runner + run: | + set -euxo pipefail + df -h + sudo rm -rf /usr/local/lib/android || true + sudo rm -rf /usr/share/dotnet || true + sudo rm -rf /opt/ghc || true + sudo rm -rf "${AGENT_TOOLSDIRECTORY:-/opt/hostedtoolcache}" || true + docker system prune -af || true + sudo apt-get clean + df -h + + - name: Build image + uses: docker/build-push-action@v6 + with: + context: . + file: Dockerfile + push: false + load: true + tags: slide2vec:${{ github.sha }} + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Guard required secret (if needed) + if: ${{ github.event_name != 'pull_request' }} # or drop the condition if PRs from same repo need it + run: | + set -euo pipefail + test -n "${HF_TOKEN:-}" || { echo "HF_TOKEN is required but not set"; exit 1; } + + - name: Generate outputs in container + run: | + set -euo pipefail + docker run --rm \ + -e HF_TOKEN="$HF_TOKEN" \ + -v "$GITHUB_WORKSPACE/test/input:/input" \ + -v "$GITHUB_WORKSPACE/test/output:/output" \ + slide2vec:${{ github.sha }} \ + python slide2vec/main.py \ + --config-file /input/config.yaml \ + --skip-datetime \ + --run-on-cpu + + - name: Verify output consistency (inside container) + run: | + set -euo pipefail + docker run --rm \ + -v "$GITHUB_WORKSPACE/test/gt:/gt" \ + -v "$GITHUB_WORKSPACE/test/output:/output" \ + slide2vec:${{ github.sha }} \ + bash -lc "python - <<'PY' + import numpy as np, torch + from numpy.testing import assert_array_equal + + # coordinates must match exactly (deterministic tiling) + gt_coordinates = np.load('/gt/test-wsi.npy') + coordinates = np.load('/output/coordinates/test-wsi.npy') + assert_array_equal(coordinates, gt_coordinates), f'Coordinates mismatch: {coordinates} vs {gt_coordinates}' + + # embeddings: allow tiny numeric drift + gt = torch.load('/gt/test-wsi.pt', map_location='cpu') + emb = torch.load('/output/features/test-wsi.pt', map_location='cpu') + assert emb.shape == gt.shape, f'Shape mismatch: {emb.shape} vs {gt.shape}' + + cos = torch.nn.functional.cosine_similarity(emb, gt, dim=-1) + mean_cos = float(cos.mean()) + atol, rtol = 1e-2, 1e-3 + if not torch.allclose(emb, gt, atol=atol, rtol=rtol): + if mean_cos < 0.99: + raise AssertionError(f'Embedding mismatch: mean cosine similarity={mean_cos:.4f} (atol={atol}, rtol={rtol})') + else: + print(f'WARNING: embeddings not allclose, but mean cosine similarity={mean_cos:.4f} (atol={atol}, rtol={rtol})') + else: + print(f'OK: embeddings within tolerance; mean cosine similarity={mean_cos:.4f}') + PY" diff --git a/slide2vec/aggregate.py b/slide2vec/aggregate.py index 47e4e17..55d4743 100644 --- a/slide2vec/aggregate.py +++ b/slide2vec/aggregate.py @@ -33,6 +33,9 @@ def get_args_parser(add_help: bool = True): default="", help="Name of output subdirectory", ) + parser.add_argument( + "--run-on-cpu", action="store_true", help="run inference on cpu" + ) return parser @@ -49,6 +52,7 @@ def scale_coordinates(wsi_fp, coordinates, spacing, backend): def main(args): # setup configuration + run_on_cpu = args.run_on_cpu cfg = get_cfg_from_file(args.config_file) output_dir = Path(cfg.output_dir, args.run_id) cfg.output_dir = str(output_dir) @@ -87,7 +91,7 @@ def main(args): autocast_context = ( torch.autocast(device_type="cuda", dtype=torch.float16) - if cfg.speed.fp16 + if (cfg.speed.fp16 and not run_on_cpu) else nullcontext() ) feature_aggregation_updates = {} @@ -136,7 +140,8 @@ def main(args): torch.save(wsi_feature, feature_path) del wsi_feature - torch.cuda.empty_cache() + if not run_on_cpu: + torch.cuda.empty_cache() gc.collect() feature_aggregation_updates[str(wsi_fp)] = {"status": "success"} diff --git a/slide2vec/embed.py b/slide2vec/embed.py index f2bb3af..bc1e5aa 100644 --- a/slide2vec/embed.py +++ b/slide2vec/embed.py @@ -33,6 +33,9 @@ def get_args_parser(add_help: bool = True): default="", help="Name of output subdirectory", ) + parser.add_argument( + "--run-on-cpu", action="store_true", help="run inference on cpu" + ) return parser @@ -61,14 +64,15 @@ def create_dataset(wsi_fp, coordinates_dir, spacing, backend, transforms): ) -def run_inference(dataloader, model, device, autocast_context, unit, batch_size, feature_path, feature_dim, dtype): +def run_inference(dataloader, model, device, autocast_context, unit, batch_size, feature_path, feature_dim, dtype, run_on_cpu: False): + device_name = f"GPU {distributed.get_global_rank()}" if not run_on_cpu else "CPU" with h5py.File(feature_path, "w") as f: features = f.create_dataset("features", shape=(0, *feature_dim), maxshape=(None, *feature_dim), dtype=dtype, chunks=(batch_size, *feature_dim)) indices = f.create_dataset("indices", shape=(0,), maxshape=(None,), dtype='int64', chunks=(batch_size,)) with torch.inference_mode(), autocast_context: for batch in tqdm.tqdm( dataloader, - desc=f"Inference on GPU {distributed.get_global_rank()}", + desc=f"Inference on {device_name}", unit=unit, unit_scale=batch_size, leave=False, @@ -86,7 +90,8 @@ def run_inference(dataloader, model, device, autocast_context, unit, batch_size, del image, feature # cleanup - torch.cuda.empty_cache() + if not run_on_cpu: + torch.cuda.empty_cache() gc.collect() @@ -116,11 +121,13 @@ def load_sort_and_deduplicate_features(tmp_dir, name, expected_len=None): def main(args): # setup configuration + run_on_cpu = args.run_on_cpu cfg = get_cfg_from_file(args.config_file) output_dir = Path(cfg.output_dir, args.run_id) cfg.output_dir = str(output_dir) - setup_distributed() + if not run_on_cpu: + setup_distributed() if cfg.tiling.read_coordinates_from: coordinates_dir = Path(cfg.tiling.read_coordinates_from) @@ -155,7 +162,8 @@ def main(args): model = ModelFactory(cfg.model).get_model() if distributed.is_main_process(): print(f"Starting {unit}-level feature extraction...") - torch.distributed.barrier() + if not run_on_cpu: + torch.distributed.barrier() # select slides that were successfully tiled but not yet processed for feature extraction tiled_df = process_df[process_df.tiling_status == "success"] @@ -174,7 +182,7 @@ def main(args): autocast_context = ( torch.autocast(device_type="cuda", dtype=torch.float16) - if cfg.speed.fp16 + if (cfg.speed.fp16 and not run_on_cpu) else nullcontext() ) feature_extraction_updates = {} @@ -231,9 +239,11 @@ def main(args): tmp_feature_path, feature_dim, dtype, + run_on_cpu, ) - torch.distributed.barrier() + if not run_on_cpu: + torch.distributed.barrier() if distributed.is_main_process(): wsi_feature = load_sort_and_deduplicate_features(tmp_dir, name, expected_len=len(dataset)) @@ -241,10 +251,12 @@ def main(args): # cleanup del wsi_feature - torch.cuda.empty_cache() + if not run_on_cpu: + torch.cuda.empty_cache() gc.collect() - torch.distributed.barrier() + if not run_on_cpu: + torch.distributed.barrier() feature_extraction_updates[str(wsi_fp)] = {"status": "success"} diff --git a/slide2vec/main.py b/slide2vec/main.py index ba6bfb4..c7925ed 100644 --- a/slide2vec/main.py +++ b/slide2vec/main.py @@ -21,6 +21,9 @@ def get_args_parser(add_help: bool = True): parser.add_argument( "--skip-datetime", action="store_true", help="skip run id datetime prefix" ) + parser.add_argument( + "--run-on-cpu", action="store_true", help="run inference on cpu" + ) return parser @@ -62,7 +65,7 @@ def run_tiling(config_file, run_id): sys.exit(proc.returncode) -def run_feature_extraction(config_file, run_id): +def run_feature_extraction(config_file, run_id, run_on_cpu: False): print("Running embed.py...") # find a free port with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -80,6 +83,16 @@ def run_feature_extraction(config_file, run_id): "--config-file", config_file, ] + if run_on_cpu: + cmd = [ + sys.executable, + "slide2vec/embed.py", + "--run-id", + run_id, + "--config-file", + config_file, + "--run-on-cpu", + ] # launch in its own process group. proc = subprocess.Popen( cmd, @@ -106,7 +119,7 @@ def run_feature_extraction(config_file, run_id): sys.exit(proc.returncode) -def run_feature_aggregation(config_file, run_id): +def run_feature_aggregation(config_file, run_id, run_on_cpu: False): print("Running aggregate.py...") # find a free port cmd = [ @@ -117,6 +130,8 @@ def run_feature_aggregation(config_file, run_id): "--config-file", config_file, ] + if run_on_cpu: + cmd.append("--run-on-cpu") # launch in its own process group. proc = subprocess.Popen( cmd, @@ -134,7 +149,7 @@ def run_feature_aggregation(config_file, run_id): sys.stdout.flush() proc.wait() except KeyboardInterrupt: - print("Received CTRL+C, terminating embed.py process group...") + print("Received CTRL+C, terminating aggregate.py process group...") os.killpg(os.getpgid(proc.pid), signal.SIGTERM) proc.wait() sys.exit(1) @@ -146,18 +161,17 @@ def run_feature_aggregation(config_file, run_id): def main(args): config_file = args.config_file skip_datetime = args.skip_datetime + run_on_cpu = args.run_on_cpu - cfg = setup(config_file, skip_datetime=skip_datetime) + cfg, run_id = setup(config_file, skip_datetime=skip_datetime) hf_login() - output_dir = Path(cfg.output_dir) - run_id = output_dir.stem - run_tiling(config_file, run_id) print("Tiling completed.") print("=+=" * 10) + output_dir = Path(cfg.output_dir) features_dir = output_dir / "features" if cfg.wandb.enable: stop_event = threading.Event() @@ -166,14 +180,16 @@ def main(args): ) log_thread.start() - run_feature_extraction(config_file, run_id) - print("Feature extraction completed.") - print("=+=" * 10) + run_feature_extraction(config_file, run_id, run_on_cpu) if cfg.model.level == "slide": - run_feature_aggregation(config_file, run_id) - print("Feature aggregation completed.") + run_feature_aggregation(config_file, run_id, run_on_cpu) + print("Feature extraction completed.") print("=+=" * 10) + else: + print("Feature extraction completed.") + print("=+=" * 10) + if cfg.wandb.enable: stop_event.set() @@ -184,16 +200,21 @@ def main(args): if __name__ == "__main__": - import warnings + import warnings import torchvision + torchvision.disable_beta_transforms_warning() - + warnings.filterwarnings("ignore", message=".*Could not set the permissions.*") warnings.filterwarnings("ignore", message=".*antialias.*", category=UserWarning) warnings.filterwarnings("ignore", message=".*TypedStorage.*", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", message="The given NumPy array is not writable") + warnings.filterwarnings( + "ignore", + message=".*'frozen' attribute with value True was provided to the `Field`.*" + ) args = get_args_parser(add_help=True).parse_args() main(args) diff --git a/slide2vec/models/models.py b/slide2vec/models/models.py index a61e4c1..2856f19 100644 --- a/slide2vec/models/models.py +++ b/slide2vec/models/models.py @@ -578,5 +578,5 @@ def build_encoders(self): def forward_slide(self, tile_features, **kwargs): tile_features = tile_features.unsqueeze(0) reprs = self.slide_encoder.slide_representations(tile_features) - output = reprs["image_embedding"] # [1, 1280] + output = reprs["image_embedding"].squeeze(0) # [1280] return output diff --git a/slide2vec/utils/config.py b/slide2vec/utils/config.py index d58d66a..4ddeea9 100644 --- a/slide2vec/utils/config.py +++ b/slide2vec/utils/config.py @@ -58,7 +58,7 @@ def setup(config_file, skip_datetime: bool = False): output_dir = Path(cfg.output_dir, run_id) if distributed.is_main_process(): - output_dir.mkdir(exist_ok=cfg.resume, parents=True) + output_dir.mkdir(exist_ok=cfg.resume or skip_datetime, parents=True) cfg.output_dir = str(output_dir) fix_random_seeds(0) @@ -67,7 +67,7 @@ def setup(config_file, skip_datetime: bool = False): cfg_path = write_config(cfg, cfg.output_dir) if cfg.wandb.enable: wandb_run.save(cfg_path) - return cfg + return cfg, run_id def setup_distributed(): diff --git a/test/gt/mask-visu.jpg b/test/gt/mask-visu.jpg new file mode 100644 index 0000000..120d545 Binary files /dev/null and b/test/gt/mask-visu.jpg differ diff --git a/test/gt/test-wsi.npy b/test/gt/test-wsi.npy new file mode 100644 index 0000000..c68ca54 Binary files /dev/null and b/test/gt/test-wsi.npy differ diff --git a/test/gt/test-wsi.pt b/test/gt/test-wsi.pt new file mode 100644 index 0000000..e50b8e1 Binary files /dev/null and b/test/gt/test-wsi.pt differ diff --git a/test/gt/tiling-visu.jpg b/test/gt/tiling-visu.jpg new file mode 100644 index 0000000..ef5a989 Binary files /dev/null and b/test/gt/tiling-visu.jpg differ diff --git a/test/input/config.yaml b/test/input/config.yaml new file mode 100644 index 0000000..80783f9 --- /dev/null +++ b/test/input/config.yaml @@ -0,0 +1,26 @@ +csv: "/input/test.csv" + +output_dir: "/output" +visualize: false + +tiling: + params: + spacing: 0.5 # spacing at which to tile the slide, in microns per pixel + tolerance: 0.07 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) + tile_size: 224 # size of the tiles to extract, in pixels + min_tissue_percentage: 0.1 # threshold used to filter out tiles that have less tissue than this value (percentage) + filter_params: + ref_tile_size: 224 + +model: + level: "slide" + name: "prism" + batch_size: 8 + +speed: + fp16: true + num_workers_tiling: 4 + num_workers_embedding: 4 + +wandb: + enable: false \ No newline at end of file diff --git a/test/input/test-mask.tif b/test/input/test-mask.tif new file mode 100644 index 0000000..ec8682a Binary files /dev/null and b/test/input/test-mask.tif differ diff --git a/test/input/test-wsi.tif b/test/input/test-wsi.tif new file mode 100644 index 0000000..f7d4039 Binary files /dev/null and b/test/input/test-wsi.tif differ diff --git a/test/input/test.csv b/test/input/test.csv new file mode 100644 index 0000000..80c7ed2 --- /dev/null +++ b/test/input/test.csv @@ -0,0 +1,2 @@ +wsi_path,mask_path +/input/test-wsi.tif,/input/test-mask.tif \ No newline at end of file