Skip to content

Commit e245e96

Browse files
committed
improve embedding persistence
1 parent 1ec6b06 commit e245e96

3 files changed

Lines changed: 38 additions & 28 deletions

File tree

slide2vec/distributed/pipeline_worker.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def main(argv=None) -> int:
1919
import slide2vec.distributed as distributed
2020
from slide2vec.api import Model
2121
from slide2vec.inference import (
22+
_build_incremental_persist_callback,
2223
_compute_embedded_slides,
23-
_persist_embedded_slide,
2424
load_successful_tiled_slides,
2525
)
2626
from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter
@@ -70,21 +70,20 @@ def main(argv=None) -> int:
7070
)
7171
context = activate_progress_reporter(reporter) if reporter is not None else nullcontext()
7272
with context:
73-
embedded_slides = _compute_embedded_slides(
73+
persist_callback, _, _ = _build_incremental_persist_callback(
74+
model=model,
75+
preprocessing=preprocessing,
76+
execution=execution,
77+
process_list_path=None,
78+
)
79+
_compute_embedded_slides(
7480
model,
7581
assigned_slides,
7682
assigned_tiling_results,
7783
preprocessing=preprocessing,
7884
execution=execution,
85+
on_embedded_slide=persist_callback,
7986
)
80-
for embedded_slide, tiling_result in zip(embedded_slides, assigned_tiling_results):
81-
_persist_embedded_slide(
82-
model,
83-
embedded_slide,
84-
tiling_result,
85-
preprocessing=preprocessing,
86-
execution=execution,
87-
)
8887
return 0
8988
finally:
9089
if dist.is_available() and dist.is_initialized():

slide2vec/encoders/validation.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,18 @@ def validate_encoder_config(
6363
if not mismatches:
6464
return
6565

66-
message = (
67-
f"Model '{encoder_name}' is configured with "
68-
f"{'; '.join(mismatches)}. "
69-
"Set `model.allow_non_recommended_settings=true` in YAML/CLI or "
70-
"`allow_non_recommended_settings=True` in `Model.from_preset(...)` "
71-
"to continue with a warning."
72-
)
7366
if allow_non_recommended:
74-
logger.warning(message)
67+
logger.warning(
68+
f"Model '{encoder_name}' is configured with "
69+
f"{'; '.join(mismatches)}. "
70+
"Warning-only mode is enabled because "
71+
"`allow_non_recommended_settings=True`."
72+
)
7573
else:
76-
raise ValueError(message)
74+
raise ValueError(
75+
f"Model '{encoder_name}' is configured with "
76+
f"{'; '.join(mismatches)}. "
77+
"Set `model.allow_non_recommended_settings=true` in YAML/CLI or "
78+
"`allow_non_recommended_settings=True` in `Model.from_preset(...)` "
79+
"to continue."
80+
)

slide2vec/inference.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,24 +1021,31 @@ def run_pipeline_with_coordinates(
10211021
slide_artifacts=slide_artifacts,
10221022
process_list_path=process_list_path,
10231023
)
1024-
embedded_slides = _compute_embedded_slides(
1025-
model,
1026-
embeddable_slides,
1027-
embeddable_tiling_results,
1024+
local_persist_callback, tile_or_hier_artifacts, slide_artifacts = _build_incremental_persist_callback(
1025+
model=model,
10281026
preprocessing=resolved_preprocessing,
10291027
execution=execution,
1028+
process_list_path=process_list_path,
10301029
)
1031-
tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_local_pipeline_artifacts(
1032-
model=model,
1033-
embedded_slides=embedded_slides,
1034-
tiling_results=embeddable_tiling_results,
1030+
_compute_embedded_slides(
1031+
model,
1032+
embeddable_slides,
1033+
embeddable_tiling_results,
10351034
preprocessing=resolved_preprocessing,
10361035
execution=execution,
1036+
on_embedded_slide=local_persist_callback,
10371037
)
1038+
tile_artifacts: list[TileEmbeddingArtifact] = []
1039+
hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] = []
1040+
for artifact in tile_or_hier_artifacts:
1041+
if isinstance(artifact, HierarchicalEmbeddingArtifact):
1042+
hierarchical_artifacts.append(artifact)
1043+
elif artifact is not None:
1044+
tile_artifacts.append(artifact)
10381045
return RunResult(
10391046
tile_artifacts=tile_artifacts,
10401047
hierarchical_artifacts=hierarchical_artifacts,
1041-
slide_artifacts=slide_artifacts,
1048+
slide_artifacts=list(slide_artifacts),
10421049
process_list_path=process_list_path,
10431050
)
10441051
except Exception as exc:

0 commit comments

Comments
 (0)