Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion slide2vec/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,17 @@ def main(args):
with autocast_context:
features = torch.load(feature_path).to(model.device)
tile_size_lv0 = coordinates_arr["tile_size_lv0"][0]
wsi_feature = model.forward_slide(
output = model.forward_slide(
features,
tile_coordinates=coordinates,
tile_size_lv0=tile_size_lv0,
)
wsi_feature = output["embedding"].cpu()
if cfg.model.name == "prism" and cfg.model.save_latents:
latent_path = features_dir / f"{name}-latents.pt"
latents = output["latents"].cpu()
torch.save(latents, latent_path)
del latents

torch.save(wsi_feature, feature_path)
del wsi_feature
Expand Down
1 change: 1 addition & 0 deletions slide2vec/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ model:
tile_size: ${tiling.params.tile_size}
patch_size: 256 # if level is "region", size used to unroll the region into patches
save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide"
save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism')

speed:
fp16: false # use mixed precision during model inference
Expand Down
1 change: 1 addition & 0 deletions slide2vec/configs/prism.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ model:
name: "prism"
batch_size: 32
save_tile_embeddings: true # whether to save tile embeddings alongside the pooled slide embedding
save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism')

speed:
fp16: true
Expand Down
4 changes: 2 additions & 2 deletions slide2vec/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_inference(dataloader, model, device, autocast_context, unit, batch_size,
):
idx, image = batch
image = image.to(device, non_blocking=True)
feature = model(image).cpu().numpy()
feature = model(image)["embedding"].cpu().numpy()
features.resize(features.shape[0] + feature.shape[0], axis=0)
features[-feature.shape[0]:] = feature
indices.resize(indices.shape[0] + idx.shape[0], axis=0)
Expand Down Expand Up @@ -227,7 +227,7 @@ def main(args):
with torch.inference_mode(), autocast_context:
sample_batch = next(iter(dataloader))
sample_image = sample_batch[1].to(model.device)
sample_feature = model(sample_image).cpu().numpy()
sample_feature = model(sample_image)["embedding"].cpu().numpy()
feature_dim = sample_feature.shape[1:]
dtype = sample_feature.dtype

Expand Down
65 changes: 45 additions & 20 deletions slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ def get_transforms(self):
return transform

def forward(self, x):
return self.encoder(x)

embedding = self.encoder(x)
output = {"embedding": embedding}
return output

class CustomViT(FeatureExtractor):
def __init__(
Expand Down Expand Up @@ -263,7 +264,9 @@ def get_transforms(self):
return transform

def forward(self, x):
return self.encoder(x)
embedding = self.encoder(x)
output = {"embedding": embedding}
return output


class UNI(FeatureExtractor):
Expand All @@ -281,7 +284,9 @@ def build_encoder(self):
return encoder

def forward(self, x):
return self.encoder(x)
embedding = self.encoder(x)
output = {"embedding": embedding}
return output


class UNI2(FeatureExtractor):
Expand Down Expand Up @@ -311,7 +316,9 @@ def build_encoder(self):
return encoder

def forward(self, x):
return self.encoder(x)
embedding = self.encoder(x)
output = {"embedding": embedding}
return output


class Virchow(FeatureExtractor):
Expand All @@ -338,12 +345,13 @@ def forward(self, x):
:, 1:
] # size: 1 x 256 x 1280, tokens 1-4 are register tokens so we ignore those
if self.mode == "cls":
return class_token
output = {"embedding": class_token}
elif self.mode == "full":
embedding = torch.cat(
[class_token, patch_tokens.mean(1)], dim=-1
) # size: 1 x 2560
return embedding
output = {"embedding": embedding}
return output


class Virchow2(FeatureExtractor):
Expand All @@ -370,12 +378,13 @@ def forward(self, x):
:, 5:
] # size: 1 x 256 x 1280, tokens 1-4 are register tokens so we ignore those
if self.mode == "cls":
return class_token
output = {"embedding": class_token}
elif self.mode == "full":
embedding = torch.cat(
[class_token, patch_tokens.mean(1)], dim=-1
) # size: 1 x 2560
return embedding
output = {"embedding": embedding}
return output


