Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ jobs:
- name: Run pylint
if: always()
run: |
pylint --exit-zero slide2vec
pylint --exit-zero slide2vec
109 changes: 109 additions & 0 deletions .github/workflows/pr-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
name: Test WSI to embedding consistency

on:
pull_request:
types: [opened, synchronize, reopened]
workflow_dispatch:

jobs:
docker-test:
runs-on: ubuntu-latest
timeout-minutes: 60
permissions:
contents: read
actions: write # needed for Buildx GHA cache

# don't run secret-using job on forked PRs
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

steps:
- name: Check out repository
uses: actions/checkout@v4

- name: Verify required folders exist
run: |
set -euo pipefail
test -d test/input
test -d test/gt
mkdir -p test/output # ensure host-mapped output exists

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Free disk space on runner
run: |
set -euxo pipefail
df -h
sudo rm -rf /usr/local/lib/android || true
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf "${AGENT_TOOLSDIRECTORY:-/opt/hostedtoolcache}" || true
docker system prune -af || true
sudo apt-get clean
df -h

- name: Build image
uses: docker/build-push-action@v6
with:
context: .
file: Dockerfile
push: false
load: true
tags: slide2vec:${{ github.sha }}
cache-from: type=gha
cache-to: type=gha,mode=max

- name: Guard required secret (if needed)
if: ${{ github.event_name != 'pull_request' }} # or drop the condition if PRs from same repo need it
run: |
set -euo pipefail
test -n "${HF_TOKEN:-}" || { echo "HF_TOKEN is required but not set"; exit 1; }

- name: Generate outputs in container
run: |
set -euo pipefail
docker run --rm \
-e HF_TOKEN="$HF_TOKEN" \
-v "$GITHUB_WORKSPACE/test/input:/input" \
-v "$GITHUB_WORKSPACE/test/output:/output" \
slide2vec:${{ github.sha }} \
python slide2vec/main.py \
--config-file /input/config.yaml \
--skip-datetime \
--run-on-cpu

- name: Verify output consistency (inside container)
run: |
set -euo pipefail
docker run --rm \
-v "$GITHUB_WORKSPACE/test/gt:/gt" \
-v "$GITHUB_WORKSPACE/test/output:/output" \
slide2vec:${{ github.sha }} \
bash -lc "python - <<'PY'
import numpy as np, torch
from numpy.testing import assert_array_equal

# coordinates must match exactly (deterministic tiling)
gt_coordinates = np.load('/gt/test-wsi.npy')
coordinates = np.load('/output/coordinates/test-wsi.npy')
assert_array_equal(coordinates, gt_coordinates), f'Coordinates mismatch: {coordinates} vs {gt_coordinates}'

# embeddings: allow tiny numeric drift
gt = torch.load('/gt/test-wsi.pt', map_location='cpu')
emb = torch.load('/output/features/test-wsi.pt', map_location='cpu')
assert emb.shape == gt.shape, f'Shape mismatch: {emb.shape} vs {gt.shape}'

cos = torch.nn.functional.cosine_similarity(emb, gt, dim=-1)
mean_cos = float(cos.mean())
atol, rtol = 1e-2, 1e-3
if not torch.allclose(emb, gt, atol=atol, rtol=rtol):
if mean_cos < 0.99:
raise AssertionError(f'Embedding mismatch: mean cosine similarity={mean_cos:.4f} (atol={atol}, rtol={rtol})')
else:
print(f'WARNING: embeddings not allclose, but mean cosine similarity={mean_cos:.4f} (atol={atol}, rtol={rtol})')
else:
print(f'OK: embeddings within tolerance; mean cosine similarity={mean_cos:.4f}')
PY"
9 changes: 7 additions & 2 deletions slide2vec/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def get_args_parser(add_help: bool = True):
default="",
help="Name of output subdirectory",
)
parser.add_argument(
"--run-on-cpu", action="store_true", help="run inference on cpu"
)
return parser


Expand All @@ -49,6 +52,7 @@ def scale_coordinates(wsi_fp, coordinates, spacing, backend):

def main(args):
# setup configuration
run_on_cpu = args.run_on_cpu
cfg = get_cfg_from_file(args.config_file)
output_dir = Path(cfg.output_dir, args.run_id)
cfg.output_dir = str(output_dir)
Expand Down Expand Up @@ -87,7 +91,7 @@ def main(args):

