Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ RUN apt-get update && curl -L ${ASAP_URL} -o /tmp/ASAP.deb && apt-get install --
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# clone prov-gigapath repo
RUN git clone https://github.com/prov-gigapath/prov-gigapath.git
# clone & install relevant repositories
RUN git clone https://github.com/prov-gigapath/prov-gigapath.git && \
git+https://github.com/lilab-stanford/MUSK.git && \
git+https://github.com/Mahmoodlab/CONCH.git && \
python -m pip install -e /home/user/MUSK && \
python -m pip install -e /home/user/CONCH

# add gigapath folder to python path
ENV PYTHONPATH="/home/user/prov-gigapath:$PYTHONPATH"
# add folders to python path
ENV PYTHONPATH="/home/user/prov-gigapath:/home/user/CONCH:/home/user/MUSK:$PYTHONPATH"

WORKDIR /opt/app/

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pip install slide2vec

A good starting point is the default configuration file `slide2vec/configs/default.yaml` where parameters are documented.<br>
We've also added default configuration files for each of the foundation models currently supported:
- tile-level: `uni`, `uni2`, `virchow`, `virchow2`, `prov-gigapath`, `h-optimus-0`, `h-optimus-1`, `h0-mini`
- tile-level: `uni`, `uni2`, `virchow`, `virchow2`, `prov-gigapath`, `h-optimus-0`, `h-optimus-1`, `h0-mini`, `conch`, `musk`, `phikonv2`, `hibou-b`, `hibou-L`, [`kaiko`](https://github.com/kaiko-ai/towards_large_pathology_fms)
- slide-level: `prov-gigapath`, `titan`, `prism`


Expand Down
18 changes: 18 additions & 0 deletions slide2vec/configs/conch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 448 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "conch"
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
19 changes: 19 additions & 0 deletions slide2vec/configs/hibou.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 224 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
arch: "hibou-b" # Hibou model architectures, options: ("hibou-b", "hibou-L")
name: "hibou"
Comment thread
clemsgrs marked this conversation as resolved.
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
19 changes: 19 additions & 0 deletions slide2vec/configs/kaiko.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 224 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "kaiko"
arch: "vitl14" # kaiko model architectures, options: ("vits8", "vits16", "vitb8", "vitb16", "vitl14")
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
18 changes: 18 additions & 0 deletions slide2vec/configs/musk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 384 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "musk"
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
18 changes: 18 additions & 0 deletions slide2vec/configs/phikonv2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 224 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "phikonv2"
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
6 changes: 5 additions & 1 deletion slide2vec/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import wholeslidedata as wsd

from transformers.image_processing_utils import BaseImageProcessor
from PIL import Image
from pathlib import Path

Expand Down Expand Up @@ -58,5 +59,8 @@ def __getitem__(self, idx):
if self.tile_size[idx] != self.tile_size_resized[idx]:
tile = tile.resize((self.tile_size[idx], self.tile_size[idx]))
if self.transforms:
tile = self.transforms(tile)
if isinstance(self.transforms, BaseImageProcessor): # Hugging Face (`transformer`)
tile = self.transforms(tile, return_tensors="pt")["pixel_values"].squeeze(0)
Comment thread
clemsgrs marked this conversation as resolved.
else: # general callable such as torchvision transforms
tile = self.transforms(tile)
return idx, tile
179 changes: 173 additions & 6 deletions slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

from einops import rearrange
from omegaconf import DictConfig
from transformers import AutoModel
from transformers import AutoModel, AutoImageProcessor
from torchvision import transforms
from torchvision.transforms import v2
from timm.data import resolve_data_config
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data.transforms_factory import create_transform

from conch.open_clip_custom import create_model_from_pretrained
from musk import modeling as musk_modeling
from musk import utils as musk_utils

import slide2vec.distributed as distributed
import slide2vec.models.vision_transformer_dino as vits_dino
import slide2vec.models.vision_transformer_dinov2 as vits_dinov2
Expand Down Expand Up @@ -41,9 +47,17 @@ def __init__(
elif options.name == "h-optimus-1":
model = Hoptimus1()
elif options.name == "h-optimus-0-mini" or options.name == "h0-mini":
model = Hoptimus0Mini(
mode=options.mode
)
model = Hoptimus0Mini(mode=options.mode)
elif options.name == "conch":
model = CONCH()
elif options.name == "musk":
model = MUSK()
elif options.name == "phikonv2":
model = PhikonV2()
elif options.name == "hibou":
model = Hibou(arch=options.arch)
elif options.name == "kaiko":
model = Kaiko(arch=options.arch)
elif options.name == "rumc-vit-s-50k":
model = CustomViT(
arch=options.arch,
Expand Down Expand Up @@ -71,6 +85,16 @@ def __init__(
tile_encoder = Hoptimus0()
elif options.name == "h-optimus-1":
tile_encoder = Hoptimus1()
elif options.name == "conch":
model = CONCH()
elif options.name == "musk":
model = MUSK()
elif options.name == "phikonv2":
model = PhikonV2()
elif options.name == "hibou":
Comment thread
clemsgrs marked this conversation as resolved.
model = Hibou()
elif options.name == "kaiko":
model = Kaiko(arch=options.arch)
elif options.name == "rumc-vit-s-50k":
tile_encoder = CustomViT(
arch=options.arch,
Expand Down Expand Up @@ -161,7 +185,9 @@ def load_weights(self):
nn.modules.utils.consume_prefix_in_state_dict_if_present(
state_dict, prefix="backbone."
)
state_dict, msg = update_state_dict(model_dict=self.encoder.state_dict(), state_dict=state_dict)
state_dict, msg = update_state_dict(
model_dict=self.encoder.state_dict(), state_dict=state_dict
)
if distributed.is_main_process():
print(msg)
self.encoder.load_state_dict(state_dict, strict=False)
Expand Down Expand Up @@ -243,7 +269,9 @@ def load_weights(self):
nn.modules.utils.consume_prefix_in_state_dict_if_present(
state_dict, prefix="backbone."
)
state_dict, msg = update_state_dict(model_dict=self.encoder.state_dict(), state_dict=state_dict)
state_dict, msg = update_state_dict(
model_dict=self.encoder.state_dict(), state_dict=state_dict
)
if distributed.is_main_process():
print(msg)
self.encoder.load_state_dict(state_dict, strict=False)
Expand Down Expand Up @@ -478,6 +506,145 @@ def forward(self, x):
return output


class CONCH(FeatureExtractor):
def __init__(self):
self.features_dim = 512
super(CONCH, self).__init__()

def build_encoder(self):
encoder, transform = create_model_from_pretrained(
"conch_ViT-B-16",
"hf_hub:MahmoodLab/conch",
)
self.transform = transform
return encoder

def get_transforms(self):
return self.transform

def forward(self, x):
embedding = self.encoder.encode_image(x, proj_contrast=False, normalize=False)
output = {"embedding": embedding}
return output


class MUSK(FeatureExtractor):
def __init__(self):
self.features_dim = 2048
super(MUSK, self).__init__()

def build_encoder(self):
encoder = timm.create_model("musk_large_patch16_384")
musk_utils.load_model_and_may_interpolate(
"hf_hub:xiangjx/musk", encoder, "model|module", ""
)
return encoder

def get_transforms(self):
return transforms.Compose(
[
transforms.Resize(384, interpolation=3, antialias=True),
transforms.CenterCrop((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
),
]
)

def forward(self, x):
embedding = self.encoder(
image=x,
with_head=False,
out_norm=False,
ms_aug=True,
return_global=True,
)[0]
output = {"embedding": embedding}
return output


class PhikonV2(FeatureExtractor):
def __init__(self):
self.features_dim = 1024
super(PhikonV2, self).__init__()

def build_encoder(self):
return AutoModel.from_pretrained("owkin/phikon-v2", trust_remote_code=True)

def get_transforms(self):
return AutoImageProcessor.from_pretrained("owkin/phikon-v2", trust_remote_code=True)

def forward(self, x):
embedding = self.encoder(x).last_hidden_state[:, 0, :]
output = {"embedding": embedding}
return output


class Kaiko(FeatureExtractor):
def __init__(self, arch: str = "vits16"):
self.arch = arch
self.features_dim = 384
if arch == "vits8":
self.features_dim = 384
elif arch == "vitb8":
self.features_dim = 768
elif arch == "vitb16":
self.features_dim = 768
elif arch == "vitl14":
self.features_dim = 1024
super(Kaiko, self).__init__()

def build_encoder(self):
encoder = torch.hub.load(
"kaiko-ai/towards_large_pathology_fms", self.arch, trust_repo=True
)
return encoder

def get_transforms(self):
return v2.Compose(
[
v2.ToImage(),
v2.Resize(size=224),
v2.CenterCrop(size=224),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
]
)

def forward(self, x):
embedding = self.encoder(x)
import ipdb; ipdb.set_trace()
output = {"embedding": embedding}
return output


class Hibou(FeatureExtractor):
def __init__(self, arch="hibou-b"):
self.arch = arch
self.features_dim = 768
if arch == "hibou-L":
self.features_dim = 1024
super(Hibou, self).__init__()

def build_encoder(self):
model = f"histai/{self.arch}"
return AutoModel.from_pretrained(model, trust_remote_code=True)

def get_transforms(self):
return AutoImageProcessor.from_pretrained(
"histai/hibou-L", trust_remote_code=True
)

def forward(self, x):
embedding = self.encoder(x).last_hidden_state[:, 0, :]
output = {"embedding": embedding}
return output


class RegionFeatureExtractor(nn.Module):
def __init__(
self,
Expand Down