class ProvGigaPath(FeatureExtractor):
Expand All @@ -391,7 +400,9 @@ def build_encoder(self):
return encoder

def forward(self, x):
return self.encoder(x)
embedding = self.encoder(x)
output = {"embedding": embedding}
return output


class Hoptimus0(FeatureExtractor):
Expand All @@ -409,7 +420,9 @@ def build_encoder(self):
return encoder

def forward(self, x):
return self.encoder(x)
embedding = self.encoder(x)
output = {"embedding": embedding}
return output


class Hoptimus1(FeatureExtractor):
Expand All @@ -427,7 +440,9 @@ def build_encoder(self):
return encoder

def forward(self, x):
return self.encoder(x)
embedding = self.encoder(x)
output = {"embedding": embedding}
return output


class Hoptimus0Mini(FeatureExtractor):
Expand All @@ -454,12 +469,13 @@ def forward(self, x):
:, self.encoder.num_prefix_tokens :
] # size: 1 x 256 x 768
if self.mode == "cls":
return cls_features
output = {"embedding": cls_features}
elif self.mode == "full":
embedding = torch.cat(
[cls_features, patch_token_features.mean(1)], dim=-1
) # size: 1 x 1536
return embedding
output = {"embedding": embedding}
return output


class RegionFeatureExtractor(nn.Module):
Expand All @@ -482,9 +498,10 @@ def forward(self, x):
B = x.size(0)
x = rearrange(x, "b p c w h -> (b p) c w h") # [B*num_tiles, 3, 224, 224]
output = self.tile_encoder(x) # [B*num_tiles, features_dim]
output = rearrange(
embedding = rearrange(
output, "(b p) f -> b p f", b=B
) # [B, num_tiles, features_dim]
output = {"embedding": embedding}
return output


Expand Down Expand Up @@ -514,7 +531,7 @@ def forward(self, x):
return self.tile_encoder(x)

def forward_slide(self, **kwargs):
return self.slide_encoder(**kwargs)
raise NotImplementedError


class ProvGigaPathSlide(SlideFeatureExtractor):
Expand All @@ -537,7 +554,8 @@ def build_encoders(self):
def forward_slide(self, tile_features, tile_coordinates, **kwargs):
tile_features = tile_features.unsqueeze(0)
output = self.slide_encoder(tile_features, tile_coordinates)
output = output[0].squeeze()
embedding = output[0].squeeze()
output = {"embedding": embedding}
return output


Expand All @@ -558,16 +576,18 @@ def get_transforms(self):
def forward_slide(self, tile_features, tile_coordinates, tile_size_lv0, **kwargs):
tile_features = tile_features.unsqueeze(0)
tile_coordinates = tile_coordinates.unsqueeze(0)
output = self.slide_encoder.encode_slide_from_patch_features(
embedding = self.slide_encoder.encode_slide_from_patch_features(
tile_features, tile_coordinates, tile_size_lv0
)
output = {"embedding": embedding.squeeze(0)}
return output


class PRISM(SlideFeatureExtractor):
def __init__(self):
def __init__(self, return_latents: bool = False):
super(PRISM, self).__init__()
self.features_dim = self.tile_encoder.features_dim
self.return_latents = return_latents

def build_encoders(self):
self.slide_encoder = AutoModel.from_pretrained(
Expand All @@ -578,5 +598,10 @@ def build_encoders(self):
def forward_slide(self, tile_features, **kwargs):
tile_features = tile_features.unsqueeze(0)
reprs = self.slide_encoder.slide_representations(tile_features)
output = reprs["image_embedding"].squeeze(0) # [1280]
embedding = reprs["image_embedding"].squeeze(0) # [1280]
if self.return_latents:
latents = reprs["image_latents"].squeeze(0) # [512, 1280]
output = {"embedding": embedding, "latents": latents}
else:
output = {"embedding": embedding}
return output
Loading