autocast_context = (
torch.autocast(device_type="cuda", dtype=torch.float16)
if cfg.speed.fp16
if (cfg.speed.fp16 and not run_on_cpu)
else nullcontext()
)
feature_aggregation_updates = {}
Expand Down Expand Up @@ -136,7 +140,8 @@ def main(args):

torch.save(wsi_feature, feature_path)
del wsi_feature
torch.cuda.empty_cache()
if not run_on_cpu:
torch.cuda.empty_cache()
gc.collect()

feature_aggregation_updates[str(wsi_fp)] = {"status": "success"}
Expand Down
30 changes: 21 additions & 9 deletions slide2vec/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def get_args_parser(add_help: bool = True):
default="",
help="Name of output subdirectory",
)
parser.add_argument(
"--run-on-cpu", action="store_true", help="run inference on cpu"
)
return parser


Expand Down Expand Up @@ -61,14 +64,15 @@ def create_dataset(wsi_fp, coordinates_dir, spacing, backend, transforms):
)


def run_inference(dataloader, model, device, autocast_context, unit, batch_size, feature_path, feature_dim, dtype):
def run_inference(dataloader, model, device, autocast_context, unit, batch_size, feature_path, feature_dim, dtype, run_on_cpu: False):
device_name = f"GPU {distributed.get_global_rank()}" if not run_on_cpu else "CPU"
with h5py.File(feature_path, "w") as f:
features = f.create_dataset("features", shape=(0, *feature_dim), maxshape=(None, *feature_dim), dtype=dtype, chunks=(batch_size, *feature_dim))
indices = f.create_dataset("indices", shape=(0,), maxshape=(None,), dtype='int64', chunks=(batch_size,))
with torch.inference_mode(), autocast_context:
for batch in tqdm.tqdm(
dataloader,
desc=f"Inference on GPU {distributed.get_global_rank()}",
desc=f"Inference on {device_name}",
unit=unit,
unit_scale=batch_size,
leave=False,
Expand All @@ -86,7 +90,8 @@ def run_inference(dataloader, model, device, autocast_context, unit, batch_size,
del image, feature

# cleanup
torch.cuda.empty_cache()
if not run_on_cpu:
torch.cuda.empty_cache()
gc.collect()


Expand Down Expand Up @@ -116,11 +121,13 @@ def load_sort_and_deduplicate_features(tmp_dir, name, expected_len=None):

def main(args):
# setup configuration
run_on_cpu = args.run_on_cpu
cfg = get_cfg_from_file(args.config_file)
output_dir = Path(cfg.output_dir, args.run_id)
cfg.output_dir = str(output_dir)

setup_distributed()
if not run_on_cpu:
setup_distributed()

if cfg.tiling.read_coordinates_from:
coordinates_dir = Path(cfg.tiling.read_coordinates_from)
Expand Down Expand Up @@ -155,7 +162,8 @@ def main(args):
model = ModelFactory(cfg.model).get_model()
if distributed.is_main_process():
print(f"Starting {unit}-level feature extraction...")
torch.distributed.barrier()
if not run_on_cpu:
torch.distributed.barrier()

# select slides that were successfully tiled but not yet processed for feature extraction
tiled_df = process_df[process_df.tiling_status == "success"]
Expand All @@ -174,7 +182,7 @@ def main(args):

autocast_context = (
torch.autocast(device_type="cuda", dtype=torch.float16)
if cfg.speed.fp16
if (cfg.speed.fp16 and not run_on_cpu)
else nullcontext()
)
feature_extraction_updates = {}
Expand Down Expand Up @@ -231,20 +239,24 @@ def main(args):
tmp_feature_path,
feature_dim,
dtype,
run_on_cpu,
)

torch.distributed.barrier()
if not run_on_cpu:
torch.distributed.barrier()

if distributed.is_main_process():
wsi_feature = load_sort_and_deduplicate_features(tmp_dir, name, expected_len=len(dataset))
torch.save(wsi_feature, feature_path)

# cleanup
del wsi_feature
torch.cuda.empty_cache()
if not run_on_cpu:
torch.cuda.empty_cache()
gc.collect()

torch.distributed.barrier()
if not run_on_cpu:
torch.distributed.barrier()

feature_extraction_updates[str(wsi_fp)] = {"status": "success"}

Expand Down
49 changes: 35 additions & 14 deletions slide2vec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def get_args_parser(add_help: bool = True):
parser.add_argument(
"--skip-datetime", action="store_true", help="skip run id datetime prefix"
)
parser.add_argument(
"--run-on-cpu", action="store_true", help="run inference on cpu"
)
return parser


Expand Down Expand Up @@ -62,7 +65,7 @@ def run_tiling(config_file, run_id):
sys.exit(proc.returncode)


def run_feature_extraction(config_file, run_id):
def run_feature_extraction(config_file, run_id, run_on_cpu: False):
print("Running embed.py...")
# find a free port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand All @@ -80,6 +83,16 @@ def run_feature_extraction(config_file, run_id):
"--config-file",
config_file,
]
if run_on_cpu:
cmd = [
sys.executable,
"slide2vec/embed.py",
"--run-id",
run_id,
"--config-file",
config_file,
"--run-on-cpu",
]
# launch in its own process group.
proc = subprocess.Popen(
cmd,
Expand All @@ -106,7 +119,7 @@ def run_feature_extraction(config_file, run_id):
sys.exit(proc.returncode)


def run_feature_aggregation(config_file, run_id):
def run_feature_aggregation(config_file, run_id, run_on_cpu: False):
print("Running aggregate.py...")
# find a free port
cmd = [
Expand All @@ -117,6 +130,8 @@ def run_feature_aggregation(config_file, run_id):
"--config-file",
config_file,
]
if run_on_cpu:
cmd.append("--run-on-cpu")
# launch in its own process group.
proc = subprocess.Popen(
cmd,
Expand All @@ -134,7 +149,7 @@ def run_feature_aggregation(config_file, run_id):
sys.stdout.flush()
proc.wait()
except KeyboardInterrupt:
print("Received CTRL+C, terminating embed.py process group...")
print("Received CTRL+C, terminating aggregate.py process group...")
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
proc.wait()
sys.exit(1)
Expand All @@ -146,18 +161,17 @@ def run_feature_aggregation(config_file, run_id):
def main(args):
config_file = args.config_file
skip_datetime = args.skip_datetime
run_on_cpu = args.run_on_cpu

cfg = setup(config_file, skip_datetime=skip_datetime)
cfg, run_id = setup(config_file, skip_datetime=skip_datetime)
hf_login()

output_dir = Path(cfg.output_dir)
run_id = output_dir.stem

run_tiling(config_file, run_id)

print("Tiling completed.")
print("=+=" * 10)

output_dir = Path(cfg.output_dir)
features_dir = output_dir / "features"
if cfg.wandb.enable:
stop_event = threading.Event()
Expand All @@ -166,14 +180,16 @@ def main(args):
)
log_thread.start()

run_feature_extraction(config_file, run_id)
print("Feature extraction completed.")
print("=+=" * 10)
run_feature_extraction(config_file, run_id, run_on_cpu)

if cfg.model.level == "slide":
run_feature_aggregation(config_file, run_id)
print("Feature aggregation completed.")
run_feature_aggregation(config_file, run_id, run_on_cpu)
print("Feature extraction completed.")
print("=+=" * 10)
else:
print("Feature extraction completed.")
print("=+=" * 10)


if cfg.wandb.enable:
stop_event.set()
Expand All @@ -184,16 +200,21 @@ def main(args):


if __name__ == "__main__":
import warnings

import warnings
import torchvision

torchvision.disable_beta_transforms_warning()

warnings.filterwarnings("ignore", message=".*Could not set the permissions.*")
warnings.filterwarnings("ignore", message=".*antialias.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*TypedStorage.*", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
warnings.filterwarnings(
"ignore",
message=".*'frozen' attribute with value True was provided to the `Field`.*"
)

args = get_args_parser(add_help=True).parse_args()
main(args)
2 changes: 1 addition & 1 deletion slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,5 +578,5 @@ def build_encoders(self):
def forward_slide(self, tile_features, **kwargs):
tile_features = tile_features.unsqueeze(0)
reprs = self.slide_encoder.slide_representations(tile_features)
output = reprs["image_embedding"] # [1, 1280]
output = reprs["image_embedding"].squeeze(0) # [1280]
return output
Loading
Loading