diff --git a/slide2vec/aggregate.py b/slide2vec/aggregate.py index 55d4743..8f20c67 100644 --- a/slide2vec/aggregate.py +++ b/slide2vec/aggregate.py @@ -132,11 +132,17 @@ def main(args): with autocast_context: features = torch.load(feature_path).to(model.device) tile_size_lv0 = coordinates_arr["tile_size_lv0"][0] - wsi_feature = model.forward_slide( + output = model.forward_slide( features, tile_coordinates=coordinates, tile_size_lv0=tile_size_lv0, ) + wsi_feature = output["embedding"].cpu() + if cfg.model.name == "prism" and cfg.model.save_latents: + latent_path = features_dir / f"{name}-latents.pt" + latents = output["latents"].cpu() + torch.save(latents, latent_path) + del latents torch.save(wsi_feature, feature_path) del wsi_feature diff --git a/slide2vec/configs/default.yaml b/slide2vec/configs/default.yaml index 6d35c92..35acec7 100644 --- a/slide2vec/configs/default.yaml +++ b/slide2vec/configs/default.yaml @@ -45,6 +45,7 @@ model: tile_size: ${tiling.params.tile_size} patch_size: 256 # if level is "region", size used to unroll the region into patches save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide" + save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism') speed: fp16: false # use mixed precision during model inference diff --git a/slide2vec/configs/prism.yaml b/slide2vec/configs/prism.yaml index 0198e6a..4f482f9 100644 --- a/slide2vec/configs/prism.yaml +++ b/slide2vec/configs/prism.yaml @@ -17,6 +17,7 @@ model: name: "prism" batch_size: 32 save_tile_embeddings: true # whether to save tile embeddings alongside the pooled slide embedding + save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism') speed: fp16: true diff --git a/slide2vec/embed.py b/slide2vec/embed.py index 47e0d95..3c3d643 100644 --- a/slide2vec/embed.py +++ b/slide2vec/embed.py @@ -80,7 +80,7 @@ def run_inference(dataloader, model, device, autocast_context, unit, batch_size, ): idx, image = batch image = image.to(device, non_blocking=True) - feature = model(image).cpu().numpy() + feature = model(image)["embedding"].cpu().numpy() features.resize(features.shape[0] + feature.shape[0], axis=0) features[-feature.shape[0]:] = feature indices.resize(indices.shape[0] + idx.shape[0], axis=0) @@ -227,7 +227,7 @@ def main(args): with torch.inference_mode(), autocast_context: sample_batch = next(iter(dataloader)) sample_image = sample_batch[1].to(model.device) - sample_feature = model(sample_image).cpu().numpy() + sample_feature = model(sample_image)["embedding"].cpu().numpy() feature_dim = sample_feature.shape[1:] dtype = sample_feature.dtype diff --git a/slide2vec/models/models.py b/slide2vec/models/models.py index 2856f19..edce725 100644 --- a/slide2vec/models/models.py +++ b/slide2vec/models/models.py @@ -191,8 +191,9 @@ def get_transforms(self): return transform def forward(self, x): - return self.encoder(x) - + embedding = self.encoder(x) + output = {"embedding": embedding} + return output class CustomViT(FeatureExtractor): def __init__( @@ -263,7 +264,9 @@ def get_transforms(self): return transform def forward(self, x): - return self.encoder(x) + embedding = self.encoder(x) + output = {"embedding": embedding} + return output class UNI(FeatureExtractor): @@ -281,7 +284,9 @@ def build_encoder(self): return encoder def forward(self, x): - return self.encoder(x) + embedding = self.encoder(x) + output = {"embedding": embedding} + return output class UNI2(FeatureExtractor): @@ -311,7 +316,9 @@ def build_encoder(self): return encoder def forward(self, x): - return self.encoder(x) + embedding = self.encoder(x) + output = {"embedding": embedding} + return output class Virchow(FeatureExtractor): @@ -338,12 +345,13 @@ def forward(self, x): :, 1: ] # size: 1 x 256 x 1280, tokens 1-4 are register tokens so we ignore those if self.mode == "cls": - return class_token + output = {"embedding": class_token} elif self.mode == "full": embedding = torch.cat( [class_token, patch_tokens.mean(1)], dim=-1 ) # size: 1 x 2560 - return embedding + output = {"embedding": embedding} + return output class Virchow2(FeatureExtractor): @@ -370,12 +378,13 @@ def forward(self, x): :, 5: ] # size: 1 x 256 x 1280, tokens 1-4 are register tokens so we ignore those if self.mode == "cls": - return class_token + output = {"embedding": class_token} elif self.mode == "full": embedding = torch.cat( [class_token, patch_tokens.mean(1)], dim=-1 ) # size: 1 x 2560 - return embedding + output = {"embedding": embedding} + return output class ProvGigaPath(FeatureExtractor): @@ -391,7 +400,9 @@ def build_encoder(self): return encoder def forward(self, x): - return self.encoder(x) + embedding = self.encoder(x) + output = {"embedding": embedding} + return output class Hoptimus0(FeatureExtractor): @@ -409,7 +420,9 @@ def build_encoder(self): return encoder def forward(self, x): - return self.encoder(x) + embedding = self.encoder(x) + output = {"embedding": embedding} + return output class Hoptimus1(FeatureExtractor): @@ -427,7 +440,9 @@ def build_encoder(self): return encoder def forward(self, x): - return self.encoder(x) + embedding = self.encoder(x) + output = {"embedding": embedding} + return output class Hoptimus0Mini(FeatureExtractor): @@ -454,12 +469,13 @@ def forward(self, x): :, self.encoder.num_prefix_tokens : ] # size: 1 x 256 x 768 if self.mode == "cls": - return cls_features + output = {"embedding": cls_features} elif self.mode == "full": embedding = torch.cat( [cls_features, patch_token_features.mean(1)], dim=-1 ) # size: 1 x 1536 - return embedding + output = {"embedding": embedding} + return output class RegionFeatureExtractor(nn.Module): @@ -482,9 +498,10 @@ def forward(self, x): B = x.size(0) x = rearrange(x, "b p c w h -> (b p) c w h") # [B*num_tiles, 3, 224, 224] output = self.tile_encoder(x) # [B*num_tiles, features_dim] - output = rearrange( + embedding = rearrange( output, "(b p) f -> b p f", b=B ) # [B, num_tiles, features_dim] + output = {"embedding": embedding} return output @@ -514,7 +531,7 @@ def forward(self, x): return self.tile_encoder(x) def forward_slide(self, **kwargs): - return self.slide_encoder(**kwargs) + raise NotImplementedError class ProvGigaPathSlide(SlideFeatureExtractor): @@ -537,7 +554,8 @@ def build_encoders(self): def forward_slide(self, tile_features, tile_coordinates, **kwargs): tile_features = tile_features.unsqueeze(0) output = self.slide_encoder(tile_features, tile_coordinates) - output = output[0].squeeze() + embedding = output[0].squeeze() + output = {"embedding": embedding} return output @@ -558,16 +576,18 @@ def get_transforms(self): def forward_slide(self, tile_features, tile_coordinates, tile_size_lv0, **kwargs): tile_features = tile_features.unsqueeze(0) tile_coordinates = tile_coordinates.unsqueeze(0) - output = self.slide_encoder.encode_slide_from_patch_features( + embedding = self.slide_encoder.encode_slide_from_patch_features( tile_features, tile_coordinates, tile_size_lv0 ) + output = {"embedding": embedding.squeeze(0)} return output class PRISM(SlideFeatureExtractor): - def __init__(self): + def __init__(self, return_latents: bool = False): super(PRISM, self).__init__() self.features_dim = self.tile_encoder.features_dim + self.return_latents = return_latents def build_encoders(self): self.slide_encoder = AutoModel.from_pretrained( @@ -578,5 +598,10 @@ def build_encoders(self): def forward_slide(self, tile_features, **kwargs): tile_features = tile_features.unsqueeze(0) reprs = self.slide_encoder.slide_representations(tile_features) - output = reprs["image_embedding"].squeeze(0) # [1280] + embedding = reprs["image_embedding"].squeeze(0) # [1280] + if self.return_latents: + latents = reprs["image_latents"].squeeze(0) # [512, 1280] + output = {"embedding": embedding, "latents": latents} + else: + output = {"embedding": embedding} return output