Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pip install slide2vec

A good starting point is the default configuration file `slide2vec/configs/default.yaml` where parameters are documented.<br>
We've also added default configuration files for each of the foundation models currently supported:
- tile-level: `uni`, `uni2`, `virchow`, `virchow2`, `prov-gigapath`, `h-optimus-0`, `h-optimus-1`, `h0-mini`
- tile-level: `uni`, `uni2`, `virchow`, `virchow2`, `prov-gigapath`, `h-optimus-0`, `h-optimus-1`, `h0-mini`, `conch`, `musk`, `phikonv2`, `hibou`, `kaiko`
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
- slide-level: `prov-gigapath`, `titan`, `prism`


Expand Down
4 changes: 4 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ environs
xformers>=0.0.31
matplotlib

## MUSK & Conch
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
git+https://github.com/lilab-stanford/MUSK.git
git+https://github.com/Mahmoodlab/CONCH.git

## gigapath
torchmetrics>=0.10.3
fvcore
Expand Down
18 changes: 18 additions & 0 deletions slide2vec/configs/conch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 448 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "conch"
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
18 changes: 18 additions & 0 deletions slide2vec/configs/hibou.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 224 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "hibou"
Comment thread
clemsgrs marked this conversation as resolved.
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
19 changes: 19 additions & 0 deletions slide2vec/configs/kaiko.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 224 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "kaiko"
mode: "vitl14" # kaiko model mode, options: ("vits8", "vits16", "vitb8", "vitb16", "vitl14")
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
18 changes: 18 additions & 0 deletions slide2vec/configs/musk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 384 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "musk"
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
18 changes: 18 additions & 0 deletions slide2vec/configs/phikonv2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
csv: # path to csv containing slide paths

output_dir: "output" # output directory

visualize: true

tiling:
params:
spacing: 0.5 # spacing at which to tile the slide, in microns per pixel
tile_size: 224 # size of the tiles to extract, in pixels

model:
level: "tile" # level at which to extract the features ("tile", "region" or "slide")
name: "phikonv2"
batch_size: 1

speed:
fp16: true # use mixed precision during model inference
6 changes: 5 additions & 1 deletion slide2vec/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import wholeslidedata as wsd

from transformers.image_processing_utils import BaseImageProcessor
from PIL import Image
from pathlib import Path

Expand Down Expand Up @@ -58,5 +59,8 @@ def __getitem__(self, idx):
if self.tile_size[idx] != self.tile_size_resized[idx]:
tile = tile.resize((self.tile_size[idx], self.tile_size[idx]))
if self.transforms:
tile = self.transforms(tile)
if isinstance(self.transforms, BaseImageProcessor): # Hugging Face (`transformer`)
tile = self.transforms(tile, return_tensors="pt")["pixel_values"].squeeze(0)
Comment thread
clemsgrs marked this conversation as resolved.
else: # general callable such as torchvision transforms
tile = self.transforms(tile)
return idx, tile
173 changes: 167 additions & 6 deletions slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@

from einops import rearrange
from omegaconf import DictConfig
from transformers import AutoModel
from transformers import AutoModel, AutoImageProcessor
from torchvision import transforms
from timm.data import resolve_data_config
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data.transforms_factory import create_transform

from conch.open_clip_custom import create_model_from_pretrained
from musk import modeling as musk_modeling
from musk import utils as musk_utils

