diff --git a/.gitignore b/.gitignore index 267edc4..120cb5c 100644 --- a/.gitignore +++ b/.gitignore @@ -165,6 +165,5 @@ cython_debug/ output/ outputs/ archive/ -tasks/ docs/20*-*.md data/ diff --git a/slide2vec/api.py b/slide2vec/api.py index 443d1f1..135e57b 100644 --- a/slide2vec/api.py +++ b/slide2vec/api.py @@ -159,8 +159,9 @@ class ExecutionOptions: output_format: str = "pt" #: Number of tiles per forward pass. batch_size: int = 32 - #: DataLoader worker count. ``None`` means auto (capped by CPU / SLURM limit). - num_workers: int | None = None + #: DataLoader worker count per GPU rank. ``None`` means auto + #: (capped by CPU / SLURM limit, then split across the resolved GPU count). + num_workers_per_gpu: int | None = None #: Tiling worker count. ``None`` means auto (capped by CPU / SLURM limit). num_preprocessing_workers: int | None = None #: Number of GPUs to use. ``None`` defaults to all available GPUs. @@ -170,8 +171,6 @@ class ExecutionOptions: precision: str | None = None #: DataLoader prefetch queue depth per worker (default ``4``). prefetch_factor: int = 4 - #: Keep DataLoader workers alive between batches (default ``True``). - persistent_workers: bool = True #: Persist tile embeddings to disk when running a slide-level model. save_tile_embeddings: bool = False #: Persist slide embeddings to disk when running a patient-level model. @@ -183,14 +182,13 @@ class ExecutionOptions: def from_config(cls, cfg: Any, *, run_on_cpu: bool = False) -> "ExecutionOptions": configured_num_gpus = cfg.speed.num_gpus requested_precision = normalize_precision_name(cfg.speed.precision) - num_workers = cfg.speed.num_dataloader_workers + num_workers_per_gpu = cfg.speed.num_dataloader_workers prefetch_factor = int(cfg.speed.prefetch_factor_embedding) - persistent_workers = bool(cfg.speed.persistent_workers_embedding) return cls( output_dir=Path(cfg.output_dir), output_format="pt", batch_size=int(cfg.model.batch_size), - num_workers=int(num_workers) if num_workers is not None else None, + num_workers_per_gpu=int(num_workers_per_gpu) if num_workers_per_gpu is not None else None, num_preprocessing_workers=( int(cfg.speed.num_preprocessing_workers) if cfg.speed.num_preprocessing_workers is not None @@ -199,7 +197,6 @@ def from_config(cls, cfg: Any, *, run_on_cpu: bool = False) -> "ExecutionOptions num_gpus=1 if run_on_cpu else (int(configured_num_gpus) if configured_num_gpus is not None else None), precision="fp32" if run_on_cpu else requested_precision, prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers, save_tile_embeddings=bool(cfg.model.save_tile_embeddings), save_slide_embeddings=bool(cfg.model.save_slide_embeddings), save_latents=bool(cfg.model.save_latents), @@ -222,23 +219,25 @@ def __post_init__(self) -> None: object.__setattr__(self, "num_preprocessing_workers", capped_num_preprocessing_workers) logger = logging.getLogger(__name__) cap_source = f"slurm_cpu_limit={slurm_limit}" if slurm_limit is not None else f"cpu_count={cpu_count}" - resolved_num_workers = self.resolved_num_workers() - num_workers_label = ( + resolved_num_workers = self.resolved_num_workers_per_gpu() + num_workers_per_gpu_label = ( f"{resolved_num_workers} (requested=auto)" - if self.num_workers is None + if self.num_workers_per_gpu is None else str(resolved_num_workers) ) logger.info( - "ExecutionOptions: num_workers=%s, num_preprocessing_workers=%d " + "ExecutionOptions: num_workers_per_gpu=%s, num_preprocessing_workers=%d " "(preprocessing cap=%d via %s)", - num_workers_label, + num_workers_per_gpu_label, capped_num_preprocessing_workers, cap, cap_source, ) - def resolved_num_workers(self) -> int: - return cpu_worker_limit() if self.num_workers is None else int(self.num_workers) + def resolved_num_workers_per_gpu(self) -> int: + if self.num_workers_per_gpu is not None: + return self.num_workers_per_gpu + return max(1, cpu_worker_limit() // self.num_gpus) def with_output_dir(self, output_dir: PathLike | None) -> "ExecutionOptions": if output_dir is None: diff --git a/slide2vec/configs/default.yaml b/slide2vec/configs/default.yaml index 02d3241..d138c4c 100644 --- a/slide2vec/configs/default.yaml +++ b/slide2vec/configs/default.yaml @@ -42,7 +42,7 @@ tiling: sthresh_up: 255 # upper threshold value for scaling the binary mask mthresh: 7 # median filter size (positive, odd integer) close: 4 # additional morphological closing to apply following initial thresholding (positive integer) - method: "hsv" # tissue segmentation method: "hsv", "otsu", "threshold", or "sam2" + method: # tissue segmentation method: "hsv", "otsu", "threshold", or "sam2"; ignored when precomputed tissue masks are provided sam2_checkpoint_path: # optional when method="sam2"; if empty, hs2p downloads the default AtlasPatch checkpoint from Hugging Face sam2_config_path: # optional local override for the SAM2 model config; if empty, hs2p downloads the default AtlasPatch config from Hugging Face sam2_device: "cpu" # device for SAM2 inference, e.g. "cpu", "cuda", or "cuda:0" @@ -71,12 +71,11 @@ tiling: speed: precision: # model inference precision ["fp32", "fp16", "bf16"]; if not set, determined automatically based on model recommendations - num_dataloader_workers: # number of DataLoader worker processes for reading tiles during embedding; defaults to auto (job CPU budget, except cuCIM on-the-fly uses cpu_budget // speed.num_cucim_workers) + num_dataloader_workers: # number of DataLoader worker processes per GPU rank for reading tiles during embedding; defaults to auto (job CPU budget split across GPUs, except cuCIM on-the-fly uses per-GPU budget // speed.num_cucim_workers) num_gpus: # number of GPUs to use for feature extraction; defaults to all available GPUs num_preprocessing_workers: # number of workers for hs2p tiling (WSI reading, JPEG encoding, tar writing); defaults to the runtime CPU budget capped at 64 num_cucim_workers: 4 # number of internal cucim threads per read_region call (embedding path, on-the-fly only); DataLoader workers are auto-set to cpu_count // num_cucim_workers prefetch_factor_embedding: 4 # prefetch factor for tile embedding dataloaders - persistent_workers_embedding: true # keep DataLoader workers alive across epochs/batches wandb: enable: false diff --git a/slide2vec/distributed/direct_embed_worker.py b/slide2vec/distributed/direct_embed_worker.py index e28f4de..0989676 100644 --- a/slide2vec/distributed/direct_embed_worker.py +++ b/slide2vec/distributed/direct_embed_worker.py @@ -119,20 +119,24 @@ def main(argv=None) -> int: return 0 assigned_slides = [paired_by_sample[sample_id][0] for sample_id in assigned_ids] assigned_tiling_results = [paired_by_sample[sample_id][1] for sample_id in assigned_ids] - embedded_slides = _compute_embedded_slides( - model, - assigned_slides, - assigned_tiling_results, - preprocessing=preprocessing, - execution=execution, - ) - for embedded_slide in embedded_slides: + + def _persist_embedded_slide(slide, tiling_result, embedded_slide) -> None: payload = { "tile_embeddings": _to_cpu_payload(embedded_slide.tile_embeddings), "slide_embedding": _to_cpu_payload(embedded_slide.slide_embedding), "latents": _to_cpu_payload(embedded_slide.latents), } torch.save(payload, coordination_dir / f"{embedded_slide.sample_id}.embedded.pt") + + _compute_embedded_slides( + model, + assigned_slides, + assigned_tiling_results, + preprocessing=preprocessing, + execution=execution, + on_embedded_slide=_persist_embedded_slide, + collect_results=False, + ) return 0 finally: if dist.is_available() and dist.is_initialized(): diff --git a/slide2vec/distributed/pipeline_worker.py b/slide2vec/distributed/pipeline_worker.py index ec395c1..8421e50 100644 --- a/slide2vec/distributed/pipeline_worker.py +++ b/slide2vec/distributed/pipeline_worker.py @@ -19,8 +19,8 @@ def main(argv=None) -> int: import slide2vec.distributed as distributed from slide2vec.api import Model from slide2vec.inference import ( + _build_incremental_persist_callback, _compute_embedded_slides, - _persist_embedded_slide, load_successful_tiled_slides, ) from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter @@ -70,21 +70,21 @@ def main(argv=None) -> int: ) context = activate_progress_reporter(reporter) if reporter is not None else nullcontext() with context: - embedded_slides = _compute_embedded_slides( + persist_callback, _, _ = _build_incremental_persist_callback( + model=model, + preprocessing=preprocessing, + execution=execution, + process_list_path=None, + ) + _compute_embedded_slides( model, assigned_slides, assigned_tiling_results, preprocessing=preprocessing, execution=execution, + on_embedded_slide=persist_callback, + collect_results=False, ) - for embedded_slide, tiling_result in zip(embedded_slides, assigned_tiling_results): - _persist_embedded_slide( - model, - embedded_slide, - tiling_result, - preprocessing=preprocessing, - execution=execution, - ) return 0 finally: if dist.is_available() and dist.is_initialized(): diff --git a/slide2vec/encoders/validation.py b/slide2vec/encoders/validation.py index bbc94e1..e84a880 100644 --- a/slide2vec/encoders/validation.py +++ b/slide2vec/encoders/validation.py @@ -63,14 +63,18 @@ def validate_encoder_config( if not mismatches: return - message = ( - f"Model '{encoder_name}' is configured with " - f"{'; '.join(mismatches)}. " - "Set `model.allow_non_recommended_settings=true` in YAML/CLI or " - "`allow_non_recommended_settings=True` in `Model.from_preset(...)` " - "to continue with a warning." - ) if allow_non_recommended: - logger.warning(message) + logger.warning( + f"Model '{encoder_name}' is configured with " + f"{'; '.join(mismatches)}. " + "Warning-only mode is enabled because " + "`allow_non_recommended_settings=True`." + ) else: - raise ValueError(message) + raise ValueError( + f"Model '{encoder_name}' is configured with " + f"{'; '.join(mismatches)}. " + "Set `model.allow_non_recommended_settings=true` in YAML/CLI or " + "`allow_non_recommended_settings=True` in `Model.from_preset(...)` " + "to continue." + ) diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 2a68471..eb56e42 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -98,25 +98,29 @@ def _serialize_execution( *, preprocessing: PreprocessingConfig | None = None, ) -> dict[str, Any]: - effective_num_workers = None + effective_num_workers_per_gpu = None if preprocessing is not None and preprocessing.on_the_fly and preprocessing.read_tiles_from is None: - effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + effective_num_workers_per_gpu, _ = _resolve_on_the_fly_num_workers( + preprocessing.num_cucim_workers, + num_gpus=execution.num_gpus, + ) return runtime_serialization.serialize_execution( execution, - effective_num_workers=effective_num_workers, + effective_num_workers_per_gpu=effective_num_workers_per_gpu, ) -def _resolve_on_the_fly_num_workers(num_cucim_workers: int) -> tuple[int, str]: +def _resolve_on_the_fly_num_workers(num_cucim_workers: int, num_gpus: int) -> tuple[int, str]: if int(num_cucim_workers) < 1: raise ValueError("num_cucim_workers must be at least 1") cpu_count = os.cpu_count() or 1 - worker_budget = cpu_worker_limit() + worker_budget = max(1, cpu_worker_limit() // max(1, int(num_gpus))) details = [f"cpu_count={cpu_count}"] slurm_limit = slurm_cpu_limit() if slurm_limit is not None: details.append(f"slurm_cpu_limit={slurm_limit}") + details.append(f"num_gpus={num_gpus}") effective_num_workers = max(1, worker_budget // num_cucim_workers) details.append(f"num_cucim_workers={num_cucim_workers}") return effective_num_workers, " // ".join(details) @@ -131,13 +135,16 @@ def _log_on_the_fly_worker_override_once( return if not any(runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) == "cucim" for tiling_result in tiling_results): return - effective_num_workers, worker_context = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) - if effective_num_workers == execution.num_workers: + effective_num_workers_per_gpu, worker_context = _resolve_on_the_fly_num_workers( + preprocessing.num_cucim_workers, + num_gpus=execution.num_gpus, + ) + if effective_num_workers_per_gpu == execution.resolved_num_workers_per_gpu(): return logging.getLogger(__name__).info( - f"on-the-fly mode: setting DataLoader num_workers={effective_num_workers} " + f"on-the-fly mode: setting DataLoader num_workers_per_gpu={effective_num_workers_per_gpu} " f"({worker_context}); " - f"ignoring speed.num_dataloader_workers={execution.num_workers}" + f"ignoring speed.num_workers_per_gpu={execution.num_workers_per_gpu}" ) @@ -183,6 +190,13 @@ def _uses_cuda_runtime(device) -> bool: return str(device).startswith("cuda") and torch.cuda.is_available() +def _slide_encode_autocast_ctx(device, precision: str | None): + autocast_dtype = _autocast_dtype(torch, precision) if precision is not None else None + if autocast_dtype is None or not _uses_cuda_runtime(device): + return nullcontext() + return torch.autocast(device_type="cuda", dtype=autocast_dtype) + + def _make_slide_spec( *, sample_id: str, @@ -388,6 +402,8 @@ def _encode_slide_from_tiles( loaded: LoadedModel, tile_embeddings: torch.Tensor, tiling_result, + *, + execution: ExecutionOptions | None = None, ) -> torch.Tensor: """Run the slide encoder on already-computed tile embeddings. @@ -397,12 +413,13 @@ def _encode_slide_from_tiles( coordinates = np.column_stack((x_values, y_values)) coordinate_tensor = torch.tensor(coordinates, dtype=torch.int, device=loaded.device) features = tile_embeddings.to(loaded.device) - with torch.inference_mode(): - return loaded.model.encode_slide( - features, - coordinate_tensor, - tile_size_lv0=int(tiling_result.tile_size_lv0), - ).detach().cpu() + with _slide_encode_autocast_ctx(loaded.device, None if execution is None else execution.precision): + with torch.inference_mode(): + return loaded.model.encode_slide( + features, + coordinate_tensor, + tile_size_lv0=int(tiling_result.tile_size_lv0), + ).detach().cpu() def embed_patients( @@ -506,7 +523,12 @@ def embed_patients( preprocessing=preprocessing, execution=execution, ) - slide_emb = _encode_slide_from_tiles(loaded, tile_embeddings, tiling_result) + slide_emb = _encode_slide_from_tiles( + loaded, + tile_embeddings, + tiling_result, + execution=execution, + ) patient_id = patient_id_map.get(slide.sample_id, slide.sample_id) patient_slide_embeddings.setdefault(patient_id, []).append( (slide.sample_id, slide_emb) @@ -701,12 +723,13 @@ def aggregate_tiles( if not torch.is_tensor(tile_features): tile_features = torch.as_tensor(tile_features) tile_features = tile_features.to(loaded.device) - with torch.inference_mode(): - embedding = loaded.model.encode_slide( - tile_features, - coordinate_tensor, - tile_size_lv0=int(tiling_result.tile_size_lv0), - ) + with _slide_encode_autocast_ctx(loaded.device, execution.precision): + with torch.inference_mode(): + embedding = loaded.model.encode_slide( + tile_features, + coordinate_tensor, + tile_size_lv0=int(tiling_result.tile_size_lv0), + ) latents = None slide_artifact = runtime_embedding.write_slide_embedding_artifact( artifact.sample_id, @@ -888,6 +911,7 @@ def run_pipeline( preprocessing=resolved_preprocessing, execution=execution, on_embedded_slide=local_persist_callback, + collect_results=False, ) tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_pipeline_artifacts( embeddable_slides, @@ -1005,24 +1029,32 @@ def run_pipeline_with_coordinates( slide_artifacts=slide_artifacts, process_list_path=process_list_path, ) - embedded_slides = _compute_embedded_slides( - model, - embeddable_slides, - embeddable_tiling_results, + local_persist_callback, tile_or_hier_artifacts, slide_artifacts = _build_incremental_persist_callback( + model=model, preprocessing=resolved_preprocessing, execution=execution, + process_list_path=process_list_path, ) - tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_local_pipeline_artifacts( - model=model, - embedded_slides=embedded_slides, - tiling_results=embeddable_tiling_results, + _compute_embedded_slides( + model, + embeddable_slides, + embeddable_tiling_results, preprocessing=resolved_preprocessing, execution=execution, + on_embedded_slide=local_persist_callback, + collect_results=False, ) + tile_artifacts: list[TileEmbeddingArtifact] = [] + hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] = [] + for artifact in tile_or_hier_artifacts: + if isinstance(artifact, HierarchicalEmbeddingArtifact): + hierarchical_artifacts.append(artifact) + elif artifact is not None: + tile_artifacts.append(artifact) return RunResult( tile_artifacts=tile_artifacts, hierarchical_artifacts=hierarchical_artifacts, - slide_artifacts=slide_artifacts, + slide_artifacts=list(slide_artifacts), process_list_path=process_list_path, ) except Exception as exc: @@ -1090,7 +1122,12 @@ def _run_patient_pipeline( sample_id=slide.sample_id, total_tiles=_num_embedding_items(tiling_result, preprocessing), ) - slide_emb = _encode_slide_from_tiles(loaded, tile_embeddings, tiling_result) + slide_emb = _encode_slide_from_tiles( + loaded, + tile_embeddings, + tiling_result, + execution=execution, + ) emit_progress("aggregation.finished", sample_id=slide.sample_id, has_latents=False) if execution.save_slide_embeddings: @@ -1370,6 +1407,7 @@ def _compute_embedded_slides( preprocessing: PreprocessingConfig, execution: ExecutionOptions, on_embedded_slide: Callable[[SlideSpec, Any, EmbeddedSlide], None] | None = None, + collect_results: bool = True, ) -> list[EmbeddedSlide]: loaded = model._load_backend() embedded_slides: list[EmbeddedSlide] = [] @@ -1424,7 +1462,8 @@ def _compute_embedded_slides( slide_embedding=slide_embedding, latents=latents, ) - embedded_slides.append(embedded_slide) + if collect_results: + embedded_slides.append(embedded_slide) if on_embedded_slide is not None: on_embedded_slide(slide, tiling_result, embedded_slide) emit_progress( @@ -1508,10 +1547,12 @@ def _compute_tile_embeddings_for_slide( loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) resolved_backend = runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) if preprocessing.on_the_fly and preprocessing.read_tiles_from is None and resolved_backend == "cucim": - effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + effective_num_workers, _ = _resolve_on_the_fly_num_workers( + preprocessing.num_cucim_workers, + num_gpus=execution.num_gpus, + ) loader_kwargs["num_workers"] = effective_num_workers if effective_num_workers == 0: - loader_kwargs.pop("persistent_workers", None) loader_kwargs.pop("prefetch_factor", None) _configure_cucim_worker_stderr(loader_kwargs, backend=resolved_backend) if batch_sampler is not None: @@ -1525,7 +1566,7 @@ def _compute_tile_embeddings_for_slide( **loader_kwargs, ) def _compute_embeddings(): - return _run_forward_pass( + _batch_indices, tile_embeddings = _run_forward_pass( dataloader, loaded, autocast_context, @@ -1534,6 +1575,7 @@ def _compute_embeddings(): total_items=len(dataset), unit_label="tile", ) + return tile_embeddings if resolved_backend == "cucim": tile_embeddings = run_with_filtered_stderr(_compute_embeddings) @@ -1587,10 +1629,12 @@ def _compute_hierarchical_embeddings_for_slide( loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) resolved_backend = runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) if resolved_backend == "cucim": - effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + effective_num_workers, _ = _resolve_on_the_fly_num_workers( + preprocessing.num_cucim_workers, + num_gpus=execution.num_gpus, + ) loader_kwargs["num_workers"] = effective_num_workers if effective_num_workers == 0: - loader_kwargs.pop("persistent_workers", None) loader_kwargs.pop("prefetch_factor", None) _configure_cucim_worker_stderr( loader_kwargs, @@ -1620,7 +1664,6 @@ def _compute_embeddings(): sample_id=slide.sample_id, total_items=len(dataset), unit_label="tile", - return_indices=True, ) if resolved_backend == "cucim": @@ -1670,10 +1713,12 @@ def _compute_hierarchical_embedding_shard_for_slide( loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) resolved_backend = runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) if resolved_backend == "cucim": - effective_num_workers, _worker_context = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + effective_num_workers, _worker_context = _resolve_on_the_fly_num_workers( + preprocessing.num_cucim_workers, + num_gpus=execution.num_gpus, + ) loader_kwargs["num_workers"] = effective_num_workers if effective_num_workers == 0: - loader_kwargs.pop("persistent_workers", None) loader_kwargs.pop("prefetch_factor", None) _configure_cucim_worker_stderr( loader_kwargs, @@ -1699,7 +1744,6 @@ def _compute_embeddings(): sample_id=slide.sample_id, total_items=len(dataset), unit_label="tile", - return_indices=True, ) if resolved_backend == "cucim": @@ -1734,12 +1778,13 @@ def _aggregate_tile_embeddings_for_slide( if not torch.is_tensor(tile_embeddings): tile_embeddings = torch.as_tensor(tile_embeddings) features = tile_embeddings.to(loaded.device) - with torch.inference_mode(): - slide_embedding = loaded.model.encode_slide( - features, - coordinate_tensor, - tile_size_lv0=int(tiling_result.tile_size_lv0), - ).detach().cpu() + with _slide_encode_autocast_ctx(loaded.device, execution.precision): + with torch.inference_mode(): + slide_embedding = loaded.model.encode_slide( + features, + coordinate_tensor, + tile_size_lv0=int(tiling_result.tile_size_lv0), + ).detach().cpu() latents = None return slide_embedding, latents diff --git a/slide2vec/runtime/batching.py b/slide2vec/runtime/batching.py index 8890ecb..16d1165 100644 --- a/slide2vec/runtime/batching.py +++ b/slide2vec/runtime/batching.py @@ -20,17 +20,6 @@ def uses_cuda_runtime(device) -> bool: return str(device).startswith("cuda") and torch.cuda.is_available() -def embedding_dataloader_kwargs(loaded: LoadedModel, execution) -> dict[str, Any]: - resolved_num_workers = execution.resolved_num_workers() - kwargs: dict[str, Any] = { - "num_workers": resolved_num_workers, - "pin_memory": uses_cuda_runtime(loaded.device), - } - if resolved_num_workers > 0: - kwargs["persistent_workers"] = bool(execution.persistent_workers) - kwargs["prefetch_factor"] = int(execution.prefetch_factor) - return kwargs - def should_suppress_cucim_dataloader_stderr(dataloader) -> bool: if dataloader is None: @@ -80,6 +69,16 @@ def preprocess(batch): return preprocess +def embedding_dataloader_kwargs(loaded: LoadedModel, execution) -> dict[str, Any]: + num_workers = execution.resolved_num_workers_per_gpu() + kwargs: dict[str, Any] = {"num_workers": num_workers} + if num_workers > 0: + kwargs["prefetch_factor"] = execution.prefetch_factor + if uses_cuda_runtime(loaded.device): + kwargs["pin_memory"] = True + return kwargs + + def build_batch_transform_spec(transforms) -> BatchTransformSpec | None: if isinstance(transforms, BaseImageProcessor): crop_size = transforms.crop_size if hasattr(transforms, "crop_size") else None @@ -370,10 +369,10 @@ def run_forward_pass( sample_id: str | None = None, total_items: int | None = None, unit_label: str = "tile", - return_indices: bool = False, ): - outputs = [] - batch_indices = [] if return_indices else None + embeddings = None + batch_indices = None + buffer_capacity = max(0, int(total_items)) if total_items is not None else 0 processed = 0 batch_index = 0 prefetcher_context = ( @@ -389,9 +388,33 @@ def run_forward_pass( forward_start = time.perf_counter() embedding = loaded.model.encode_tiles(image).detach().cpu() forward_ms = (time.perf_counter() - forward_start) * 1000.0 - outputs.append(embedding) - if batch_indices is not None: - batch_indices.append(torch.as_tensor(prepared_batch.indices, dtype=torch.long).detach().cpu()) + batch_size = int(embedding.shape[0]) + current_indices = torch.as_tensor(prepared_batch.indices, dtype=torch.long).detach().cpu() + required_capacity = processed + batch_size + if embeddings is None: + buffer_capacity = max(buffer_capacity, required_capacity) + embeddings = torch.empty( + (buffer_capacity, int(embedding.shape[-1])), + dtype=embedding.dtype, + ) + batch_indices = torch.empty((buffer_capacity,), dtype=torch.long) + elif required_capacity > buffer_capacity: + new_capacity = max(required_capacity, max(1, buffer_capacity * 2)) + grown_embeddings = torch.empty( + (new_capacity, int(embeddings.shape[-1])), + dtype=embeddings.dtype, + ) + if processed > 0: + grown_embeddings[:processed] = embeddings[:processed] + embeddings = grown_embeddings + grown_indices = torch.empty((new_capacity,), dtype=torch.long) + if processed > 0: + grown_indices[:processed] = batch_indices[:processed] + batch_indices = grown_indices + buffer_capacity = new_capacity + + embeddings[processed:required_capacity] = embedding + batch_indices[processed:required_capacity] = current_indices processed += int(embedding.shape[0]) batch_index += 1 batch_total_ms = ( @@ -428,16 +451,11 @@ def run_forward_pass( total=int(total_items or processed), unit=unit_label, ) - if not outputs: + if embeddings is None: feature_dim = loaded.tile_feature_dim if loaded.tile_feature_dim is not None else loaded.feature_dim empty = torch.empty((0, int(feature_dim)), dtype=torch.float32) - if batch_indices is not None: - return torch.empty((0,), dtype=torch.long), empty - return empty - embeddings = torch.cat(outputs, dim=0) - if batch_indices is not None: - return torch.cat(batch_indices, dim=0), embeddings - return embeddings + return torch.empty((0,), dtype=torch.long), empty + return batch_indices[:processed], embeddings[:processed] def resolve_device(device: str, default_device): diff --git a/slide2vec/runtime/distributed.py b/slide2vec/runtime/distributed.py index c5de80f..a831a20 100644 --- a/slide2vec/runtime/distributed.py +++ b/slide2vec/runtime/distributed.py @@ -72,6 +72,7 @@ def run_torchrun_worker( sys.executable, "-m", "torch.distributed.run", + "--standalone", f"--nproc_per_node={num_gpus}", "-m", module, diff --git a/slide2vec/runtime/serialization.py b/slide2vec/runtime/serialization.py index c3e080d..0808537 100644 --- a/slide2vec/runtime/serialization.py +++ b/slide2vec/runtime/serialization.py @@ -36,18 +36,21 @@ def serialize_preprocessing(preprocessing: PreprocessingConfig) -> dict[str, Any def serialize_execution( execution: ExecutionOptions, *, - effective_num_workers: int | None = None, + effective_num_workers_per_gpu: int | None = None, ) -> dict[str, Any]: return { "output_dir": str(execution.output_dir) if execution.output_dir is not None else None, "output_format": execution.output_format, "batch_size": execution.batch_size, - "num_workers": effective_num_workers if effective_num_workers is not None else execution.num_workers, + "num_workers_per_gpu": ( + effective_num_workers_per_gpu + if effective_num_workers_per_gpu is not None + else execution.num_workers_per_gpu + ), "num_preprocessing_workers": execution.num_preprocessing_workers, "num_gpus": execution.num_gpus, "precision": execution.precision, "prefetch_factor": execution.prefetch_factor, - "persistent_workers": execution.persistent_workers, "save_tile_embeddings": execution.save_tile_embeddings, "save_slide_embeddings": execution.save_slide_embeddings, "save_latents": execution.save_latents, @@ -92,37 +95,27 @@ def deserialize_preprocessing(payload: dict[str, Any]) -> PreprocessingConfig: def deserialize_execution(payload: dict[str, Any]) -> ExecutionOptions: - output_dir = payload["output_dir"] if "output_dir" in payload else None - batch_size = payload["batch_size"] if "batch_size" in payload else None - num_workers = payload["num_workers"] if "num_workers" in payload else None - num_preprocessing_workers = ( - payload["num_preprocessing_workers"] if "num_preprocessing_workers" in payload else None - ) - num_gpus = payload["num_gpus"] if "num_gpus" in payload else 1 - precision = payload["precision"] if "precision" in payload else "fp32" - prefetch_factor = payload["prefetch_factor"] if "prefetch_factor" in payload else 4 - persistent_workers = ( - bool(payload["persistent_workers"]) if "persistent_workers" in payload else True - ) - save_tile_embeddings = ( - bool(payload["save_tile_embeddings"]) if "save_tile_embeddings" in payload else False - ) - save_slide_embeddings = ( - bool(payload["save_slide_embeddings"]) if "save_slide_embeddings" in payload else False - ) - save_latents = bool(payload["save_latents"]) if "save_latents" in payload else False + output_dir = payload.get("output_dir") + batch_size = payload.get("batch_size") + num_workers_per_gpu = payload.get("num_workers_per_gpu") + num_preprocessing_workers = payload.get("num_preprocessing_workers") + num_gpus = payload.get("num_gpus", 1) + precision = payload.get("precision", "fp32") + prefetch_factor = payload.get("prefetch_factor", 4) + save_tile_embeddings = bool(payload.get("save_tile_embeddings", False)) + save_slide_embeddings = bool(payload.get("save_slide_embeddings", False)) + save_latents = bool(payload.get("save_latents", False)) return ExecutionOptions( output_dir=Path(output_dir) if output_dir is not None else None, - output_format=payload["output_format"] if "output_format" in payload else "pt", + output_format=payload.get("output_format", "pt"), batch_size=batch_size, - num_workers=int(num_workers) if num_workers is not None else None, + num_workers_per_gpu=int(num_workers_per_gpu) if num_workers_per_gpu is not None else None, num_preprocessing_workers=( int(num_preprocessing_workers) if num_preprocessing_workers is not None else None ), num_gpus=int(num_gpus), precision=precision, prefetch_factor=int(prefetch_factor), - persistent_workers=persistent_workers, save_tile_embeddings=save_tile_embeddings, save_slide_embeddings=save_slide_embeddings, save_latents=save_latents, diff --git a/slide2vec/runtime/tiling.py b/slide2vec/runtime/tiling.py index 53766ce..3e37061 100644 --- a/slide2vec/runtime/tiling.py +++ b/slide2vec/runtime/tiling.py @@ -45,12 +45,12 @@ def build_hs2p_configs( else preprocessing.requested_tile_size_px ) tiling_cfg = TilingConfig( - backend=resolve_tiling_backend(preprocessing), requested_spacing_um=preprocessing.requested_spacing_um, requested_tile_size_px=requested_tile_size_px, tolerance=preprocessing.tolerance, overlap=preprocessing.overlap, tissue_threshold=preprocessing.tissue_threshold, + backend=resolve_tiling_backend(preprocessing), ) segmentation_cfg = SegmentationConfig(**dict(preprocessing.segmentation)) filtering_cfg = FilterConfig(**dict(preprocessing.filtering)) diff --git a/tests/test_output_consistency.py b/tests/test_output_consistency.py new file mode 100644 index 0000000..bc9f59f --- /dev/null +++ b/tests/test_output_consistency.py @@ -0,0 +1,192 @@ +import os +import json +import subprocess +import sys +from pathlib import Path + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") +OmegaConf = pytest.importorskip("omegaconf").OmegaConf + +# --------------------------------------------------------------------------- +# Hardcoded pipeline parameters +# --------------------------------------------------------------------------- + +# -- tiling.params -- +TILING_PARAMS = dict( + requested_spacing_um=0.5, + tolerance=0.07, # override (default: 0.05) + requested_tile_size_px=224, # override (default: 256) + overlap=0.0, + tissue_threshold=0.1, # override (default: 0.01) +) + +# -- tiling.seg_params -- +TILING_SEG_PARAMS = dict( + downsample=64, # override (default: 16) + sthresh=8, + sthresh_up=255, + mthresh=7, + close=4, + method="hsv", +) + +# -- tiling.filter_params -- +TILING_FILTER_PARAMS = dict( + ref_tile_size=224, # override (default: 16) + a_t=4, + a_h=2, + filter_white=False, + filter_black=False, + white_threshold=220, + black_threshold=25, + fraction_threshold=0.9, +) + +# -- tiling.preview -- +TILING_PREVIEW = dict( + save_mask_preview=False, + save_tiling_preview=False, + downsample=32, + tissue_contour_color=(157, 219, 129), + mask_overlay_alpha=0.5, +) + +# -- model -- +MODEL_PARAMS = dict( + name="prism", # override (default: null) + batch_size=8, # override (default: 256) + save_tile_embeddings=True, + save_slide_embeddings=False, + save_latents=False, +) + +# -- speed -- +SPEED_PARAMS = dict( + precision="fp16", # override (default: fp32) + num_dataloader_workers=0, # keep the Prism subprocess path single-process to avoid worker SHM pressure +) + +# --------------------------------------------------------------------------- +# Paths relative to this test file +# --------------------------------------------------------------------------- +TEST_DIR = Path(__file__).parent +INPUT_DIR = TEST_DIR / "fixtures" / "input" +GT_DIR = TEST_DIR / "fixtures" / "gt" +REPO_ROOT = TEST_DIR.parent + + +@pytest.fixture(scope="module") +def wsi_path() -> Path: + p = INPUT_DIR / "test-wsi.tif" + if not p.is_file(): + pytest.skip(f"Test fixture missing: {p}") + return p + + +@pytest.fixture(scope="module") +def mask_path() -> Path: + p = INPUT_DIR / "test-mask.tif" + if not p.is_file(): + pytest.skip(f"Test fixture missing: {p}") + return p + + +@pytest.mark.skipif( + not os.environ.get("HF_TOKEN"), + reason="HF_TOKEN required for model weight download", +) +def test_output_consistency(wsi_path, mask_path, tmp_path): + """Running the full pipeline with hardcoded params produces x/y coordinates and + embeddings that match the ground truth fixtures in test/gt/.""" + + pytest.importorskip("transformers") + pytest.importorskip("wholeslidedata") + + # 1. Build a temporary CSV with resolved absolute paths + tmp_csv = tmp_path / "test.csv" + tmp_csv.write_text( + f"sample_id,image_path,mask_path\ntest-wsi,{wsi_path},{mask_path}\n" + ) + + # 2. Build config from hardcoded constants (no dependency on test/input/config.yaml) + cfg = OmegaConf.create({ + "csv": str(tmp_csv), + "output_dir": str(tmp_path), + "resume": False, + "resume_dirname": None, + "seed": 0, + "tiling": { + "read_coordinates_from": None, + "read_tiles_from": None, + "on_the_fly": True, + "backend": "asap", + "params": TILING_PARAMS, + "seg_params": TILING_SEG_PARAMS, + "filter_params": TILING_FILTER_PARAMS, + "preview": TILING_PREVIEW, + }, + "model": MODEL_PARAMS, + "speed": SPEED_PARAMS, + "wandb": {"enable": False}, + }) + cfg_path = tmp_path / "config.yaml" + OmegaConf.save(cfg, cfg_path) + + # 3. Run the pipeline + subprocess.run( + [ + "slide2vec", + str(cfg_path), + "--skip-datetime", + "--run-on-cpu", + ], + cwd=REPO_ROOT, + check=True, + ) + + # 4. Assert coordinates match exactly (tiling is deterministic) + gt_coords = np.load(GT_DIR / "test-wsi.coordinates.npz", allow_pickle=False) + coords = np.load(tmp_path / "tiles" / "test-wsi.coordinates.npz", allow_pickle=False) + np.testing.assert_array_equal(coords, gt_coords) + + meta = json.loads((tmp_path / "tiles" / "test-wsi.coordinates.meta.json").read_text()) + assert meta["provenance"]["sample_id"] == "test-wsi" + assert meta["provenance"]["backend"] == "asap" + assert meta["tiling"]["requested_spacing_um"] == pytest.approx(0.5) + assert meta["tiling"]["requested_tile_size_px"] == 224 + + # 5. Assert slide embeddings are within tolerance + gt_emb = torch.load(GT_DIR / "test-wsi.pt", map_location="cpu", weights_only=True) + emb = torch.load(tmp_path / "slide_embeddings" / "test-wsi.pt", map_location="cpu", weights_only=True) + assert emb.shape == gt_emb.shape, f"Shape mismatch: {emb.shape} vs {gt_emb.shape}" + + cos = torch.nn.functional.cosine_similarity(emb, gt_emb, dim=-1) + mean_cos = float(cos.mean()) + atol, rtol = 1e-2, 1e-3 + if not torch.allclose(emb, gt_emb, atol=atol, rtol=rtol): + assert mean_cos >= 0.99, ( + f"Embedding mismatch: mean cosine similarity={mean_cos:.4f} " + f"(atol={atol}, rtol={rtol})" + ) + else: + print(f"OK: slide embeddings within tolerance; mean cosine similarity={mean_cos:.4f}") + + # 6. Assert tile-level embeddings match ground truth (verifies tile ordering) + gt_tile_emb = torch.load(GT_DIR / "test-wsi.tiles.pt", map_location="cpu", weights_only=True) + tile_emb = torch.load(tmp_path / "tile_embeddings" / "test-wsi.pt", map_location="cpu", weights_only=True) + assert tile_emb.shape == gt_tile_emb.shape, ( + f"Tile embedding shape mismatch: {tile_emb.shape} vs {gt_tile_emb.shape}" + ) + tile_cos = torch.nn.functional.cosine_similarity(tile_emb, gt_tile_emb, dim=-1) + mean_tile_cos = float(tile_cos.mean()) + atol, rtol = 1e-2, 1e-3 + if not torch.allclose(tile_emb, gt_tile_emb, atol=atol, rtol=rtol): + assert mean_tile_cos >= 0.99, ( + f"Tile embedding mismatch: mean cosine similarity={mean_tile_cos:.4f} " + f"(atol={atol}, rtol={rtol})" + ) + else: + print(f"OK: tile embeddings within tolerance; mean cosine similarity={mean_tile_cos:.4f}") diff --git a/tests/test_progress.py b/tests/test_progress.py index 6138a0b..6a64c06 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -405,7 +405,7 @@ def encode_tiles(self, image): loaded = SimpleNamespace(device="cpu", feature_dim=3, model=FakeModel(), transforms=lambda image: image) with progress.activate_progress_reporter(reporter): - outputs = inference._run_forward_pass( + indices, outputs = inference._run_forward_pass( dataloader, loaded, nullcontext(), @@ -414,6 +414,7 @@ def encode_tiles(self, image): unit_label="tile", ) + assert torch.equal(indices, torch.tensor([0, 1, 2, 3, 4])) assert outputs.shape == (5, 3) payloads = [event.payload for event in reporter.events if event.kind == "embedding.tile.progress"] assert [payload["processed"] for payload in payloads] == [2, 4, 5] @@ -496,7 +497,7 @@ def encode_tiles(self, image): ) with progress.activate_progress_reporter(reporter): - outputs = inference._run_forward_pass( + indices, outputs = inference._run_forward_pass( dataloader, loaded, nullcontext(), @@ -505,6 +506,7 @@ def encode_tiles(self, image): unit_label="tile", ) + assert torch.equal(indices, torch.tensor([0, 1, 2])) assert outputs.shape == (3, 3) assert torch.all(outputs == 7.0) @@ -611,6 +613,46 @@ def wait(self, timeout=None): assert [event.kind for event in reporter.events] == ["embedding.slide.started"] +def test_run_torchrun_worker_uses_standalone_rendezvous(monkeypatch, tmp_path: Path): + import slide2vec.inference as inference + + request_path = tmp_path / "request.json" + request_path.write_text("{}", encoding="utf-8") + output_dir = tmp_path / "output" + output_dir.mkdir() + + observed = {} + + class FakePopen: + def __init__(self, command, **kwargs): + observed["command"] = command + self.stdout = io.StringIO("") + self.stderr = io.StringIO("") + self._returncode = 0 + + def poll(self): + return 0 + + def wait(self, timeout=None): + return 0 + + monkeypatch.setattr(inference.runtime_distributed.time, "sleep", lambda _seconds: None) + + inference.runtime_distributed.run_torchrun_worker( + module="slide2vec.distributed.direct_embed_worker", + num_gpus=2, + output_dir=output_dir, + request_path=request_path, + failure_title="boom", + popen_factory=FakePopen, + ) + + command = observed["command"] + assert "--standalone" in command + assert "--master_port" not in " ".join(command) + assert "--rdzv-endpoint" not in " ".join(command) + + def test_rich_reporter_collapses_multi_gpu_model_loading_into_one_task(monkeypatch): import slide2vec.progress as progress diff --git a/tests/test_regression_core.py b/tests/test_regression_core.py index f41b7ed..800ecfa 100644 --- a/tests/test_regression_core.py +++ b/tests/test_regression_core.py @@ -376,9 +376,9 @@ def test_execution_options_preserves_explicit_dataloader_workers(monkeypatch): monkeypatch.setattr(api, "cpu_worker_limit", lambda: 2) monkeypatch.setattr(api, "slurm_cpu_limit", lambda: 2) - execution = api.ExecutionOptions(num_workers=3) + execution = api.ExecutionOptions(num_workers_per_gpu=3) - assert execution.num_workers == 3 + assert execution.num_workers_per_gpu == 3 assert execution.num_preprocessing_workers == 2 def test_cpu_worker_limit_caps_large_cpu_budget_to_sixty_four(monkeypatch): @@ -393,7 +393,7 @@ def test_execution_options_default_batchis_thirty_two(): assert ExecutionOptions().batch_size == 32 def test_execution_options_default_num_workers_is_auto(): - assert ExecutionOptions().num_workers is None + assert ExecutionOptions().num_workers_per_gpu is None def test_execution_options_logs_resolved_auto_num_workers(monkeypatch, caplog): import slide2vec.api as api @@ -401,13 +401,25 @@ def test_execution_options_logs_resolved_auto_num_workers(monkeypatch, caplog): monkeypatch.setattr(api, "cpu_worker_limit", lambda: 18) monkeypatch.setattr(api, "slurm_cpu_limit", lambda: 18) monkeypatch.setattr(api.os, "cpu_count", lambda: 64) + monkeypatch.setattr(api, "_default_num_gpus", lambda: 1) with caplog.at_level("INFO"): execution = api.ExecutionOptions() - assert execution.num_workers is None - assert "ExecutionOptions: num_workers=18 (requested=auto)" in caplog.text - assert "num_workers=auto" not in caplog.text + assert execution.num_workers_per_gpu is None + assert "ExecutionOptions: num_workers_per_gpu=18 (requested=auto)" in caplog.text + assert "num_workers_per_gpu=auto" not in caplog.text + + +def test_execution_options_auto_workers_are_split_across_gpus(monkeypatch): + import slide2vec.api as api + + monkeypatch.setattr(api, "cpu_worker_limit", lambda: 18) + monkeypatch.setattr(api, "slurm_cpu_limit", lambda: 18) + + execution = api.ExecutionOptions(num_gpus=3) + + assert execution.resolved_num_workers_per_gpu() == 6 def test_hf_login_skips_hub_login_when_token_is_already_set(monkeypatch): @@ -453,11 +465,9 @@ def test_execution_options_from_config_maps_cli_fields(tmp_path: Path): assert execution.output_dir == tmp_path assert execution.output_format == "pt" assert execution.batch_size == 4 - assert execution.num_workers == 2 assert execution.num_gpus == 3 assert execution.precision == "bf16" assert execution.prefetch_factor == 5 - assert execution.persistent_workers is False assert execution.save_tile_embeddings is True assert execution.save_latents is True @@ -511,7 +521,7 @@ def test_execution_options_from_config_preserves_auto_num_workers(tmp_path: Path execution = ExecutionOptions.from_config(cfg) - assert execution.num_workers is None + assert execution.num_workers_per_gpu is None def test_execution_options_from_config_defaults_to_all_available_gpus_when_unset(monkeypatch, tmp_path: Path): import slide2vec.api as api @@ -540,7 +550,6 @@ def test_execution_options_from_config_defaults_to_all_available_gpus_when_unset assert execution.num_gpus == 6 assert execution.precision == "fp32" assert execution.prefetch_factor == 3 - assert execution.persistent_workers is True def test_execution_options_from_config_forces_fp32_for_cpu_runs(monkeypatch, tmp_path: Path): import slide2vec.api as api @@ -615,11 +624,10 @@ def test_execution_options_with_output_dir_preserves_other_fields(tmp_path: Path output_dir=None, output_format="npz", batch_size=8, - num_workers=3, + num_workers_per_gpu=3, num_gpus=2, precision="bf16", prefetch_factor=6, - persistent_workers=False, save_tile_embeddings=True, save_latents=True, ) @@ -629,11 +637,10 @@ def test_execution_options_with_output_dir_preserves_other_fields(tmp_path: Path assert updated.output_dir == tmp_path assert updated.output_format == base.output_format assert updated.batch_size == base.batch_size - assert updated.num_workers == base.num_workers + assert updated.num_workers_per_gpu == base.num_workers_per_gpu assert updated.num_gpus == base.num_gpus assert updated.precision == base.precision assert updated.prefetch_factor == base.prefetch_factor - assert updated.persistent_workers == base.persistent_workers assert updated.save_tile_embeddings == base.save_tile_embeddings assert updated.save_latents == base.save_latents assert updated is not base diff --git a/tests/test_regression_inference.py b/tests/test_regression_inference.py index 65a331b..e60388c 100644 --- a/tests/test_regression_inference.py +++ b/tests/test_regression_inference.py @@ -1,5 +1,7 @@ import ast +import json import sys +from contextlib import contextmanager from dataclasses import replace from pathlib import Path from types import SimpleNamespace @@ -8,6 +10,7 @@ import numpy as np import pandas as pd import pytest +import torch from slide2vec.api import ( EmbeddedSlide, @@ -522,6 +525,138 @@ def fake_persist_embedded_slide(model, embedded_slide, tiling_result, *, preproc assert recorded.loc["slide-a", "feature_path"] == str((tmp_path / "relative-output" / "slide_embeddings" / "slide-a.pt").resolve()) +def test_aggregate_tiles_uses_autocast_for_slide_encoding(monkeypatch, tmp_path: Path): + import slide2vec.inference as inference + + autocast_active = False + + @contextmanager + def fake_autocast(*, device_type: str, dtype): + nonlocal autocast_active + assert device_type == "cuda" + assert dtype == torch.float16 + autocast_active = True + try: + yield + finally: + autocast_active = False + + def encode_slide(tile_features, coordinates, *, tile_size_lv0: int | None = None): + assert autocast_active is True + assert tile_features.shape == (1, 4) + assert coordinates.shape == (1, 2) + assert tile_size_lv0 == 224 + return torch.ones(4, dtype=torch.float32) + + monkeypatch.setattr(inference.torch, "autocast", fake_autocast) + monkeypatch.setattr(inference, "_autocast_dtype", lambda torch_module, precision: torch_module.float16) + monkeypatch.setattr(inference, "_uses_cuda_runtime", lambda device: True) + monkeypatch.setattr( + inference.runtime_tiling, + "load_tiling_result_from_paths", + lambda *_args, **_kwargs: SimpleNamespace( + x=np.array([0], dtype=np.int64), + y=np.array([1], dtype=np.int64), + tile_size_lv0=224, + requested_spacing_um=0.5, + ), + ) + monkeypatch.setattr( + inference, + "load_array", + lambda *_args, **_kwargs: np.ones((1, 4), dtype=np.float32), + ) + + captured = {} + + def fake_write_slide_embedding_artifact(sample_id, embedding, *, execution, metadata, latents=None): + captured["sample_id"] = sample_id + captured["embedding"] = embedding + captured["execution"] = execution + captured["metadata"] = metadata + captured["latents"] = latents + return SimpleNamespace(sample_id=sample_id, path=tmp_path / "slide_embeddings" / f"{sample_id}.pt") + + loaded = SimpleNamespace(device=torch.device("cpu"), model=SimpleNamespace(encode_slide=encode_slide)) + model = SimpleNamespace(name="prism", level="slide", _load_backend=lambda: loaded) + artifact = SimpleNamespace( + sample_id="slide-a", + path=tmp_path / "tile_embeddings" / "slide-a.pt", + metadata={ + "coordinates_npz_path": str(tmp_path / "slide-a.coordinates.npz"), + "coordinates_meta_path": str(tmp_path / "slide-a.coordinates.meta.json"), + "image_path": str(tmp_path / "slide-a.svs"), + }, + ) + + monkeypatch.setattr(inference.runtime_embedding, "write_slide_embedding_artifact", fake_write_slide_embedding_artifact) + + outputs = inference.aggregate_tiles( + model, + [artifact], + preprocessing=DEFAULT_PREPROCESSING, + execution=ExecutionOptions(output_dir=tmp_path, precision="fp16", save_slide_embeddings=True), + ) + + assert len(outputs) == 1 + assert captured["sample_id"] == "slide-a" + assert torch.equal(captured["embedding"], torch.ones(4)) + assert captured["latents"] is None + assert autocast_active is False + + +def test_aggregate_tile_embeddings_for_slide_uses_autocast(monkeypatch, tmp_path: Path): + import slide2vec.inference as inference + + autocast_active = False + + @contextmanager + def fake_autocast(*, device_type: str, dtype): + nonlocal autocast_active + assert device_type == "cuda" + assert dtype == torch.float16 + autocast_active = True + try: + yield + finally: + autocast_active = False + + def encode_slide(tile_features, coordinates, *, tile_size_lv0: int | None = None): + assert autocast_active is True + assert tile_features.shape == (1, 4) + assert coordinates.shape == (1, 2) + assert tile_size_lv0 == 224 + return torch.ones(4, dtype=torch.float32) + + monkeypatch.setattr(inference.torch, "autocast", fake_autocast) + monkeypatch.setattr(inference, "_autocast_dtype", lambda torch_module, precision: torch_module.float16) + monkeypatch.setattr(inference, "_uses_cuda_runtime", lambda device: True) + + loaded = SimpleNamespace(device=torch.device("cpu"), model=SimpleNamespace(encode_slide=encode_slide)) + model = SimpleNamespace(level="slide", name="prism") + slide = make_slide("slide-a") + tiling_result = SimpleNamespace( + x=np.array([0], dtype=np.int64), + y=np.array([1], dtype=np.int64), + tile_size_lv0=224, + ) + tile_embeddings = np.ones((1, 4), dtype=np.float32) + + slide_embedding, latents = inference._aggregate_tile_embeddings_for_slide( + loaded, + model, + slide, + tiling_result, + tile_embeddings, + preprocessing=DEFAULT_PREPROCESSING, + execution=ExecutionOptions(output_dir=tmp_path, precision="fp16"), + ) + + assert torch.equal(slide_embedding, torch.ones(4)) + assert latents is None + assert autocast_active is False + + def test_run_pipeline_skips_zero_tile_slides_and_counts_only_embeddable_slides(monkeypatch, tmp_path: Path): import slide2vec.inference as inference import slide2vec.progress as progress @@ -585,9 +720,19 @@ def write_log(self, message, *, stream=None): lambda *args, **kwargs: ([slide_zero, slide_full], [zero_tiling, full_tiling], process_list_path), ) - def fake_compute_embedded_slides(model, slide_records, tiling_results, *, preprocessing, execution, on_embedded_slide=None): + def fake_compute_embedded_slides( + model, + slide_records, + tiling_results, + *, + preprocessing, + execution, + on_embedded_slide=None, + collect_results=True, + ): captured["slide_records"] = [slide.sample_id for slide in slide_records] captured["tiling_results"] = [result.x.shape[0] for result in tiling_results] + captured["collect_results"] = collect_results if on_embedded_slide is not None: on_embedded_slide(slide_full, full_tiling, embedded_full) return [embedded_full] @@ -608,7 +753,7 @@ def fake_compute_embedded_slides(model, slide_records, tiling_results, *, prepro model, slides=[slide_zero, slide_full], preprocessing=DEFAULT_PREPROCESSING, - execution=ExecutionOptions(output_dir=tmp_path, save_tile_embeddings=True), + execution=ExecutionOptions(output_dir=tmp_path, num_gpus=1, save_tile_embeddings=True), ) zero_meta = load_metadata(tmp_path / "tile_embeddings" / "slide-zero.meta.json") @@ -741,14 +886,17 @@ def test_run_pipeline_local_branch_uses_incremental_persist_callback(monkeypatch "_prepare_tiled_slides", lambda *args, **kwargs: ([slide_record], [tiling_result], tmp_path / "process_list.csv"), ) - monkeypatch.setattr( - inference, - "_compute_embedded_slides", - lambda *args, **kwargs: [embedded], - ) - captured = {} + def fake_compute_embedded_slides(*args, **kwargs): + captured["collect_results"] = kwargs.get("collect_results") + callback = kwargs.get("on_embedded_slide") + if callback is not None: + callback(slide_record, tiling_result, embedded) + return [] + + monkeypatch.setattr(inference, "_compute_embedded_slides", fake_compute_embedded_slides) + def fake_build_callback(*, model, preprocessing, execution, process_list_path): captured["model"] = model captured["preprocessing"] = preprocessing @@ -772,10 +920,194 @@ def fake_build_callback(*, model, preprocessing, execution, process_list_path): ) assert captured["process_list_path"] == tmp_path / "process_list.csv" + assert captured["collect_results"] is False assert result.tile_artifacts == ["tile-artifact"] assert result.slide_artifacts == ["slide-artifact"] +def test_compute_embedded_slides_skips_retaining_results_when_collect_results_is_false(monkeypatch): + import slide2vec.inference as inference + + slides = [make_slide("slide-a"), make_slide("slide-b")] + tiling_results = [ + SimpleNamespace(x=np.array([0]), y=np.array([1]), tile_size_lv0=224), + SimpleNamespace(x=np.array([2]), y=np.array([3]), tile_size_lv0=224), + ] + seen: list[str] = [] + + model = SimpleNamespace(level="tile", _load_backend=lambda: SimpleNamespace()) + + monkeypatch.setattr(inference, "emit_progress", lambda *args, **kwargs: None) + monkeypatch.setattr(inference, "_is_hierarchical_preprocessing", lambda preprocessing: False) + monkeypatch.setattr( + inference, + "_compute_tile_embeddings_for_slide", + lambda *args, **kwargs: np.zeros((1, 2), dtype=np.float32), + ) + monkeypatch.setattr( + inference, + "_aggregate_tile_embeddings_for_slide", + lambda *args, **kwargs: (None, None), + ) + monkeypatch.setattr( + inference, + "_make_embedded_slide", + lambda *, slide, **kwargs: SimpleNamespace(sample_id=slide.sample_id), + ) + + result = inference._compute_embedded_slides( + model, + slides, + tiling_results, + preprocessing=DEFAULT_PREPROCESSING, + execution=ExecutionOptions(output_dir=Path("/tmp")), + on_embedded_slide=lambda slide, tiling_result, embedded_slide: seen.append(embedded_slide.sample_id), + collect_results=False, + ) + + assert result == [] + assert seen == ["slide-a", "slide-b"] + + +def test_pipeline_worker_disables_result_collection_when_streaming(monkeypatch, tmp_path: Path): + import torch.distributed as dist + + import slide2vec.distributed as distributed + import slide2vec.inference as inference + import slide2vec.runtime.serialization as serialization + from slide2vec.api import Model + from slide2vec.distributed import pipeline_worker + + request_path = tmp_path / "request.json" + request_path.write_text( + json.dumps( + { + "model": { + "name": "virchow2", + "allow_non_recommended_settings": False, + }, + "preprocessing": {}, + "execution": {}, + "tiling_input_dir": str(tmp_path), + } + ), + encoding="utf-8", + ) + + slide = make_slide("slide-a") + tiling_result = SimpleNamespace(x=np.array([0]), y=np.array([1]), tile_size_lv0=224) + captured = {} + + monkeypatch.setattr(distributed, "enable", lambda overwrite=True: None) + monkeypatch.setattr(distributed, "get_local_rank", lambda: 0) + monkeypatch.setattr(distributed, "get_global_rank", lambda: 0) + monkeypatch.setattr(distributed, "get_global_size", lambda: 1) + monkeypatch.setattr(dist, "is_available", lambda: False) + monkeypatch.setattr(dist, "is_initialized", lambda: False) + monkeypatch.setattr(Model, "from_preset", lambda *args, **kwargs: SimpleNamespace()) + monkeypatch.setattr(serialization, "deserialize_preprocessing", lambda payload: DEFAULT_PREPROCESSING) + monkeypatch.setattr( + serialization, + "deserialize_execution", + lambda payload: ExecutionOptions(output_dir=tmp_path), + ) + monkeypatch.setattr( + inference, + "load_successful_tiled_slides", + lambda tiling_input_dir: ([slide], [tiling_result]), + ) + monkeypatch.setattr( + inference, + "_build_incremental_persist_callback", + lambda **kwargs: (lambda *args, **kwargs: None, [], []), + ) + + def fake_compute_embedded_slides(*args, **kwargs): + captured["collect_results"] = kwargs.get("collect_results") + return [] + + monkeypatch.setattr(inference, "_compute_embedded_slides", fake_compute_embedded_slides) + monkeypatch.setattr( + pipeline_worker, + "assign_slides_to_ranks", + lambda slide_records, tiling_results, *, num_gpus: {0: ["slide-a"]}, + ) + + assert pipeline_worker.main(["--output-dir", str(tmp_path), "--request-path", str(request_path)]) == 0 + assert captured["collect_results"] is False + + +def test_direct_embed_worker_streams_payloads_without_retaining_results(monkeypatch, tmp_path: Path): + import torch + import torch.distributed as dist + + import slide2vec.distributed as distributed + import slide2vec.runtime.serialization as serialization + from slide2vec.api import Model + from slide2vec.distributed import direct_embed_worker + + coordination_dir = tmp_path / "coordination" + coordination_dir.mkdir() + request_path = tmp_path / "request.json" + request_path.write_text( + json.dumps( + { + "model": { + "name": "virchow2", + "allow_non_recommended_settings": False, + }, + "preprocessing": {}, + "execution": {}, + "coordination_dir": str(coordination_dir), + "strategy": "slide_shard", + "assignments": {"0": ["slide-a"]}, + } + ), + encoding="utf-8", + ) + + slide = make_slide("slide-a") + tiling_result = SimpleNamespace(x=np.array([0]), y=np.array([1]), tile_size_lv0=224) + captured = {} + + monkeypatch.setattr(distributed, "enable", lambda overwrite=True: None) + monkeypatch.setattr(distributed, "get_local_rank", lambda: 0) + monkeypatch.setattr(distributed, "get_global_rank", lambda: 0) + monkeypatch.setattr(distributed, "get_global_size", lambda: 1) + monkeypatch.setattr(dist, "is_available", lambda: False) + monkeypatch.setattr(dist, "is_initialized", lambda: False) + monkeypatch.setattr(Model, "from_preset", lambda *args, **kwargs: SimpleNamespace()) + monkeypatch.setattr(serialization, "deserialize_preprocessing", lambda payload: DEFAULT_PREPROCESSING) + monkeypatch.setattr( + serialization, + "deserialize_execution", + lambda payload: ExecutionOptions(output_dir=tmp_path), + ) + monkeypatch.setattr( + direct_embed_worker, + "_to_cpu_payload", + lambda value: value, + ) + + import slide2vec.inference as inference + + monkeypatch.setattr( + inference, + "load_successful_tiled_slides", + lambda output_dir: ([slide], [tiling_result]), + ) + monkeypatch.setattr( + inference, + "_compute_embedded_slides", + lambda *args, **kwargs: [], + ) + + assert direct_embed_worker.main(["--output-dir", str(tmp_path), "--request-path", str(request_path)]) == 0 + source = (ROOT / "slide2vec" / "distributed" / "direct_embed_worker.py").read_text(encoding="utf-8") + assert "collect_results=False" in source + assert "on_embedded_slide=_persist_embedded_slide" in source + + def test_run_pipeline_local_branch_persists_completed_slides_before_later_failure(monkeypatch, tmp_path: Path): import slide2vec.inference as inference @@ -945,7 +1277,7 @@ def fake_compute_tile_embeddings(loaded, model, slide, tiling_result, *, preproc model, slides=slides, preprocessing=DEFAULT_PREPROCESSING, - execution=ExecutionOptions(output_dir=tmp_path, save_tile_embeddings=True), + execution=ExecutionOptions(output_dir=tmp_path, num_gpus=1, save_tile_embeddings=True), ) assert (tmp_path / "tile_embeddings" / "slide-a.pt").is_file() @@ -2167,8 +2499,9 @@ def encode_tiles(self, image): device=torch.device("cpu"), ) - result = inference._run_forward_pass(dataloader, loaded, nullcontext()) + indices, result = inference._run_forward_pass(dataloader, loaded, nullcontext()) + assert indices.shape == (0,) assert result.shape == (0, 5) assert result.dtype == torch.float32 @@ -2240,7 +2573,7 @@ def encode_tiles(self, image): device=torch.device("cpu"), ) - result = inference._run_forward_pass( + indices, result = inference._run_forward_pass( DummyLoader(), loaded, nullcontext(), @@ -2249,10 +2582,81 @@ def encode_tiles(self, image): total_items=2, ) + assert torch.equal(indices, torch.tensor([0, 1], dtype=torch.long)) assert result.shape == (2, 3) assert torch.allclose(result, torch.ones((2, 3), dtype=torch.float32)) +def test_run_forward_pass_preserves_embedding_order_and_indices_across_batches(): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + from contextlib import nullcontext + + class DummyLoader: + def __iter__(self): + yield ( + torch.tensor([5, 2], dtype=torch.long), + torch.tensor( + [ + [[[10.0]]], + [[[20.0]]], + ], + dtype=torch.float32, + ), + ) + yield ( + torch.tensor([9], dtype=torch.long), + torch.tensor([[[[30.0]]]], dtype=torch.float32), + ) + yield ( + torch.tensor([4, 1], dtype=torch.long), + torch.tensor( + [ + [[[40.0]]], + [[[50.0]]], + ], + dtype=torch.float32, + ), + ) + + def __len__(self): + return 3 + + class DummyModel: + def encode_tiles(self, image): + values = image[:, 0, 0, 0] + return torch.stack((values, values + 0.5), dim=1) + + loaded = inference.LoadedModel( + name="virchow2", + level="tile", + model=DummyModel(), + transforms=lambda image: image, + feature_dim=2, + device=torch.device("cpu"), + ) + + indices, embeddings = inference._run_forward_pass( + DummyLoader(), + loaded, + nullcontext(), + total_items=5, + ) + + assert torch.equal(indices, torch.tensor([5, 2, 9, 4, 1], dtype=torch.long)) + expected = torch.tensor( + [ + [10.0, 10.5], + [20.0, 20.5], + [30.0, 30.5], + [40.0, 40.5], + [50.0, 50.5], + ], + dtype=torch.float32, + ) + assert torch.equal(embeddings, expected) + + def test_serialize_execution_preserves_loader_optimization_fields(): import slide2vec.inference as inference from slide2vec.runtime.serialization import deserialize_execution @@ -2260,11 +2664,10 @@ def test_serialize_execution_preserves_loader_optimization_fields(): execution = ExecutionOptions( output_dir=Path("/tmp/output"), batch_size=64, - num_workers=8, + num_workers_per_gpu=8, num_gpus=2, precision="bf16", prefetch_factor=7, - persistent_workers=False, save_tile_embeddings=True, save_latents=True, ) @@ -2273,10 +2676,8 @@ def test_serialize_execution_preserves_loader_optimization_fields(): restored = deserialize_execution(payload) assert payload["prefetch_factor"] == 7 - assert payload["persistent_workers"] is False assert payload["precision"] == "bf16" assert restored.prefetch_factor == 7 - assert restored.persistent_workers is False assert restored.precision == "bf16" @@ -2286,7 +2687,7 @@ def test_serialize_execution_preserves_slide_embedding_and_preprocessing_worker_ execution = ExecutionOptions( output_dir=Path("/tmp/output"), - num_workers=8, + num_workers_per_gpu=8, num_preprocessing_workers=3, num_gpus=2, save_tile_embeddings=True, @@ -2308,15 +2709,15 @@ def test_deserialize_execution_defaults_num_workers_to_auto(): restored = deserialize_execution({"batch_size": 4, "num_gpus": 1}) - assert restored.num_workers is None + assert restored.num_workers_per_gpu is None def test_deserialize_execution_preserves_auto_num_workers(): from slide2vec.runtime.serialization import deserialize_execution - restored = deserialize_execution({"batch_size": 4, "num_workers": None, "num_gpus": 1}) + restored = deserialize_execution({"batch_size": 4, "num_workers_per_gpu": None, "num_gpus": 1}) - assert restored.num_workers is None + assert restored.num_workers_per_gpu is None def test_embedding_dataloader_kwargs_resolve_auto_mode_to_cpu_budget(monkeypatch): @@ -2325,6 +2726,11 @@ def test_embedding_dataloader_kwargs_resolve_auto_mode_to_cpu_budget(monkeypatch torch = pytest.importorskip("torch") monkeypatch.setattr(api, "cpu_worker_limit", lambda: 24) + monkeypatch.setattr(api, "slurm_cpu_limit", lambda: 24) + monkeypatch.setattr(inference, "cpu_worker_limit", lambda: 24) + monkeypatch.setattr(inference, "slurm_cpu_limit", lambda: 24) + monkeypatch.setattr("slide2vec.utils.utils.cpu_worker_limit", lambda: 24) + monkeypatch.setattr("slide2vec.utils.utils.slurm_cpu_limit", lambda: 24) loaded = inference.LoadedModel( name="test", @@ -2337,11 +2743,10 @@ def test_embedding_dataloader_kwargs_resolve_auto_mode_to_cpu_budget(monkeypatch kwargs = inference._embedding_dataloader_kwargs( loaded, - ExecutionOptions(num_workers=None, num_gpus=1), + ExecutionOptions(num_workers_per_gpu=None, num_gpus=1), ) assert kwargs["num_workers"] == 24 - assert kwargs["persistent_workers"] is True assert kwargs["prefetch_factor"] == 4 @@ -2390,6 +2795,11 @@ def __call__(self, batch_indices): monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) monkeypatch.setattr(inference, "_build_batch_preprocessor", lambda *args, **kwargs: lambda batch: batch.float()) monkeypatch.setattr(api, "cpu_worker_limit", lambda: 24) + monkeypatch.setattr(api, "slurm_cpu_limit", lambda: 24) + monkeypatch.setattr(inference, "cpu_worker_limit", lambda: 24) + monkeypatch.setattr(inference, "slurm_cpu_limit", lambda: 24) + monkeypatch.setattr("slide2vec.utils.utils.cpu_worker_limit", lambda: 24) + monkeypatch.setattr("slide2vec.utils.utils.slurm_cpu_limit", lambda: 24) loaded = inference.LoadedModel( name="prov-gigapath", @@ -2415,12 +2825,11 @@ def __call__(self, batch_indices): tile_size_lv0=224, ), preprocessing=replace(DEFAULT_PREPROCESSING, on_the_fly=True, backend="auto", num_cucim_workers=4), - execution=ExecutionOptions(batch_size=2, num_workers=None, num_gpus=1), + execution=ExecutionOptions(batch_size=2, num_workers_per_gpu=None, num_gpus=1), ) assert result.shape == (2, 3) assert captured["kwargs"]["num_workers"] == 24 - assert captured["kwargs"]["persistent_workers"] is True assert captured["kwargs"]["prefetch_factor"] == 4 assert captured["wsd_collator_kwargs"]["backend"] == "asap" @@ -2480,10 +2889,9 @@ def encode_tiles(self, image): ) execution = ExecutionOptions( batch_size=2, - num_workers=3, + num_workers_per_gpu=3, num_gpus=1, prefetch_factor=9, - persistent_workers=True, ) result = inference._compute_tile_embeddings_for_slide( @@ -2497,7 +2905,6 @@ def encode_tiles(self, image): assert result.shape == (2, 3) assert captured["kwargs"]["num_workers"] == 3 - assert captured["kwargs"]["persistent_workers"] is True assert captured["kwargs"]["prefetch_factor"] == 9 assert captured["kwargs"]["collate_fn"] == ( "collator", @@ -2568,7 +2975,7 @@ def encode_tiles(self, image): slide, tiling_result, preprocessing=replace(DEFAULT_PREPROCESSING, read_tiles_from=Path("/tmp/external-tiles")), - execution=ExecutionOptions(batch_size=1, num_workers=0, num_gpus=1), + execution=ExecutionOptions(batch_size=1, num_workers_per_gpu=0, num_gpus=1), ) assert result.shape == (1, 3) @@ -2581,29 +2988,6 @@ def encode_tiles(self, image): ) -def test_resolve_on_the_fly_num_workers_caps_to_slurm_allocation(monkeypatch): - import slide2vec.inference as inference - - monkeypatch.setattr(inference.os, "cpu_count", lambda: 96) - monkeypatch.setenv("SLURM_JOB_CPUS_PER_NODE", "32") - monkeypatch.delenv("SLURM_CPUS_PER_TASK", raising=False) - monkeypatch.delenv("SLURM_CPUS_ON_NODE", raising=False) - - workers, details = inference._resolve_on_the_fly_num_workers(4) - - assert workers == 8 - assert "cpu_count=96" in details - assert "slurm_cpu_limit=32" in details - assert "num_cucim_workers=4" in details - - -def test_resolve_on_the_fly_num_workers_rejects_non_positive_cucim_worker_count(): - import slide2vec.inference as inference - - with pytest.raises(ValueError, match="num_cucim_workers must be at least 1"): - inference._resolve_on_the_fly_num_workers(0) - - def test_compute_tile_embeddings_for_slide_caps_on_the_fly_workers_to_slurm(monkeypatch, caplog): import slide2vec.inference as inference torch = pytest.importorskip("torch") @@ -2673,10 +3057,9 @@ def __call__(self, batch_indices): ) execution = ExecutionOptions( batch_size=2, - num_workers=99, + num_workers_per_gpu=99, num_gpus=1, prefetch_factor=9, - persistent_workers=True, ) with caplog.at_level("INFO"): @@ -2691,11 +3074,93 @@ def __call__(self, batch_indices): assert result.shape == (2, 3) assert captured["kwargs"]["num_workers"] == 8 - assert captured["kwargs"]["persistent_workers"] is True assert captured["kwargs"]["prefetch_factor"] == 9 assert "on-the-fly mode: setting DataLoader num_workers=8" not in caplog.text +def test_compute_tile_embeddings_for_slide_splits_on_the_fly_workers_across_gpus(): + import subprocess + import sys + + script = """ +import slide2vec.inference as inference +inference.cpu_worker_limit = lambda: 24 +inference.slurm_cpu_limit = lambda: 24 +workers, details = inference._resolve_on_the_fly_num_workers(4, 2) +print(workers) +print(details) +""" + result = subprocess.run( + [sys.executable, "-c", script], + check=True, + capture_output=True, + text=True, + ) + + assert result.stdout.splitlines()[0] == "3" + assert "num_gpus=2" in result.stdout + + +def test_compute_tile_embeddings_for_slide_rejects_non_positive_cucim_worker_count(monkeypatch): + import slide2vec.inference as inference + torch = pytest.importorskip("torch") + + class DummyLoader: + def __init__(self, dataset, **kwargs): + del dataset, kwargs + + def __iter__(self): + yield ( + torch.tensor([0], dtype=torch.long), + torch.zeros((1, 3, 4, 4), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + + def __len__(self): + return 1 + + class DummyEncoder: + pretrained_cfg = {} + + class DummyModel: + encoder = DummyEncoder() + + def encode_tiles(self, image): + return torch.ones((image.shape[0], 3), dtype=torch.float32, device=image.device) + + monkeypatch.setattr(inference, "OnTheFlyBatchTileCollator", lambda **kwargs: SimpleNamespace(__call__=lambda batch_indices: None, ordered_indices=None)) + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + monkeypatch.setattr(inference, "_build_batch_preprocessor", lambda *args, **kwargs: lambda batch: batch.float()) + + loaded = inference.LoadedModel( + name="prov-gigapath", + level="tile", + model=DummyModel(), + transforms=object(), + feature_dim=3, + device=torch.device("cpu"), + ) + + with pytest.raises(ValueError, match="num_cucim_workers must be at least 1"): + inference._compute_tile_embeddings_for_slide( + loaded, + SimpleNamespace(level="tile"), + make_slide("slide-a"), + SimpleNamespace( + x=np.array([0]), + y=np.array([1]), + backend="cucim", + requested_spacing_um=0.5, + requested_tile_size_px=4, + read_spacing_um=0.5, + read_tile_size_px=4, + tile_size_lv0=224, + ), + preprocessing=replace(DEFAULT_PREPROCESSING, on_the_fly=True, backend="cucim", num_cucim_workers=0), + execution=ExecutionOptions(batch_size=2, num_workers_per_gpu=99, num_gpus=2), + ) + + def test_run_pipeline_logs_on_the_fly_worker_override_once(monkeypatch, tmp_path: Path, caplog): import slide2vec.inference as inference @@ -2748,7 +3213,7 @@ def test_run_pipeline_logs_on_the_fly_worker_override_once(monkeypatch, tmp_path execution=execution, ) - assert caplog.text.count("on-the-fly mode: setting DataLoader num_workers=") == 1 + assert caplog.text.count("on-the-fly mode: setting DataLoader num_workers_per_gpu=") == 1 def test_compute_tile_embeddings_for_slide_filters_on_the_fly_cucim_stderr_without_changing_workers(monkeypatch): @@ -2823,10 +3288,9 @@ def _fake_run_with_filtered_stderr(func, **kwargs): ) execution = ExecutionOptions( batch_size=2, - num_workers=99, + num_workers_per_gpu=99, num_gpus=1, prefetch_factor=9, - persistent_workers=True, ) result = inference._compute_tile_embeddings_for_slide( @@ -2840,7 +3304,6 @@ def _fake_run_with_filtered_stderr(func, **kwargs): assert result.shape == (2, 3) assert captured["kwargs"]["num_workers"] == 8 - assert captured["kwargs"]["persistent_workers"] is True assert captured["kwargs"]["prefetch_factor"] == 9 assert captured["filtered_calls"] == 1 @@ -2916,7 +3379,7 @@ def __call__(self, batch_indices): tile_size_lv0=224, ), preprocessing=replace(DEFAULT_PREPROCESSING, on_the_fly=True, backend="auto", num_cucim_workers=4), - execution=ExecutionOptions(batch_size=2, num_workers=8, num_gpus=1), + execution=ExecutionOptions(batch_size=2, num_workers_per_gpu=8, num_gpus=1), ) assert result.shape == (2, 3) @@ -2993,12 +3456,11 @@ def __call__(self, batch_indices): tile_size_lv0=224, ), preprocessing=replace(DEFAULT_PREPROCESSING, on_the_fly=True, backend="auto", num_cucim_workers=4), - execution=ExecutionOptions(batch_size=2, num_workers=8, num_gpus=1), + execution=ExecutionOptions(batch_size=2, num_workers_per_gpu=8, num_gpus=1), ) assert result.shape == (2, 3) assert captured["kwargs"]["num_workers"] == 8 - assert captured["kwargs"]["persistent_workers"] is True assert captured["kwargs"]["prefetch_factor"] == 4 assert captured["wsd_collator_kwargs"]["backend"] == "asap" @@ -3069,7 +3531,7 @@ def test_compute_tile_embeddings_for_slide_requires_current_run_tile_store_witho tiles_tar_path=None, ), preprocessing=replace(DEFAULT_PREPROCESSING, on_the_fly=False), - execution=ExecutionOptions(batch_size=1, num_workers=0, num_gpus=1), + execution=ExecutionOptions(batch_size=1, num_workers_per_gpu=0, num_gpus=1), ) @@ -3307,7 +3769,7 @@ def build_batch_sampler(self, *, batch_size, dataset_indices): slide, tiling_result, preprocessing=replace(DEFAULT_PREPROCESSING, region_tile_multiple=2, requested_region_size_px=448), - execution=ExecutionOptions(batch_size=4, num_workers=0, num_gpus=1), + execution=ExecutionOptions(batch_size=4, num_workers_per_gpu=0, num_gpus=1), ) assert result.shape == (2, 4, 2) diff --git a/tests/test_runtime_batching.py b/tests/test_runtime_batching.py new file mode 100644 index 0000000..37e2500 --- /dev/null +++ b/tests/test_runtime_batching.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import torch +from torchvision.transforms import functional as tvF + +from slide2vec.runtime.batching import apply_transforms_itemwise + + +class ConvertToRgbAndBack: + def __call__(self, image): + return tvF.pil_to_tensor(image.convert("RGB")) + + +def test_apply_transforms_itemwise_converts_tensor_samples_for_pil_only_transforms(): + image = torch.tensor( + [ + [ + [[0, 10], [20, 30]], + [[40, 50], [60, 70]], + [[80, 90], [100, 110]], + ], + [ + [[1, 11], [21, 31]], + [[41, 51], [61, 71]], + [[81, 91], [101, 111]], + ], + ], + dtype=torch.uint8, + ) + + transformed = apply_transforms_itemwise(image, ConvertToRgbAndBack()) + + assert torch.equal(transformed, image)