diff --git a/docs/models.md b/docs/models.md index b22a143..1bf9943 100644 --- a/docs/models.md +++ b/docs/models.md @@ -8,7 +8,7 @@ The canonical model presets are registered in code and documented below. Use the Preset-specific behavior lives in registry metadata and, where supported, `model.output_variant`. -## Tile-level models (17) +## Tile-level models (18) | Preset | Model | Supported Spacing (um) | Notes | | --- | --- | --- | --- | @@ -20,6 +20,7 @@ Preset-specific behavior lives in registry metadata and, where supported, `model | `h0-mini` | [H0-mini](https://huggingface.co/bioptimus/H0-mini) | `0.5` | Supports `output_variant="cls"` or `"cls_patch_mean"` | | `hibou-b` | [Hibou-B](https://huggingface.co/histai/hibou-b) | `0.5` | | | `hibou-l` | [Hibou-L](https://huggingface.co/histai/hibou-L) | `0.5` | | +| `lunit` | [Lunit ViT-S/8](https://huggingface.co/1aurent/vit_small_patch8_224.lunit_dino) | `0.5` | 384-dim; used as tile backbone for MOOZY | | `midnight` | [MidNight12k](https://huggingface.co/kaiko-ai/midnight) | `0.25`, `0.5`, `1.0`, `2.0` | Alias: `kaiko-midnight` | | `musk` | [MUSK](https://huggingface.co/xiangjx/musk) | `0.25`, `0.5`, `1.0` | Supports `output_variant="ms_aug"` (2048-dim, default) or `"cls"` (1024-dim). | | `phikon` | [Phikon](https://huggingface.co/owkin/phikon) | `0.5` | | @@ -30,10 +31,36 @@ Preset-specific behavior lives in registry metadata and, where supported, `model | `virchow` | [Virchow](https://huggingface.co/paige-ai/Virchow) | `0.5` | Supports `output_variant="cls"` or `"cls_patch_mean"` | | `virchow2` | [Virchow2](https://huggingface.co/paige-ai/Virchow2) | `0.5`, `1.0`, `2.0` | Supports `output_variant="cls"` or `"cls_patch_mean"` | -## Slide-level models (3) +## Slide-level models (4) -| Preset | Model | Tile Encoder | Supported Spacing (um) | -| --- | --- | --- | --- | -| `gigapath-slide` | [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) | `gigapath` | `0.5` | -| `prism` | [PRISM](https://huggingface.co/paige-ai/PRISM) | `virchow` (cls_patch_mean) | `0.5` | -| `titan` | [TITAN](https://huggingface.co/MahmoodLab/TITAN) | `conchv15` | `0.5` | +| Preset | Model | Tile Encoder | Supported Spacing (um) | Notes | +| --- | --- | --- | --- | --- | +| `gigapath-slide` | [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) | `gigapath` | `0.5` | | +| `moozy-slide` | [MOOZY](https://huggingface.co/AtlasAnalyticsLab/MOOZY) | `lunit` | `0.5` | 768-dim slide embedding; standalone slide encoder from the MOOZY stage-2 checkpoint | +| `prism` | [PRISM](https://huggingface.co/paige-ai/PRISM) | `virchow` (cls_patch_mean) | `0.5` | | +| `titan` | [TITAN](https://huggingface.co/MahmoodLab/TITAN) | `conchv15` | `0.5` | | + +## Patient-level models (1) + +Patient-level models aggregate multiple slide embeddings for the same patient into a single patient-level embedding. They require a `patient_id` column in the input manifest CSV (or `patient_id` keys in each slide dict when using the Python API). + +| Preset | Model | Tile Encoder | Supported Spacing (um) | Notes | +| --- | --- | --- | --- | --- | +| `moozy` | [MOOZY](https://huggingface.co/AtlasAnalyticsLab/MOOZY) | `lunit` | `0.5` | 768-dim patient embedding; runs Lunit tile encoder → MOOZY slide encoder → CaseAggregator transformer | + +### Patient manifest format + +Add a `patient_id` column to the standard manifest CSV to group slides by patient: + +```csv +sample_id,image_path,patient_id +slide_1a,/data/slide_1a.svs,patient_1 +slide_1b,/data/slide_1b.svs,patient_1 +slide_2a,/data/slide_2a.svs,patient_2 +``` + +`sample_id` remains the unique slide identifier. Multiple rows may share the same `patient_id`. + +### Per-slide embeddings + +When running a patient-level model via `Pipeline`, the intermediate per-slide MOOZY embeddings can be saved alongside the patient embeddings by setting `save_slide_embeddings: true` in config (or `ExecutionOptions(save_slide_embeddings=True)` in the Python API). Saved slide embeddings are written to `slide_embeddings/` in the output directory. diff --git a/docs/python-api.md b/docs/python-api.md index ec7b2c5..7f911e3 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -2,7 +2,7 @@ `slide2vec` exposes two main workflows: -- direct in-memory embedding with `Model.embed_slide(...)` and `Model.embed_slides(...)` +- direct in-memory embedding with `Model.embed_slide(...)`, `Model.embed_slides(...)`, `Model.embed_patient(...)`, and `Model.embed_patients(...)` - artifact generation with `Pipeline.run(...)` ## Minimal interactive usage @@ -108,12 +108,60 @@ Common fields: - `output_dir` - `output_format` - `"pt"` (default) or `"npz"` - `save_tile_embeddings` - persist tile embeddings for slide-level models (default `False`) +- `save_slide_embeddings` - persist per-slide embeddings when running a patient-level model (default `False`) - `save_latents` - persist latent representations when available (default `False`) `num_gpus` defaults to all available GPUs. `embed_slide(...)` uses tile sharding for one slide, and `embed_slides(...)` balances whole slides across GPUs while preserving input order. If you need persisted artifact generation without using `Pipeline.run(...)`, use `Model.embed_tiles(...)` and `Model.aggregate_tiles(...)`. +## Patient-level embedding + +For patient-level models (e.g. `moozy`), use `Model.embed_patient(...)` for a single patient or `Model.embed_patients(...)` for a batch of patients. + +### Single patient + +```python +from slide2vec import Model + +model = Model.from_preset("moozy") +result = model.embed_patient( + ["/data/slide_1a.svs", "/data/slide_1b.svs"], + patient_id="patient_1", +) + +print(result.patient_id) # "patient_1" +print(result.patient_embedding.shape) # torch.Size([768]) +print(result.slide_embeddings) # {"slide_1a": tensor, "slide_1b": tensor} +``` + +`embed_patient(...)` returns a single `EmbeddedPatient`. The `patient_id` argument is optional — when omitted, it is read from `patient_id` keys in the slide dicts, or falls back to `sample_id`. + +### Multiple patients + +```python +results = model.embed_patients( + [ + {"sample_id": "slide_1a", "image_path": "/data/slide_1a.svs", "patient_id": "patient_1"}, + {"sample_id": "slide_1b", "image_path": "/data/slide_1b.svs", "patient_id": "patient_1"}, + {"sample_id": "slide_2a", "image_path": "/data/slide_2a.svs", "patient_id": "patient_2"}, + ] +) + +for r in results: + print(r.patient_id, r.patient_embedding.shape) +``` + +`embed_patients(...)` returns one `EmbeddedPatient` per unique patient, ordered by first appearance. Pass an explicit `patient_id_map` dict (`{sample_id: patient_id}`) to override the per-slide `patient_id` keys. + +Each `EmbeddedPatient` has: + +- `patient_id` +- `patient_embedding` — tensor of shape `(D,)` (768 for MOOZY) +- `slide_embeddings` — `{sample_id: tensor}` for each contributing slide + +Both methods raise a `ValueError` if called on a non-patient-level model. + ## Hierarchical Feature Extraction Hierarchical mode spatially groups tiles into regions before embedding, producing outputs with shape `(num_regions, tiles_per_region, feature_dim)`. This is useful for downstream models that consume region-level spatial structure rather than flat tile bags. @@ -170,9 +218,10 @@ result = pipeline.run(manifest_path="/path/to/slides.csv") - `tile_artifacts` - `hierarchical_artifacts` - `slide_artifacts` +- `patient_artifacts` — populated when using a patient-level model (e.g. `moozy`); one entry per unique patient, written to `patient_embeddings/` in the output directory - `process_list_path` -The manifest schema matches HS2P and accepts optional `mask_path` and `spacing_at_level_0` columns. +The manifest schema matches HS2P and accepts optional `mask_path` and `spacing_at_level_0` columns. Patient-level models additionally require a `patient_id` column; see [Patient manifest format](models.md#patient-manifest-format). ### Reusing pre-extracted coordinates diff --git a/pyproject.toml b/pyproject.toml index f8f3cd6..9af34cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ hibou = [ "scipy~=1.8.1", "scikit-image~=0.19.3", ] +moozy = [ + "moozy", +] titan = [ "torch==2.0.1", "timm==1.0.3", @@ -106,6 +109,7 @@ fm = [ "scikit-survival", "scikit-learn", "fairscale", + "moozy", "packaging==23.2", "ninja==1.11.1.1", "psutil<6", diff --git a/slide2vec/api.py b/slide2vec/api.py index e9978a4..d7ef7bf 100644 --- a/slide2vec/api.py +++ b/slide2vec/api.py @@ -11,6 +11,7 @@ from slide2vec.artifacts import ( HierarchicalEmbeddingArtifact, + PatientEmbeddingArtifact, SlideEmbeddingArtifact, TileEmbeddingArtifact, ) @@ -127,6 +128,7 @@ class ExecutionOptions: prefetch_factor: int = 4 persistent_workers: bool = True save_tile_embeddings: bool = False + save_slide_embeddings: bool = False save_latents: bool = False @classmethod @@ -151,6 +153,7 @@ def from_config(cls, cfg: Any, *, run_on_cpu: bool = False) -> "ExecutionOptions prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, save_tile_embeddings=bool(cfg.model.save_tile_embeddings), + save_slide_embeddings=bool(cfg.model.save_slide_embeddings), save_latents=bool(cfg.model.save_latents), ) @@ -200,9 +203,17 @@ class RunResult: tile_artifacts: list[TileEmbeddingArtifact] hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] slide_artifacts: list[SlideEmbeddingArtifact] + patient_artifacts: list[PatientEmbeddingArtifact] = field(default_factory=list) process_list_path: Path | None = None +@dataclass(frozen=True, kw_only=True) +class EmbeddedPatient: + patient_id: str + patient_embedding: Any # torch.Tensor [D] + slide_embeddings: dict[str, Any] # {sample_id: torch.Tensor [D]} + + @dataclass(frozen=True, kw_only=True) class EmbeddedSlide: sample_id: str @@ -343,6 +354,82 @@ def embed_slides( execution=resolved, ) + def embed_patient( + self, + slides: SlideSequence, + patient_id: str | None = None, + *, + preprocessing: PreprocessingConfig | None = None, + execution: ExecutionOptions | None = None, + ) -> "EmbeddedPatient": + """Embed a single patient's slides and return one ``EmbeddedPatient``. + + Convenience wrapper around :meth:`embed_patients` for the common case + where all *slides* belong to the same patient. + + Args: + slides: All slides for this patient. + patient_id: Optional patient identifier applied to every slide. + When omitted, ``patient_id`` is read from slide dict keys or + object attributes; slides that carry no ``patient_id`` fall + back to ``sample_id``. + """ + patient_id_map: dict | None = None + if patient_id is not None: + patient_id_map = {} + for s in slides: + if isinstance(s, (str, Path)): + patient_id_map[Path(s).stem] = patient_id + elif isinstance(s, dict): + patient_id_map[str(s["sample_id"])] = patient_id + else: + patient_id_map[str(s.sample_id)] = patient_id + return self.embed_patients( + slides, + patient_id_map=patient_id_map, + preprocessing=preprocessing, + execution=execution, + )[0] + + def embed_patients( + self, + slides: SlideSequence, + patient_id_map: dict | None = None, + *, + preprocessing: PreprocessingConfig | None = None, + execution: ExecutionOptions | None = None, + ) -> "list[EmbeddedPatient]": + """Embed slides and aggregate them into patient-level embeddings. + + Requires a patient-level model (e.g. ``moozy``). For each patient + all contributing slide embeddings are aggregated by the model's + ``encode_patient`` method. + + Args: + slides: Slides to process. Each entry may be a path, a + ``SlideSpec``, or a dict with ``sample_id`` / ``image_path`` + keys. When *patient_id_map* is ``None`` a ``patient_id`` + key in each dict is used to group slides. + patient_id_map: Optional explicit ``{sample_id: patient_id}`` + mapping. When provided it takes precedence over any + ``patient_id`` key embedded in the slide dicts. When + omitted and the slide dicts carry no ``patient_id``, each + slide is treated as its own patient. + """ + from slide2vec.inference import embed_patients + + resolved = _coerce_execution_options(execution, model=self) + resolved_preprocessing = _resolve_direct_api_preprocessing(self, preprocessing) + with _auto_progress_reporting(output_dir=resolved.output_dir): + _validate_model_config(self, resolved_preprocessing, resolved) + return embed_patients( + self, + slides, + patient_id_map=patient_id_map, + preprocessing=resolved_preprocessing, + execution=resolved, + ) + def _load_backend(self) -> LoadedModel: if self._backend is None: from slide2vec.inference import load_model diff --git a/slide2vec/artifacts.py b/slide2vec/artifacts.py index 1d83c39..be2d734 100644 --- a/slide2vec/artifacts.py +++ b/slide2vec/artifacts.py @@ -35,6 +35,20 @@ def metadata(self) -> dict[str, Any]: return load_metadata(self.metadata_path) +@dataclass(frozen=True, kw_only=True) +class PatientEmbeddingArtifact: + patient_id: str + path: Path + metadata_path: Path + format: str + feature_dim: int + num_slides: int + + @property + def metadata(self) -> dict[str, Any]: + return load_metadata(self.metadata_path) + + @dataclass(frozen=True, kw_only=True) class HierarchicalEmbeddingArtifact: sample_id: str @@ -223,6 +237,45 @@ def write_slide_embeddings( ) +def write_patient_embeddings( + patient_id: str, + embedding, + *, + output_dir: str | Path, + output_format: str = "pt", + metadata: dict[str, Any] | None = None, + num_slides: int = 0, +) -> PatientEmbeddingArtifact: + output_format = _validate_output_format(output_format) + artifact_path, metadata_path = _setup_artifact_paths( + output_dir, "patient_embeddings", patient_id, output_format + ) + embedding_array = _ensure_array(embedding) + if output_format == "pt": + torch.save(_ensure_tensor(embedding), artifact_path) + else: + np.savez_compressed(artifact_path, features=embedding_array) + + patient_metadata = { + "patient_id": patient_id, + "artifact_type": "patient_embeddings", + "format": output_format, + "feature_dim": int(embedding_array.shape[-1]) if embedding_array.ndim else 1, + "num_slides": num_slides, + } + if metadata: + patient_metadata.update(metadata) + _write_metadata(metadata_path, patient_metadata) + return PatientEmbeddingArtifact( + patient_id=patient_id, + path=artifact_path, + metadata_path=metadata_path, + format=output_format, + feature_dim=patient_metadata["feature_dim"], + num_slides=num_slides, + ) + + def write_hierarchical_embeddings( sample_id: str, features, diff --git a/slide2vec/configs/default.yaml b/slide2vec/configs/default.yaml index 0aedce9..e9b4c74 100644 --- a/slide2vec/configs/default.yaml +++ b/slide2vec/configs/default.yaml @@ -13,6 +13,7 @@ model: output_variant: # requested output variant for presets that expose multiple outputs batch_size: 32 save_tile_embeddings: false # whether to save tile embeddings alongside the pooled slide embedding when level is "slide" + save_slide_embeddings: false # whether to save per-slide embeddings when level is "patient" (e.g. moozy); requires a 'patient_id' column in the input CSV save_latents: false # whether to save the latent representations from the model alongside the slide embedding (only supported for 'prism') allow_non_recommended_settings: false # when true, non-recommended spacing / tile size / precision combinations warn instead of erroring diff --git a/slide2vec/encoders/__init__.py b/slide2vec/encoders/__init__.py index 396e9e6..0b7f88f 100644 --- a/slide2vec/encoders/__init__.py +++ b/slide2vec/encoders/__init__.py @@ -6,6 +6,7 @@ from slide2vec.encoders.base import ( Encoder, + PatientEncoder, SlideEncoder, TileEncoder, TimmTileEncoder, @@ -24,6 +25,7 @@ __all__ = [ "Encoder", + "PatientEncoder", "TileEncoder", "SlideEncoder", "TimmTileEncoder", diff --git a/slide2vec/encoders/base.py b/slide2vec/encoders/base.py index fd3cdd3..9b70ea1 100644 --- a/slide2vec/encoders/base.py +++ b/slide2vec/encoders/base.py @@ -96,6 +96,33 @@ def prepare_coordinates( return coordinates +class PatientEncoder(Encoder): + """Base class for encoders that aggregate slide embeddings into patient embeddings.""" + + tile_encoder: TileEncoder | None = None + + def encode_tiles(self, batch: Tensor) -> Tensor: + if self.tile_encoder is None: + raise AttributeError("patient encoders must attach a tile_encoder before encoding tiles") + return self.tile_encoder.encode_tiles(batch) + + @abstractmethod + def encode_slide( + self, + tile_features: Tensor, + coordinates: Tensor | None = None, + *, + tile_size_lv0: int | None = None, + ) -> Tensor: + """Pool tile-level features into a single slide-level embedding.""" + ... + + @abstractmethod + def encode_patient(self, slide_embeddings: Tensor) -> Tensor: + """Aggregate slide embeddings [S, D] into a single patient-level embedding [D].""" + ... + + class TimmTileEncoder(TileEncoder): """Convenience base for timm-backed tile encoders.""" diff --git a/slide2vec/encoders/models/__init__.py b/slide2vec/encoders/models/__init__.py index c799a85..92053c6 100644 --- a/slide2vec/encoders/models/__init__.py +++ b/slide2vec/encoders/models/__init__.py @@ -8,7 +8,9 @@ gigapath, hibou, hoptimus, + lunit, midnight, + moozy, musk, phikon, prost40m, @@ -23,7 +25,9 @@ "gigapath", "hibou", "hoptimus", + "lunit", "midnight", + "moozy", "musk", "phikon", "prost40m", diff --git a/slide2vec/encoders/models/lunit.py b/slide2vec/encoders/models/lunit.py new file mode 100644 index 0000000..4bf3eb8 --- /dev/null +++ b/slide2vec/encoders/models/lunit.py @@ -0,0 +1,21 @@ +"""Lunit ViT-S/8 tile encoder implementation.""" + +from slide2vec.encoders.base import TimmTileEncoder +from slide2vec.encoders.registry import register_encoder + + +@register_encoder( + "lunit", + output_variants={"default": {"encode_dim": 384}}, + default_output_variant="default", + input_size=224, + supported_spacing_um=0.5, + precision="fp32", + source="1aurent/vit_small_patch8_224.lunit_dino", +) +class LunitTileEncoder(TimmTileEncoder): + def __init__(self, *, output_variant: str | None = None): + super().__init__( + "hf_hub:1aurent/vit_small_patch8_224.lunit_dino", + output_variant=output_variant, + ) diff --git a/slide2vec/encoders/models/moozy.py b/slide2vec/encoders/models/moozy.py new file mode 100644 index 0000000..1f42419 --- /dev/null +++ b/slide2vec/encoders/models/moozy.py @@ -0,0 +1,114 @@ +"""MOOZY slide and patient encoder implementations.""" + +import torch + +from slide2vec.encoders.base import PatientEncoder, SlideEncoder, preferred_default_device, resolve_requested_output_variant +from slide2vec.encoders.registry import register_encoder + + +@register_encoder( + "moozy-slide", + level="slide", + tile_encoder="lunit", + tile_encoder_output_variant="default", + output_variants={"default": {"encode_dim": 768}}, + default_output_variant="default", + supported_spacing_um=0.5, + precision="fp32", + source="AtlasAnalyticsLab/MOOZY", +) +class MOOZYSlideEncoder(SlideEncoder): + def __init__(self, *, output_variant: str | None = None): + from moozy.hf_hub import ensure_checkpoint + from moozy.models.factory import load_stage2_inference_model + + ckpt_path = ensure_checkpoint() + full_model = load_stage2_inference_model(ckpt_path, device=torch.device("cpu")) + self._model = full_model.slide_encoder.eval() + self._device = preferred_default_device() + self._output_variant = resolve_requested_output_variant(output_variant) + + @property + def encode_dim(self) -> int: + return 768 + + @property + def device(self) -> torch.device: + return self._device + + def to(self, device: torch.device | str) -> "MOOZYSlideEncoder": + self._device = torch.device(device) + self._model = self._model.to(self._device) + return self + + def encode_slide( + self, + tile_features: torch.Tensor, + coordinates: torch.Tensor | None = None, + *, + tile_size_lv0: int | None = None, + ) -> torch.Tensor: + if coordinates is None or tile_size_lv0 is None: + raise ValueError("MOOZY slide encoding requires coordinates and tile_size_lv0") + # MOOZYSlideEncoder expects [B, crop_h, crop_w, feat_dim]; use [1, 1, N, D] + x = tile_features.unsqueeze(0).unsqueeze(0) + coords = coordinates.unsqueeze(0).to(torch.float32) + patch_sizes = torch.tensor([tile_size_lv0], dtype=torch.float32, device=tile_features.device) + cls, _, _ = self._model(x, coords_xy=coords, patch_sizes=patch_sizes) + return cls.squeeze(0) + + +@register_encoder( + "moozy", + level="patient", + tile_encoder="lunit", + tile_encoder_output_variant="default", + output_variants={"default": {"encode_dim": 768}}, + default_output_variant="default", + supported_spacing_um=0.5, + precision="fp32", + source="AtlasAnalyticsLab/MOOZY", +) +class MOOZYPatientEncoder(PatientEncoder): + def __init__(self, *, output_variant: str | None = None): + from moozy.hf_hub import ensure_checkpoint + from moozy.models.factory import load_stage2_inference_model + + ckpt_path = ensure_checkpoint() + full_model = load_stage2_inference_model(ckpt_path, device=torch.device("cpu")) + self._slide_model = full_model.slide_encoder.eval() + self._case_transformer = full_model.case_transformer.eval() + self._device = preferred_default_device() + self._output_variant = resolve_requested_output_variant(output_variant) + + @property + def encode_dim(self) -> int: + return 768 + + @property + def device(self) -> torch.device: + return self._device + + def to(self, device: torch.device | str) -> "MOOZYPatientEncoder": + self._device = torch.device(device) + self._slide_model = self._slide_model.to(self._device) + self._case_transformer = self._case_transformer.to(self._device) + return self + + def encode_slide( + self, + tile_features: torch.Tensor, + coordinates: torch.Tensor | None = None, + *, + tile_size_lv0: int | None = None, + ) -> torch.Tensor: + if coordinates is None or tile_size_lv0 is None: + raise ValueError("MOOZY patient encoding requires coordinates and tile_size_lv0") + x = tile_features.unsqueeze(0).unsqueeze(0) + coords = coordinates.unsqueeze(0).to(torch.float32) + patch_sizes = torch.tensor([tile_size_lv0], dtype=torch.float32, device=tile_features.device) + cls, _, _ = self._slide_model(x, coords_xy=coords, patch_sizes=patch_sizes) + return cls.squeeze(0) + + def encode_patient(self, slide_embeddings: torch.Tensor) -> torch.Tensor: + return self._case_transformer(slide_embeddings) diff --git a/slide2vec/encoders/registry.py b/slide2vec/encoders/registry.py index ef4ce10..b2772b5 100644 --- a/slide2vec/encoders/registry.py +++ b/slide2vec/encoders/registry.py @@ -26,7 +26,7 @@ def resolve_encoder_level( ) -> str: """Resolve and validate one encoder level contract.""" level = str(require_encoder_metadata_field(encoder_name, metadata, "level")) - if level not in {"tile", "slide"}: + if level not in {"tile", "slide", "patient"}: raise ValueError(f"Unsupported encoder level '{level}'") return level @@ -101,7 +101,7 @@ def resolve_preprocessing_requirements( "source_encoder": encoder_name, } - if level == "slide": + if level in {"slide", "patient"}: tile_encoder_name = str( require_encoder_metadata_field(encoder_name, info, "tile_encoder") ) @@ -164,10 +164,10 @@ def resolve_encoder_output( f"Encoder '{encoder_name}' has invalid default_output_variant " f"'{default_output_variant}'" ) - if requested_output_variant is not None and level == "slide": + if requested_output_variant is not None and level in {"slide", "patient"}: raise ValueError( - f"Slide encoder '{encoder_name}' has a fixed output_variant; " - "do not override output_variant for slide encoders." + f"Slide encoder '{encoder_name}' (level={level}) has a fixed output_variant; " + "do not override output_variant for slide or patient encoders." ) output_variant = requested_output_variant or str(default_output_variant) @@ -196,6 +196,7 @@ def resolve_tile_dependency_output( resolved["encoder_name"] = encoder_name return resolved + # Both "slide" and "patient" declare tile_encoder / tile_encoder_output_variant. tile_encoder_name = str( require_encoder_metadata_field(encoder_name, info, "tile_encoder") ) diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 3e2300d..40e970a 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -22,6 +22,7 @@ from transformers.image_processing_utils import BaseImageProcessor from slide2vec.api import ( + EmbeddedPatient, EmbeddedSlide, ExecutionOptions, PreprocessingConfig, @@ -30,11 +31,13 @@ ) from slide2vec.artifacts import ( HierarchicalEmbeddingArtifact, + PatientEmbeddingArtifact, SlideEmbeddingArtifact, TileEmbeddingArtifact, write_hierarchical_embeddings, load_array, load_metadata, + write_patient_embeddings, write_slide_embeddings, write_tile_embedding_metadata, write_tile_embeddings, @@ -58,6 +61,7 @@ from slide2vec.utils.coordinates import coordinate_arrays from slide2vec.utils.tiling_io import ( load_embedding_process_df, + load_patient_id_mapping, load_slide_manifest, load_tiling_process_df, load_tiling_result_from_row, @@ -288,6 +292,7 @@ def load_model( if resolved_level == "tile": transforms = encoder.get_transform() else: + # Both "slide" and "patient" declare tile_encoder for transform resolution. tile_enc_name = info["tile_encoder"] tile_enc_ov = info["tile_encoder_output_variant"] tile_enc_cls = encoder_registry.require(tile_enc_name) @@ -430,6 +435,166 @@ def embed_slides( raise +def _encode_slide_from_tiles( + loaded: LoadedModel, + tile_embeddings: torch.Tensor, + tiling_result, +) -> torch.Tensor: + """Run the slide encoder on already-computed tile embeddings. + + Returns a CPU tensor of shape ``(D,)``. + """ + x_values, y_values = coordinate_arrays(tiling_result) + coordinates = np.column_stack((x_values, y_values)) + coordinate_tensor = torch.tensor(coordinates, dtype=torch.int, device=loaded.device) + features = tile_embeddings.to(loaded.device) + with torch.inference_mode(): + return loaded.model.encode_slide( + features, + coordinate_tensor, + tile_size_lv0=int(tiling_result.tile_size_lv0), + ).detach().cpu() + + +def embed_patients( + model, + slides, + *, + patient_id_map: dict[str, str] | None = None, + preprocessing: PreprocessingConfig, + execution: ExecutionOptions, +) -> list[EmbeddedPatient]: + """Tile slides and aggregate them into patient-level embeddings in memory. + + For each slide the tile encoder and slide encoder are run to produce a + slide-level embedding. Once all slides have been processed the slide + embeddings are grouped by ``patient_id`` and passed to the model's + ``encode_patient`` method. + + Args: + model: A patient-level ``Model`` instance (e.g. ``moozy``). + slides: Slides to process. + patient_id_map: Optional explicit ``{sample_id: patient_id}`` mapping. + When omitted, ``patient_id`` is looked up from each slide dict / + object attribute; slides without any ``patient_id`` are each + treated as their own patient. + preprocessing: Tiling and preprocessing configuration. + execution: Execution options (batch size, workers, etc.). + + Returns: + One :class:`~slide2vec.api.EmbeddedPatient` per unique patient, + ordered by first appearance. + """ + if model.level != "patient": + raise ValueError( + f"embed_patients() requires a patient-level model, but '{model.name}' " + f"has level='{model.level}'. Use embed_slides() for slide-level models." + ) + slide_records = [_coerce_slide_spec(slide) for slide in slides] + if not slide_records: + raise ValueError("At least one slide is required") + + # Resolve patient_id mapping: explicit dict > slide-level attribute > identity. + # Use slide_records for sample_id keys (already normalised by _coerce_slide_spec) + # but read patient_id from the original slide input (SlideSpec has no patient_id). + if patient_id_map is None: + patient_id_map = {} + for s, sr in zip(slides, slide_records): + if isinstance(s, dict) and "patient_id" in s: + patient_id_map[sr.sample_id] = str(s["patient_id"]) + elif hasattr(s, "patient_id"): + patient_id_map[sr.sample_id] = str(s.patient_id) + + emit_progress( + "run.started", + model_name=model.name, + level=model.level, + device_mode=_describe_device_mode(model, execution), + slide_count=len(slide_records), + output_dir=str(execution.output_dir or ""), + ) + with _embedding_work_dir(execution.output_dir) as work_dir: + try: + emit_progress("tiling.started", slide_count=len(slide_records)) + prepared_slides, tiling_results, process_list_path = _prepare_tiled_slides( + slide_records, + preprocessing, + output_dir=work_dir, + num_workers=execution.num_preprocessing_workers, + ) + _emit_tiling_finished( + process_list_path, + expected_total=len(slide_records), + successful_slides=prepared_slides, + tiling_results=tiling_results, + ) + embeddable_slides, embeddable_tiling_results, _ = _partition_slides_by_tile_count( + prepared_slides, + tiling_results, + ) + emit_progress("embedding.started", slide_count=len(embeddable_slides)) + loaded = model._load_backend() + + # Per-slide: tile encoding → slide encoding, accumulate for patient agg. + # Ordered dict preserves first-appearance order of patients. + patient_slide_embeddings: dict[str, list[tuple[str, torch.Tensor]]] = {} + for slide, tiling_result in zip(embeddable_slides, embeddable_tiling_results): + emit_progress( + "embedding.slide.started", + sample_id=slide.sample_id, + total_tiles=_num_embedding_items(tiling_result, preprocessing), + ) + tile_embeddings = _compute_tile_embeddings_for_slide( + loaded, + model, + slide, + tiling_result, + preprocessing=preprocessing, + execution=execution, + ) + slide_emb = _encode_slide_from_tiles(loaded, tile_embeddings, tiling_result) + patient_id = patient_id_map.get(slide.sample_id, slide.sample_id) + patient_slide_embeddings.setdefault(patient_id, []).append( + (slide.sample_id, slide_emb) + ) + emit_progress( + "embedding.slide.finished", + sample_id=slide.sample_id, + num_tiles=_num_embedding_items(tiling_result, preprocessing), + ) + + # Patient aggregation. + result: list[EmbeddedPatient] = [] + for patient_id, slide_embs_list in patient_slide_embeddings.items(): + stacked = torch.stack([emb for _, emb in slide_embs_list], dim=0).to(loaded.device) + with torch.inference_mode(): + patient_emb = loaded.model.encode_patient(stacked).detach().cpu() + result.append( + EmbeddedPatient( + patient_id=patient_id, + patient_embedding=patient_emb, + slide_embeddings={sid: emb for sid, emb in slide_embs_list}, + ) + ) + + emit_progress( + "embedding.finished", + slide_count=len(embeddable_slides), + slides_completed=len(embeddable_slides), + tile_artifacts=0, + slide_artifacts=0, + ) + emit_progress( + "run.finished", + output_dir=str(work_dir), + logs_dir=str(work_dir / "logs"), + ) + return result + except Exception as exc: + emit_progress("run.failed", stage="embedding", error=str(exc)) + raise + + def _select_embedding_path( *, model, @@ -604,6 +769,10 @@ def run_pipeline( tiling_only: bool = False, execution: ExecutionOptions, ) -> RunResult: + if model.level == "patient" and not tiling_only: + patient_id_map = _resolve_patient_id_map(slides=slides, manifest_path=manifest_path) + else: + patient_id_map = None slide_records = _resolve_slides(slides=slides, manifest_path=manifest_path) if not slide_records: raise ValueError("At least one slide is required") @@ -665,6 +834,36 @@ def run_pipeline( ) emit_progress("embedding.started", slide_count=len(embeddable_slides)) + if model.level == "patient": + tile_artifacts, slide_artifacts, patient_artifacts = _run_patient_pipeline( + model, + embeddable_slides=embeddable_slides, + embeddable_tiling_results=embeddable_tiling_results, + patient_id_map=patient_id_map, + preprocessing=resolved_preprocessing, + execution=execution, + output_dir=output_dir, + ) + emit_progress( + "embedding.finished", + slide_count=len(embeddable_slides), + slides_completed=len(embeddable_slides), + tile_artifacts=len(tile_artifacts), + slide_artifacts=len(slide_artifacts), + ) + emit_progress( + "run.finished", + output_dir=str(output_dir), + logs_dir=str(output_dir / "logs"), + ) + return RunResult( + tile_artifacts=tile_artifacts, + hierarchical_artifacts=[], + slide_artifacts=slide_artifacts, + patient_artifacts=patient_artifacts, + process_list_path=process_list_path, + ) + if execution.num_gpus > 1: tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_distributed_pipeline_artifacts( model=model, @@ -860,6 +1059,107 @@ def run_pipeline_with_coordinates( raise +def _run_patient_pipeline( + model, + *, + embeddable_slides: Sequence[SlideSpec], + embeddable_tiling_results, + patient_id_map: dict[str, str], + preprocessing: PreprocessingConfig, + execution: ExecutionOptions, + output_dir: Path, +) -> tuple[list[TileEmbeddingArtifact], list[SlideEmbeddingArtifact], list[PatientEmbeddingArtifact]]: + """Run the patient-level embedding pipeline. + + For each slide: extract tile features and compute a slide-level embedding. + After processing all slides for a patient: aggregate slide embeddings into + a single patient embedding via the case transformer. + """ + loaded = model._load_backend() + tile_artifacts: list[TileEmbeddingArtifact] = [] + slide_artifacts: list[SlideEmbeddingArtifact] = [] + + # Accumulate per-patient: {patient_id: [(sample_id, slide_embedding)]} + patient_slide_embeddings: dict[str, list[tuple[str, torch.Tensor]]] = {} + patient_slide_counts: dict[str, int] = {} + + for slide, tiling_result in zip(embeddable_slides, embeddable_tiling_results): + emit_progress( + "embedding.slide.started", + sample_id=slide.sample_id, + total_tiles=_num_embedding_items(tiling_result, preprocessing), + ) + tile_embeddings = _compute_tile_embeddings_for_slide( + loaded, + model, + slide, + tiling_result, + preprocessing=preprocessing, + execution=execution, + ) + + if execution.save_tile_embeddings: + tile_artifact = _write_tile_embedding_artifact( + slide.sample_id, + tile_embeddings, + execution=execution, + metadata=_build_tile_embedding_metadata( + model, + tiling_result=tiling_result, + image_path=slide.image_path, + mask_path=slide.mask_path, + tile_size_lv0=int(tiling_result.tile_size_lv0), + backend=_resolve_slide_backend(preprocessing, tiling_result), + ), + ) + tile_artifacts.append(tile_artifact) + + emit_progress( + "aggregation.started", + sample_id=slide.sample_id, + total_tiles=_num_embedding_items(tiling_result, preprocessing), + ) + slide_emb = _encode_slide_from_tiles(loaded, tile_embeddings, tiling_result) + emit_progress("aggregation.finished", sample_id=slide.sample_id, has_latents=False) + + if execution.save_slide_embeddings: + slide_artifact = _write_slide_embedding_artifact( + slide.sample_id, + slide_emb, + execution=execution, + metadata=_build_slide_embedding_metadata(model, image_path=slide.image_path), + ) + slide_artifacts.append(slide_artifact) + + patient_id = patient_id_map.get(slide.sample_id, slide.sample_id) + patient_slide_embeddings.setdefault(patient_id, []).append(slide_emb) + patient_slide_counts[patient_id] = patient_slide_counts.get(patient_id, 0) + 1 + + emit_progress( + "embedding.slide.finished", + sample_id=slide.sample_id, + num_tiles=_num_embedding_items(tiling_result, preprocessing), + ) + + # Aggregate per patient + patient_artifacts: list[PatientEmbeddingArtifact] = [] + for patient_id, slide_embs in patient_slide_embeddings.items(): + stacked = torch.stack(slide_embs, dim=0).to(loaded.device) + with torch.inference_mode(): + patient_emb = loaded.model.encode_patient(stacked).detach().cpu() + artifact = write_patient_embeddings( + patient_id, + patient_emb, + output_dir=output_dir, + output_format=execution.output_format, + metadata={"encoder_name": model.name, "encoder_level": model.level}, + num_slides=patient_slide_counts[patient_id], + ) + patient_artifacts.append(artifact) + + return tile_artifacts, slide_artifacts, patient_artifacts + + def _collect_local_pipeline_artifacts( *, model, @@ -2140,6 +2440,37 @@ def _resolve_slides(*, slides=None, manifest_path: str | Path | None = None) -> return [_coerce_slide_spec(slide) for slide in load_slide_manifest(manifest_path)] +def _resolve_patient_id_map( + *, + slides=None, + manifest_path: str | Path | None = None, +) -> dict[str, str]: + """Return {sample_id: patient_id} for patient-level models. + + Reads the 'patient_id' column from the manifest CSV, or falls back to + inspecting slide dicts for a 'patient_id' key. Raises if neither is found. + """ + if manifest_path is not None: + return load_patient_id_mapping(manifest_path) + if slides is not None: + result = {} + for slide in slides: + if isinstance(slide, dict) and "patient_id" in slide: + result[str(slide["sample_id"])] = str(slide["patient_id"]) + elif hasattr(slide, "patient_id"): + result[str(slide.sample_id)] = str(slide.patient_id) + else: + raise ValueError( + "Patient-level models require a 'patient_id' for every slide. " + "Provide a manifest CSV with a 'patient_id' column, or include " + "'patient_id' in each slide dict when calling programmatically." + ) + return result + raise ValueError( + "Either slides or manifest_path must be provided for patient-level models." + ) + + def _coerce_slide_spec(slide) -> SlideSpec: if isinstance(slide, SlideSpec): return slide @@ -2273,7 +2604,7 @@ def _emit_tiling_finished( def _should_persist_tile_embeddings(model, execution: ExecutionOptions) -> bool: - if model.level == "slide": + if model.level in {"slide", "patient"}: return bool(execution.save_tile_embeddings) return True diff --git a/slide2vec/utils/tiling_io.py b/slide2vec/utils/tiling_io.py index 598354b..d7ed237 100644 --- a/slide2vec/utils/tiling_io.py +++ b/slide2vec/utils/tiling_io.py @@ -118,6 +118,22 @@ def load_slide_manifest(csv_path: str | Path) -> list[SlideSpec]: ] +def load_patient_id_mapping(csv_path: str | Path) -> dict[str, str]: + """Return {sample_id: patient_id} from an input CSV with a 'patient_id' column. + + Raises ValueError if the 'patient_id' column is absent. + """ + manifest_path = Path(csv_path).resolve() + df = pd.read_csv(manifest_path) + if "patient_id" not in df.columns: + raise ValueError( + f"Input CSV {manifest_path} is missing the required 'patient_id' column " + "for patient-level models. Add a 'patient_id' column that groups slides " + "belonging to the same patient." + ) + return dict(zip(df["sample_id"].astype(str), df["patient_id"].astype(str))) + + def _load_base_process_df(process_list_path: str | Path) -> pd.DataFrame: process_list_path = Path(process_list_path) df = pd.read_csv(process_list_path) diff --git a/tasks/lessons.md b/tasks/lessons.md index 9eaf26f..7332be2 100644 --- a/tasks/lessons.md +++ b/tasks/lessons.md @@ -15,6 +15,13 @@ - When a user says a file is local-only, never add it to git history or include it in a PR even if they ask to modify it; keep the change local and confine the PR to repository-tracked files. - When the user explicitly rejects backward compatibility, remove compatibility shims and update writers/tests to emit the new schema instead of accepting old inputs. +## 2026-04-11 + +- When a config namespace may come from tests, fixtures, or older CLI payloads, use `getattr(..., default)` for optional flags instead of assuming every field is present. +- When an upstream helper disappears, add a local wrapper at the call site rather than importing the missing symbol directly from the dependency. +- When the sibling checkout for a dependency already exports the helper you need, align with that source instead of compensating for an older installed wheel. +- When a field is part of the current config contract, update the test fixtures to include it rather than adding a backward-compatibility branch in production code. + ## 2026-02-10 - When a git submodule shows unexpected local modifications, explicitly confirm scope with the user before editing. diff --git a/tests/test_output_consistency.py b/tests/test_output_consistency.py index 8f47123..b648f47 100644 --- a/tests/test_output_consistency.py +++ b/tests/test_output_consistency.py @@ -54,6 +54,7 @@ name="prism", # override (default: null) batch_size=8, # override (default: 256) save_tile_embeddings=True, + save_slide_embeddings=False, save_latents=False, ) diff --git a/tests/test_regression_core.py b/tests/test_regression_core.py index 8ed10d1..cb007be 100644 --- a/tests/test_regression_core.py +++ b/tests/test_regression_core.py @@ -355,6 +355,7 @@ def test_execution_options_from_config_maps_cli_fields(tmp_path: Path): model=SimpleNamespace( batch_size=4, save_tile_embeddings=True, + save_slide_embeddings=False, save_latents=True, ), speed=SimpleNamespace( @@ -389,7 +390,12 @@ def test_execution_options_from_config_defaults_preprocessing_workers_to_cpu_bud cfg = SimpleNamespace( output_dir=str(tmp_path), - model=SimpleNamespace(batch_size=4, save_tile_embeddings=False, save_latents=False), + model=SimpleNamespace( + batch_size=4, + save_tile_embeddings=False, + save_slide_embeddings=False, + save_latents=False, + ), speed=SimpleNamespace( precision="fp16", num_dataloader_workers=2, @@ -407,7 +413,12 @@ def test_execution_options_from_config_defaults_preprocessing_workers_to_cpu_bud def test_execution_options_from_config_preserves_auto_num_workers(tmp_path: Path): cfg = SimpleNamespace( output_dir=str(tmp_path), - model=SimpleNamespace(batch_size=4, save_tile_embeddings=False, save_latents=False), + model=SimpleNamespace( + batch_size=4, + save_tile_embeddings=False, + save_slide_embeddings=False, + save_latents=False, + ), speed=SimpleNamespace( precision="fp16", num_dataloader_workers=None, @@ -431,6 +442,7 @@ def test_execution_options_from_config_defaults_to_all_available_gpus_when_unset model=SimpleNamespace( batch_size=4, save_tile_embeddings=False, + save_slide_embeddings=False, save_latents=False, ), speed=SimpleNamespace( @@ -459,6 +471,7 @@ def test_execution_options_from_config_forces_fp32_for_cpu_runs(monkeypatch, tmp model=SimpleNamespace( batch_size=1, save_tile_embeddings=False, + save_slide_embeddings=False, save_latents=False, ), speed=SimpleNamespace( @@ -559,6 +572,7 @@ def test_cli_build_model_and_pipeline_delegates_to_public_api(monkeypatch, tmp_p batch_size=4, allow_non_recommended_settings=True, save_tile_embeddings=False, + save_slide_embeddings=False, save_latents=False, ), speed=SimpleNamespace(