Skip to content

Commit a7503cf

Browse files
authored
fix floating point precision and port (#123)
1 parent 9544aba commit a7503cf

16 files changed

Lines changed: 1021 additions & 223 deletions

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,5 @@ cython_debug/
165165
output/
166166
outputs/
167167
archive/
168-
tasks/
169168
docs/20*-*.md
170169
data/

slide2vec/api.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ class ExecutionOptions:
159159
output_format: str = "pt"
160160
#: Number of tiles per forward pass.
161161
batch_size: int = 32
162-
#: DataLoader worker count. ``None`` means auto (capped by CPU / SLURM limit).
163-
num_workers: int | None = None
162+
#: DataLoader worker count per GPU rank. ``None`` means auto
163+
#: (capped by CPU / SLURM limit, then split across the resolved GPU count).
164+
num_workers_per_gpu: int | None = None
164165
#: Tiling worker count. ``None`` means auto (capped by CPU / SLURM limit).
165166
num_preprocessing_workers: int | None = None
166167
#: Number of GPUs to use. ``None`` defaults to all available GPUs.
@@ -170,8 +171,6 @@ class ExecutionOptions:
170171
precision: str | None = None
171172
#: DataLoader prefetch queue depth per worker (default ``4``).
172173
prefetch_factor: int = 4
173-
#: Keep DataLoader workers alive between batches (default ``True``).
174-
persistent_workers: bool = True
175174
#: Persist tile embeddings to disk when running a slide-level model.
176175
save_tile_embeddings: bool = False
177176
#: Persist slide embeddings to disk when running a patient-level model.
@@ -183,14 +182,13 @@ class ExecutionOptions:
183182
def from_config(cls, cfg: Any, *, run_on_cpu: bool = False) -> "ExecutionOptions":
184183
configured_num_gpus = cfg.speed.num_gpus
185184
requested_precision = normalize_precision_name(cfg.speed.precision)
186-
num_workers = cfg.speed.num_dataloader_workers
185+
num_workers_per_gpu = cfg.speed.num_dataloader_workers
187186
prefetch_factor = int(cfg.speed.prefetch_factor_embedding)
188-
persistent_workers = bool(cfg.speed.persistent_workers_embedding)
189187
return cls(
190188
output_dir=Path(cfg.output_dir),
191189
output_format="pt",
192190
batch_size=int(cfg.model.batch_size),
193-
num_workers=int(num_workers) if num_workers is not None else None,
191+
num_workers_per_gpu=int(num_workers_per_gpu) if num_workers_per_gpu is not None else None,
194192
num_preprocessing_workers=(
195193
int(cfg.speed.num_preprocessing_workers)
196194
if cfg.speed.num_preprocessing_workers is not None
@@ -199,7 +197,6 @@ def from_config(cls, cfg: Any, *, run_on_cpu: bool = False) -> "ExecutionOptions
199197
num_gpus=1 if run_on_cpu else (int(configured_num_gpus) if configured_num_gpus is not None else None),
200198
precision="fp32" if run_on_cpu else requested_precision,
201199
prefetch_factor=prefetch_factor,
202-
persistent_workers=persistent_workers,
203200
save_tile_embeddings=bool(cfg.model.save_tile_embeddings),
204201
save_slide_embeddings=bool(cfg.model.save_slide_embeddings),
205202
save_latents=bool(cfg.model.save_latents),
@@ -222,23 +219,25 @@ def __post_init__(self) -> None:
222219
object.__setattr__(self, "num_preprocessing_workers", capped_num_preprocessing_workers)
223220
logger = logging.getLogger(__name__)
224221
cap_source = f"slurm_cpu_limit={slurm_limit}" if slurm_limit is not None else f"cpu_count={cpu_count}"
225-
resolved_num_workers = self.resolved_num_workers()
226-
num_workers_label = (
222+
resolved_num_workers = self.resolved_num_workers_per_gpu()
223+
num_workers_per_gpu_label = (
227224
f"{resolved_num_workers} (requested=auto)"
228-
if self.num_workers is None
225+
if self.num_workers_per_gpu is None
229226
else str(resolved_num_workers)
230227
)
231228
logger.info(
232-
"ExecutionOptions: num_workers=%s, num_preprocessing_workers=%d "
229+
"ExecutionOptions: num_workers_per_gpu=%s, num_preprocessing_workers=%d "
233230
"(preprocessing cap=%d via %s)",
234-
num_workers_label,
231+
num_workers_per_gpu_label,
235232
capped_num_preprocessing_workers,
236233
cap,
237234
cap_source,
238235
)
239236

240-
def resolved_num_workers(self) -> int:
241-
return cpu_worker_limit() if self.num_workers is None else int(self.num_workers)
237+
def resolved_num_workers_per_gpu(self) -> int:
238+
if self.num_workers_per_gpu is not None:
239+
return self.num_workers_per_gpu
240+
return max(1, cpu_worker_limit() // self.num_gpus)
242241

243242
def with_output_dir(self, output_dir: PathLike | None) -> "ExecutionOptions":
244243
if output_dir is None:

slide2vec/configs/default.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ tiling:
4242
sthresh_up: 255 # upper threshold value for scaling the binary mask
4343
mthresh: 7 # median filter size (positive, odd integer)
4444
close: 4 # additional morphological closing to apply following initial thresholding (positive integer)
45-
method: "hsv" # tissue segmentation method: "hsv", "otsu", "threshold", or "sam2"
45+
method: # tissue segmentation method: "hsv", "otsu", "threshold", or "sam2"; ignored when precomputed tissue masks are provided
4646
sam2_checkpoint_path: # optional when method="sam2"; if empty, hs2p downloads the default AtlasPatch checkpoint from Hugging Face
4747
sam2_config_path: # optional local override for the SAM2 model config; if empty, hs2p downloads the default AtlasPatch config from Hugging Face
4848
sam2_device: "cpu" # device for SAM2 inference, e.g. "cpu", "cuda", or "cuda:0"
@@ -71,12 +71,11 @@ tiling:
7171

7272
speed:
7373
precision: # model inference precision ["fp32", "fp16", "bf16"]; if not set, determined automatically based on model recommendations
74-
num_dataloader_workers: # number of DataLoader worker processes for reading tiles during embedding; defaults to auto (job CPU budget, except cuCIM on-the-fly uses cpu_budget // speed.num_cucim_workers)
74+
num_dataloader_workers: # number of DataLoader worker processes per GPU rank for reading tiles during embedding; defaults to auto (job CPU budget split across GPUs, except cuCIM on-the-fly uses per-GPU budget // speed.num_cucim_workers)
7575
num_gpus: # number of GPUs to use for feature extraction; defaults to all available GPUs
7676
num_preprocessing_workers: # number of workers for hs2p tiling (WSI reading, JPEG encoding, tar writing); defaults to the runtime CPU budget capped at 64
7777
num_cucim_workers: 4 # number of internal cucim threads per read_region call (embedding path, on-the-fly only); DataLoader workers are auto-set to cpu_count // num_cucim_workers
7878
prefetch_factor_embedding: 4 # prefetch factor for tile embedding dataloaders
79-
persistent_workers_embedding: true # keep DataLoader workers alive across epochs/batches
8079

8180
wandb:
8281
enable: false

slide2vec/distributed/direct_embed_worker.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,20 +119,24 @@ def main(argv=None) -> int:
119119
return 0
120120
assigned_slides = [paired_by_sample[sample_id][0] for sample_id in assigned_ids]
121121
assigned_tiling_results = [paired_by_sample[sample_id][1] for sample_id in assigned_ids]
122-
embedded_slides = _compute_embedded_slides(
123-
model,
124-
assigned_slides,
125-
assigned_tiling_results,
126-
preprocessing=preprocessing,
127-
execution=execution,
128-
)
129-
for embedded_slide in embedded_slides:
122+
123+
def _persist_embedded_slide(slide, tiling_result, embedded_slide) -> None:
130124
payload = {
131125
"tile_embeddings": _to_cpu_payload(embedded_slide.tile_embeddings),
132126
"slide_embedding": _to_cpu_payload(embedded_slide.slide_embedding),
133127
"latents": _to_cpu_payload(embedded_slide.latents),
134128
}
135129
torch.save(payload, coordination_dir / f"{embedded_slide.sample_id}.embedded.pt")
130+
131+
_compute_embedded_slides(
132+
model,
133+
assigned_slides,
134+
assigned_tiling_results,
135+
preprocessing=preprocessing,
136+
execution=execution,
137+
on_embedded_slide=_persist_embedded_slide,
138+
collect_results=False,
139+
)
136140
return 0
137141
finally:
138142
if dist.is_available() and dist.is_initialized():

slide2vec/distributed/pipeline_worker.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def main(argv=None) -> int:
1919
import slide2vec.distributed as distributed
2020
from slide2vec.api import Model
2121
from slide2vec.inference import (
22+
_build_incremental_persist_callback,
2223
_compute_embedded_slides,
23-
_persist_embedded_slide,
2424
load_successful_tiled_slides,
2525
)
2626
from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter
@@ -70,21 +70,21 @@ def main(argv=None) -> int:
7070
)
7171
context = activate_progress_reporter(reporter) if reporter is not None else nullcontext()
7272
with context:
73-
embedded_slides = _compute_embedded_slides(
73+
persist_callback, _, _ = _build_incremental_persist_callback(
74+
model=model,
75+
preprocessing=preprocessing,
76+
execution=execution,
77+
process_list_path=None,
78+
)
79+
_compute_embedded_slides(
7480
model,
7581
assigned_slides,
7682
assigned_tiling_results,
7783
preprocessing=preprocessing,
7884
execution=execution,
85+
on_embedded_slide=persist_callback,
86+
collect_results=False,
7987
)
80-
for embedded_slide, tiling_result in zip(embedded_slides, assigned_tiling_results):
81-
_persist_embedded_slide(
82-
model,
83-
embedded_slide,
84-
tiling_result,
85-
preprocessing=preprocessing,
86-
execution=execution,
87-
)
8888
return 0
8989
finally:
9090
if dist.is_available() and dist.is_initialized():

slide2vec/encoders/validation.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,18 @@ def validate_encoder_config(
6363
if not mismatches:
6464
return
6565

66-
message = (
67-
f"Model '{encoder_name}' is configured with "
68-
f"{'; '.join(mismatches)}. "
69-
"Set `model.allow_non_recommended_settings=true` in YAML/CLI or "
70-
"`allow_non_recommended_settings=True` in `Model.from_preset(...)` "
71-
"to continue with a warning."
72-
)
7366
if allow_non_recommended:
74-
logger.warning(message)
67+
logger.warning(
68+
f"Model '{encoder_name}' is configured with "
69+
f"{'; '.join(mismatches)}. "
70+
"Warning-only mode is enabled because "
71+
"`allow_non_recommended_settings=True`."
72+
)
7573
else:
76-
raise ValueError(message)
74+
raise ValueError(
75+
f"Model '{encoder_name}' is configured with "
76+
f"{'; '.join(mismatches)}. "
77+
"Set `model.allow_non_recommended_settings=true` in YAML/CLI or "
78+
"`allow_non_recommended_settings=True` in `Model.from_preset(...)` "
79+
"to continue."
80+
)

0 commit comments

Comments
 (0)