import slide2vec.distributed as distributed
import slide2vec.models.vision_transformer_dino as vits_dino
import slide2vec.models.vision_transformer_dinov2 as vits_dinov2
Expand Down Expand Up @@ -41,9 +46,17 @@ def __init__(
elif options.name == "h-optimus-1":
model = Hoptimus1()
elif options.name == "h-optimus-0-mini" or options.name == "h0-mini":
model = Hoptimus0Mini(
mode=options.mode
)
model = Hoptimus0Mini(mode=options.mode)
elif options.name == "conch":
model = Conch()
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
elif options.name == "musk":
model = MUSK()
elif options.name == "phikonv2":
model = PhikonV2()
elif options.name == "hibou":
model = Hibou()
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
elif options.name == "kaiko":
model = Kaiko(mode=options.mode)
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
elif options.name == "rumc-vit-s-50k":
model = CustomViT(
arch=options.arch,
Expand Down Expand Up @@ -71,6 +84,16 @@ def __init__(
tile_encoder = Hoptimus0()
elif options.name == "h-optimus-1":
tile_encoder = Hoptimus1()
elif options.name == "conch":
model = Conch()
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
elif options.name == "musk":
model = MUSK()
elif options.name == "phikonv2":
model = PhikonV2()
elif options.name == "hibou":
Comment thread
clemsgrs marked this conversation as resolved.
model = Hibou()
elif options.name == "kaiko":
model = Kaiko(mode=options.mode)
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
elif options.name == "rumc-vit-s-50k":
tile_encoder = CustomViT(
arch=options.arch,
Expand Down Expand Up @@ -161,7 +184,9 @@ def load_weights(self):
nn.modules.utils.consume_prefix_in_state_dict_if_present(
state_dict, prefix="backbone."
)
state_dict, msg = update_state_dict(model_dict=self.encoder.state_dict(), state_dict=state_dict)
state_dict, msg = update_state_dict(
model_dict=self.encoder.state_dict(), state_dict=state_dict
)
if distributed.is_main_process():
print(msg)
self.encoder.load_state_dict(state_dict, strict=False)
Expand Down Expand Up @@ -243,7 +268,9 @@ def load_weights(self):
nn.modules.utils.consume_prefix_in_state_dict_if_present(
state_dict, prefix="backbone."
)
state_dict, msg = update_state_dict(model_dict=self.encoder.state_dict(), state_dict=state_dict)
state_dict, msg = update_state_dict(
model_dict=self.encoder.state_dict(), state_dict=state_dict
)
if distributed.is_main_process():
print(msg)
self.encoder.load_state_dict(state_dict, strict=False)
Expand Down Expand Up @@ -478,6 +505,140 @@ def forward(self, x):
return output


class Conch(FeatureExtractor):
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
def __init__(self):
self.features_dim = 512
super(Conch, self).__init__()
Comment thread
clemsgrs marked this conversation as resolved.
Outdated

def build_encoder(self):
encoder, transform = create_model_from_pretrained(
"conch_ViT-B-16",
"hf_hub:MahmoodLab/conch",
)
self.transform = transform
return encoder

def get_transforms(self):
return self.transform

def forward(self, x):
embedding = self.encoder.encode_image(x, proj_contrast=False, normalize=False)
output = {"embedding": embedding}
return output


class MUSK(FeatureExtractor):
def __init__(self):
self.features_dim = 2048
super(MUSK, self).__init__()

def build_encoder(self):
encoder = timm.create_model("musk_large_patch16_384")
musk_utils.load_model_and_may_interpolate(
"hf_hub:xiangjx/musk", encoder, "model|module", ""
)
return encoder

def get_transforms(self):
return transforms.Compose(
[
transforms.Resize(384, interpolation=3, antialias=True),
transforms.CenterCrop((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
),
]
)

def forward(self, x):
embedding = self.encoder(
image=x,
with_head=False,
out_norm=False,
ms_aug=True,
return_global=True,
)[0]
output = {"embedding": embedding}
return output


class PhikonV2(FeatureExtractor):
def __init__(self):
self.features_dim = 1024
super(PhikonV2, self).__init__()

def build_encoder(self):
return AutoModel.from_pretrained("owkin/phikon-v2", trust_remote_code=True)

def get_transforms(self):
return AutoImageProcessor.from_pretrained("owkin/phikon-v2", trust_remote_code=True)

def forward(self, x):
embedding = self.encoder(x).last_hidden_state[:, 0, :]
output = {"embedding": embedding}
return output


class Kaiko(FeatureExtractor):
def __init__(self, mode: str = "vits16"):
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
self.mode = mode
self.features_dim = 384
if mode == "vits8":
self.features_dim = 384
elif mode == "vitb8":
self.features_dim = 768
elif mode == "vitb16":
self.features_dim = 768
elif mode == "vitl14":
self.features_dim = 1024
super(Kaiko, self).__init__()

def build_encoder(self):
encoder = torch.hub.load(
"kaiko-ai/towards_large_pathology_fms", self.mode, trust_repo=True
)
return encoder

def get_transforms(self):
return transforms.Compose(
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
[
transforms.Resize(size=224),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
]
)

def forward(self, x):
embedding = self.encoder(x)
import ipdb; ipdb.set_trace()
output = {"embedding": embedding}
return output


class Hibou(FeatureExtractor):
def __init__(self):
Comment thread
clemsgrs marked this conversation as resolved.
Outdated
self.features_dim = 1024
super(Hibou, self).__init__()

def build_encoder(self):
return AutoModel.from_pretrained("histai/hibou-L", trust_remote_code=True)

def get_transforms(self):
return AutoImageProcessor.from_pretrained(
"histai/hibou-L", trust_remote_code=True
)

def forward(self, x):
embedding = self.encoder(x).last_hidden_state[:, 0, :]
output = {"embedding": embedding}
return output


class RegionFeatureExtractor(nn.Module):
def __init__(
self,
Expand Down