Skip to content

Commit 4d33219

Browse files
committed
Modify DM so it always load the latest cache
1 parent 9cb95e8 commit 4d33219

File tree

7 files changed

+86
-51
lines changed

7 files changed

+86
-51
lines changed

azimuth/dataset_split_manager.py

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
FEATURES = "features"
3232
FEATURE_FAISS = "features_faiss"
33+
Time = float
3334

3435

3536
@dataclass(eq=True, frozen=True) # Generates __hash__
@@ -88,22 +89,30 @@ def __init__(
8889
self._index_path = pjoin(self._base_dataset_path, "index.faiss")
8990
self._features_path = pjoin(self._base_dataset_path, "features.faiss.npy")
9091
self._file_lock = pjoin(self._hf_path, f"{name}.lock")
91-
self.last_update = -1
9292
# Load the dataset_split from disk.
93-
cached_dataset_split = self._load_dataset_split()
94-
if cached_dataset_split is None:
93+
self._base_dataset_split_last_update: Time = -1
94+
cached_base_dataset_split = self._load_latest_base_dataset_split()
95+
if cached_base_dataset_split is None:
9596
if dataset_split is None:
9697
raise ValueError("No dataset_split cached, can't initialize.")
9798
log.info("Initializing tags", tags=initial_tags)
98-
self._base_dataset_split, self._malformed_dataset = self._load_base_dataset_split(
99+
self._base_dataset_split, self._malformed_dataset = self._process_base_dataset_split(
99100
dataset_split
100101
)
101102
self._save_base_dataset_split()
102103
else:
103-
self._base_dataset_split, self._malformed_dataset = cached_dataset_split
104+
self._base_dataset_split, self._malformed_dataset = cached_base_dataset_split
104105
self._prediction_tables: Dict[PredictionTableKey, Dataset] = {}
106+
self._prediction_tables_last_update: Dict[PredictionTableKey, Time] = defaultdict(int)
105107
self._validate_columns()
106108

109+
@property
110+
def last_update(self) -> Time:
111+
return max(
112+
[self._base_dataset_split_last_update]
113+
+ list(self._prediction_tables_last_update.values())
114+
)
115+
107116
def get_dataset_split(self, table_key: Optional[PredictionTableKey] = None) -> Dataset:
108117
"""Return a dataset_split concatenated with the config predictions.
109118
@@ -113,6 +122,11 @@ def get_dataset_split(self, table_key: Optional[PredictionTableKey] = None) -> D
113122
Returns:
114123
Dataset with predictions if available.
115124
"""
125+
current_last_update = self._base_dataset_split_last_update
126+
newest_base_ds, last_update = self.load_latest_cache(self._save_path, current_last_update)
127+
if newest_base_ds:
128+
self._base_dataset_split = newest_base_ds
129+
self._base_dataset_split_last_update = last_update
116130
if table_key is None:
117131
return self._base_dataset_split
118132
return self.dataset_split_with_predictions(table_key=table_key)
@@ -144,16 +158,16 @@ def classification_columns(self):
144158
DatasetColumn.postprocessed_prediction,
145159
}
146160

147-
def _load_dataset_split(self) -> Optional[Tuple[Dataset, Dataset]]:
161+
def _load_latest_base_dataset_split(self) -> Optional[Tuple[Dataset, Dataset]]:
148162
if os.path.exists(self._save_path):
149-
log.debug("Reloading base dataset_split.", path=self._save_path)
150-
with FileLock(self._file_lock):
151-
ds = self.load_cache(self._save_path)
152-
malformed = self.load_cache(self._malformed_path)
153-
return ds, malformed
163+
current_update = self._base_dataset_split_last_update
164+
base_ds, last_update = self.load_latest_cache(self._save_path, current_update)
165+
malformed, _ = self.load_latest_cache(self._malformed_path, current_update)
166+
self._base_dataset_split_last_update = last_update
167+
return assert_not_none(base_ds), assert_not_none(malformed)
154168
return None
155169

156-
def _load_base_dataset_split(self, dataset_split) -> Tuple[Dataset, Dataset]:
170+
def _process_base_dataset_split(self, dataset_split) -> Tuple[Dataset, Dataset]:
157171
base_dataset_split, malformed_dataset = self._split_malformed(dataset_split)
158172

159173
# Checking if a persistent id was provided.
@@ -171,10 +185,13 @@ def _load_base_dataset_split(self, dataset_split) -> Tuple[Dataset, Dataset]:
171185
def _save_base_dataset_split(self):
172186
# NOTE: We should not have the Index in `self.dataset_split`.
173187
with FileLock(self._file_lock):
174-
self._base_dataset_split.save_to_disk(self._get_new_version_path(self._save_path))
175-
self._malformed_dataset.save_to_disk(self._get_new_version_path(self._malformed_path))
176-
self.last_update = time.time()
177-
log.debug("Base dataset split saved.", path=self._save_path)
188+
version_path, last_update = self._get_new_version_path(self._save_path)
189+
self._base_dataset_split.save_to_disk(version_path)
190+
self._malformed_dataset.save_to_disk(
191+
self._get_new_version_path(self._malformed_path)[0]
192+
)
193+
self._base_dataset_split_last_update = last_update
194+
log.debug("Base dataset split saved.", path=version_path)
178195

179196
def get_dataset_split_with_class_names(
180197
self, table_key: Optional[PredictionTableKey] = None
@@ -476,37 +493,38 @@ def _get_prediction_table(self, table_key: PredictionTableKey) -> Dataset:
476493
A table for this key, will create one if it doesn't exists.
477494
"""
478495

479-
path = self._prediction_path(table_key=table_key)
480-
if table_key not in self._prediction_tables and os.path.exists(path):
481-
with FileLock(self._file_lock):
482-
self._prediction_tables[table_key] = self.load_cache(path)
483-
elif table_key not in self._prediction_tables:
496+
pred_path = self._prediction_path(table_key=table_key)
497+
if not os.path.exists(pred_path):
484498
empty_ds = Dataset.from_dict({"pred_row_idx": list(range(self.num_rows))})
485499
self._prediction_tables[table_key] = self._init_dataset_split(
486500
empty_ds, self._prediction_tags
487501
).remove_columns([DatasetColumn.row_idx])
488502
self.save_prediction_table(table_key)
503+
else:
504+
current_last_update = self._prediction_tables_last_update[table_key]
505+
newest_pred_ds, last_update = self.load_latest_cache(pred_path, current_last_update)
506+
if newest_pred_ds:
507+
self._prediction_tables[table_key] = newest_pred_ds
508+
self._prediction_tables_last_update[table_key] = last_update
489509
return self._prediction_tables[table_key]
490510

491-
def _prediction_path(self, table_key: PredictionTableKey):
511+
def _prediction_path(self, table_key: PredictionTableKey) -> str:
492512
"""Path to table file."""
493513
folder = pjoin(self._hf_path, "prediction_tables")
494514
table_name = "_".join(
495515
f"{k}={v:.2f}" if type(v) is float else f"{k}={v}" for k, v in asdict(table_key).items()
496516
)
497517
os.makedirs(folder, exist_ok=True)
498-
pt = pjoin(
499-
folder,
500-
f"{table_name}_cache_ds.arrow",
501-
)
502-
return pt
518+
return pjoin(folder, f"{table_name}_cache_ds.arrow")
503519

504520
def save_prediction_table(self, table_key: PredictionTableKey):
505521
"""Save the prediction to disk."""
506522
with FileLock(self._file_lock):
507-
pt = self._prediction_path(table_key=table_key)
508-
self._prediction_tables[table_key].save_to_disk(self._get_new_version_path(pt))
509-
self.last_update = int(time.time())
523+
pred_path = self._prediction_path(table_key=table_key)
524+
version_path, last_update = self._get_new_version_path(pred_path)
525+
self._prediction_tables[table_key].save_to_disk(version_path)
526+
self._prediction_tables_last_update[table_key] = last_update
527+
log.debug("Prediction dataset split saved.", path=version_path)
510528

511529
def add_column_to_prediction_table(
512530
self, key: str, features: List[Any], table_key: PredictionTableKey, **kwargs
@@ -568,15 +586,20 @@ def _validate_columns(self):
568586
if len(self._base_dataset_split) == 0:
569587
raise AzimuthValidationError(f"No rows found from dataset {self.name}")
570588

571-
def load_cache(self, folder: str) -> Dataset:
572-
"""
573-
Load the latest cache.
589+
def load_latest_cache(
590+
self, folder: str, current_last_update: Time
591+
) -> Tuple[Optional[Dataset], Time]:
592+
"""Load the latest dataset saved in a given folder.
593+
594+
Notes:
595+
Only load when the dataset is more recent than what is current saved in the DM.
574596
575597
Args:
576598
folder: Where to look for.
599+
current_last_update: Current version saved in the DM.
577600
578601
Returns:
579-
The cached dataset or the original.
602+
The cached dataset, if more recent. Otherwise None.
580603
581604
Raises:
582605
FileNotFoundError if no cache found.
@@ -591,13 +614,21 @@ def load_cache(self, folder: str) -> Dataset:
591614
),
592615
None,
593616
)
594-
if cache_file:
595-
log.debug(f"Loading latest cache: {cache_file.split('/')[-1]}")
596-
return Dataset.load_from_disk(cache_file)
597-
raise FileNotFoundError(f"No previously saved dataset in {folder}")
617+
if not cache_file:
618+
raise FileNotFoundError(f"No previously saved dataset in {folder}")
619+
620+
last_update = float(cache_file.split("_")[-1][:-6])
621+
if current_last_update >= last_update:
622+
return None, -1
623+
624+
with FileLock(self._file_lock):
625+
log.debug("Loading latest dataset in cache.", path=cache_file)
626+
return Dataset.load_from_disk(cache_file), last_update
598627

599-
def _get_new_version_path(self, directory):
600-
return pjoin(directory, f"version_{time.time()}.arrow")
628+
@staticmethod
629+
def _get_new_version_path(directory) -> Tuple[str, Time]:
630+
now = time.time()
631+
return pjoin(directory, f"version_{now}.arrow"), now
601632

602633
def save_proposed_actions_to_csv(self) -> str:
603634
"""Save proposed actions to a csv file.

azimuth/modules/base_classes/expirable_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ class ExpirableMixin(abc.ABC):
1515

1616
_time: float
1717

18-
def is_expired(self, compared_to: int):
18+
def is_expired(self, compared_to: float):
1919
# Check if this Module results are expired.
2020
return self._time < compared_to

azimuth/task_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_task(
103103
task_name: SupportedTask,
104104
dataset_split_name: DatasetSplitName,
105105
mod_options: Optional[ModuleOptions] = None,
106-
last_update: int = -1,
106+
last_update: float = -1,
107107
dependencies: Optional[List[DaskModule]] = None,
108108
) -> Tuple[str, Optional[DaskModule]]:
109109
"""Get the task `name` run on indices.

azimuth/utils/routers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from azimuth.utils.project import predictions_available
2929

3030

31-
def get_last_update(dataset_split_managers: List[Optional[DatasetSplitManager]]) -> int:
31+
def get_last_update(dataset_split_managers: List[Optional[DatasetSplitManager]]) -> float:
3232
last_update = max([dsm.last_update if dsm else -1 for dsm in dataset_split_managers])
3333

3434
return last_update
@@ -97,7 +97,7 @@ def get_standard_task_result(
9797
dataset_split_name: DatasetSplitName,
9898
task_manager: TaskManager,
9999
mod_options: Optional[ModuleOptions] = None,
100-
last_update: int = -1,
100+
last_update: float = -1,
101101
):
102102
"""Generate the task object and get the result for standard tasks.
103103

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ def clinc_text_config(tmp_path):
212212
pipelines=[PIPELINE_CFG],
213213
artifact_path=str(tmp_path),
214214
)
215-
clinc_text_config.pipelines[0].postprocessors[0].temperature = 1
216-
clinc_text_config.pipelines[0].postprocessors[0].kwargs["temperature"] = 1
215+
clinc_text_config.pipelines[0].postprocessors[0].temperature = 1.0
216+
clinc_text_config.pipelines[0].postprocessors[0].kwargs["temperature"] = 1.0
217217
clinc_text_config.pipelines[0].postprocessors[-1].threshold = 0.5
218218
clinc_text_config.pipelines[0].postprocessors[-1].kwargs["threshold"] = 0.5
219219
return clinc_text_config

tests/test_dataset_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def test_caching(a_text_dataset, simple_text_config):
298298

299299
assert len(glob(pjoin(pred_path, "version_*.arrow"))) == 0, "Some preds were cached at init?"
300300
assert len(glob(pjoin(dm1._save_path, "version_*.arrow"))) == 1, "Dataset not saved on disk"
301+
initial_last_update = dm1.last_update
301302

302303
dm1.add_column("pink", list(range(len(a_text_dataset))))
303304
assert (
@@ -306,6 +307,8 @@ def test_caching(a_text_dataset, simple_text_config):
306307
assert (
307308
len(glob(pjoin(dm1._save_path, "version_*.arrow"))) == 2
308309
), "Dataset not saved on disk when added a column"
310+
modified_last_update = dm1.last_update
311+
assert initial_last_update < modified_last_update, "Last update should be updated."
309312

310313
dm1.add_column_to_prediction_table(
311314
"apple", list(range(len(a_text_dataset))), table_key=simple_table_key
@@ -317,6 +320,7 @@ def test_caching(a_text_dataset, simple_text_config):
317320
assert (
318321
len(glob(pjoin(dm1._save_path, "version_*.arrow"))) == 2
319322
), "New version of based dataset detected when adding prediction"
323+
assert modified_last_update < dm1.last_update, "Last update should be updated."
320324

321325
dm2 = DatasetSplitManager(
322326
DatasetSplitName.eval,

tests/test_modules/test_model_contracts/test_model_contract_module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,18 +194,18 @@ def test_pred_smart_tags(clinc_text_config):
194194

195195
mod.save_result(res, dm)
196196

197-
clinc_table_key = get_table_key(clinc_text_config)
198-
assert SmartTag.correct_top_3 in dm.get_dataset_split(clinc_table_key).column_names
199-
assert SmartTag.correct_low_conf in dm.get_dataset_split(clinc_table_key).column_names
197+
table_key = get_table_key(clinc_text_config)
198+
ds = dm.get_dataset_split(table_key)
199+
assert SmartTag.correct_top_3 and SmartTag.correct_low_conf in ds.column_names
200200

201-
assert dm.get_dataset_split(clinc_table_key)[SmartTag.correct_top_3] == [
201+
assert ds[SmartTag.correct_top_3] == [
202202
False,
203203
True,
204204
False,
205205
True,
206206
False,
207207
], "Problem with correct_top_3 smart tag"
208-
assert dm.get_dataset_split(clinc_table_key)[SmartTag.correct_low_conf] == [
208+
assert ds[SmartTag.correct_low_conf] == [
209209
False,
210210
False,
211211
True,

0 commit comments

Comments
 (0)