30
30
31
31
FEATURES = "features"
32
32
FEATURE_FAISS = "features_faiss"
33
+ Time = float
33
34
34
35
35
36
@dataclass (eq = True , frozen = True ) # Generates __hash__
@@ -88,22 +89,30 @@ def __init__(
88
89
self ._index_path = pjoin (self ._base_dataset_path , "index.faiss" )
89
90
self ._features_path = pjoin (self ._base_dataset_path , "features.faiss.npy" )
90
91
self ._file_lock = pjoin (self ._hf_path , f"{ name } .lock" )
91
- self .last_update = - 1
92
92
# Load the dataset_split from disk.
93
- cached_dataset_split = self ._load_dataset_split ()
94
- if cached_dataset_split is None :
93
+ self ._base_dataset_split_last_update : Time = - 1
94
+ cached_base_dataset_split = self ._load_latest_base_dataset_split ()
95
+ if cached_base_dataset_split is None :
95
96
if dataset_split is None :
96
97
raise ValueError ("No dataset_split cached, can't initialize." )
97
98
log .info ("Initializing tags" , tags = initial_tags )
98
- self ._base_dataset_split , self ._malformed_dataset = self ._load_base_dataset_split (
99
+ self ._base_dataset_split , self ._malformed_dataset = self ._process_base_dataset_split (
99
100
dataset_split
100
101
)
101
102
self ._save_base_dataset_split ()
102
103
else :
103
- self ._base_dataset_split , self ._malformed_dataset = cached_dataset_split
104
+ self ._base_dataset_split , self ._malformed_dataset = cached_base_dataset_split
104
105
self ._prediction_tables : Dict [PredictionTableKey , Dataset ] = {}
106
+ self ._prediction_tables_last_update : Dict [PredictionTableKey , Time ] = defaultdict (int )
105
107
self ._validate_columns ()
106
108
109
+ @property
110
+ def last_update (self ) -> Time :
111
+ return max (
112
+ [self ._base_dataset_split_last_update ]
113
+ + list (self ._prediction_tables_last_update .values ())
114
+ )
115
+
107
116
def get_dataset_split (self , table_key : Optional [PredictionTableKey ] = None ) -> Dataset :
108
117
"""Return a dataset_split concatenated with the config predictions.
109
118
@@ -113,6 +122,11 @@ def get_dataset_split(self, table_key: Optional[PredictionTableKey] = None) -> D
113
122
Returns:
114
123
Dataset with predictions if available.
115
124
"""
125
+ current_last_update = self ._base_dataset_split_last_update
126
+ newest_base_ds , last_update = self .load_latest_cache (self ._save_path , current_last_update )
127
+ if newest_base_ds :
128
+ self ._base_dataset_split = newest_base_ds
129
+ self ._base_dataset_split_last_update = last_update
116
130
if table_key is None :
117
131
return self ._base_dataset_split
118
132
return self .dataset_split_with_predictions (table_key = table_key )
@@ -144,16 +158,16 @@ def classification_columns(self):
144
158
DatasetColumn .postprocessed_prediction ,
145
159
}
146
160
147
- def _load_dataset_split (self ) -> Optional [Tuple [Dataset , Dataset ]]:
161
+ def _load_latest_base_dataset_split (self ) -> Optional [Tuple [Dataset , Dataset ]]:
148
162
if os .path .exists (self ._save_path ):
149
- log . debug ( "Reloading base dataset_split." , path = self ._save_path )
150
- with FileLock (self ._file_lock ):
151
- ds = self .load_cache (self ._save_path )
152
- malformed = self . load_cache ( self . _malformed_path )
153
- return ds , malformed
163
+ current_update = self ._base_dataset_split_last_update
164
+ base_ds , last_update = self . load_latest_cache (self ._save_path , current_update )
165
+ malformed , _ = self .load_latest_cache (self ._malformed_path , current_update )
166
+ self . _base_dataset_split_last_update = last_update
167
+ return assert_not_none ( base_ds ), assert_not_none ( malformed )
154
168
return None
155
169
156
- def _load_base_dataset_split (self , dataset_split ) -> Tuple [Dataset , Dataset ]:
170
+ def _process_base_dataset_split (self , dataset_split ) -> Tuple [Dataset , Dataset ]:
157
171
base_dataset_split , malformed_dataset = self ._split_malformed (dataset_split )
158
172
159
173
# Checking if a persistent id was provided.
@@ -171,10 +185,13 @@ def _load_base_dataset_split(self, dataset_split) -> Tuple[Dataset, Dataset]:
171
185
def _save_base_dataset_split (self ):
172
186
# NOTE: We should not have the Index in `self.dataset_split`.
173
187
with FileLock (self ._file_lock ):
174
- self ._base_dataset_split .save_to_disk (self ._get_new_version_path (self ._save_path ))
175
- self ._malformed_dataset .save_to_disk (self ._get_new_version_path (self ._malformed_path ))
176
- self .last_update = time .time ()
177
- log .debug ("Base dataset split saved." , path = self ._save_path )
188
+ version_path , last_update = self ._get_new_version_path (self ._save_path )
189
+ self ._base_dataset_split .save_to_disk (version_path )
190
+ self ._malformed_dataset .save_to_disk (
191
+ self ._get_new_version_path (self ._malformed_path )[0 ]
192
+ )
193
+ self ._base_dataset_split_last_update = last_update
194
+ log .debug ("Base dataset split saved." , path = version_path )
178
195
179
196
def get_dataset_split_with_class_names (
180
197
self , table_key : Optional [PredictionTableKey ] = None
@@ -476,37 +493,38 @@ def _get_prediction_table(self, table_key: PredictionTableKey) -> Dataset:
476
493
A table for this key, will create one if it doesn't exists.
477
494
"""
478
495
479
- path = self ._prediction_path (table_key = table_key )
480
- if table_key not in self ._prediction_tables and os .path .exists (path ):
481
- with FileLock (self ._file_lock ):
482
- self ._prediction_tables [table_key ] = self .load_cache (path )
483
- elif table_key not in self ._prediction_tables :
496
+ pred_path = self ._prediction_path (table_key = table_key )
497
+ if not os .path .exists (pred_path ):
484
498
empty_ds = Dataset .from_dict ({"pred_row_idx" : list (range (self .num_rows ))})
485
499
self ._prediction_tables [table_key ] = self ._init_dataset_split (
486
500
empty_ds , self ._prediction_tags
487
501
).remove_columns ([DatasetColumn .row_idx ])
488
502
self .save_prediction_table (table_key )
503
+ else :
504
+ current_last_update = self ._prediction_tables_last_update [table_key ]
505
+ newest_pred_ds , last_update = self .load_latest_cache (pred_path , current_last_update )
506
+ if newest_pred_ds :
507
+ self ._prediction_tables [table_key ] = newest_pred_ds
508
+ self ._prediction_tables_last_update [table_key ] = last_update
489
509
return self ._prediction_tables [table_key ]
490
510
491
- def _prediction_path (self , table_key : PredictionTableKey ):
511
+ def _prediction_path (self , table_key : PredictionTableKey ) -> str :
492
512
"""Path to table file."""
493
513
folder = pjoin (self ._hf_path , "prediction_tables" )
494
514
table_name = "_" .join (
495
515
f"{ k } ={ v :.2f} " if type (v ) is float else f"{ k } ={ v } " for k , v in asdict (table_key ).items ()
496
516
)
497
517
os .makedirs (folder , exist_ok = True )
498
- pt = pjoin (
499
- folder ,
500
- f"{ table_name } _cache_ds.arrow" ,
501
- )
502
- return pt
518
+ return pjoin (folder , f"{ table_name } _cache_ds.arrow" )
503
519
504
520
def save_prediction_table (self , table_key : PredictionTableKey ):
505
521
"""Save the prediction to disk."""
506
522
with FileLock (self ._file_lock ):
507
- pt = self ._prediction_path (table_key = table_key )
508
- self ._prediction_tables [table_key ].save_to_disk (self ._get_new_version_path (pt ))
509
- self .last_update = int (time .time ())
523
+ pred_path = self ._prediction_path (table_key = table_key )
524
+ version_path , last_update = self ._get_new_version_path (pred_path )
525
+ self ._prediction_tables [table_key ].save_to_disk (version_path )
526
+ self ._prediction_tables_last_update [table_key ] = last_update
527
+ log .debug ("Prediction dataset split saved." , path = version_path )
510
528
511
529
def add_column_to_prediction_table (
512
530
self , key : str , features : List [Any ], table_key : PredictionTableKey , ** kwargs
@@ -568,15 +586,20 @@ def _validate_columns(self):
568
586
if len (self ._base_dataset_split ) == 0 :
569
587
raise AzimuthValidationError (f"No rows found from dataset { self .name } " )
570
588
571
- def load_cache (self , folder : str ) -> Dataset :
572
- """
573
- Load the latest cache.
589
+ def load_latest_cache (
590
+ self , folder : str , current_last_update : Time
591
+ ) -> Tuple [Optional [Dataset ], Time ]:
592
+ """Load the latest dataset saved in a given folder.
593
+
594
+ Notes:
595
+ Only load when the dataset is more recent than what is current saved in the DM.
574
596
575
597
Args:
576
598
folder: Where to look for.
599
+ current_last_update: Current version saved in the DM.
577
600
578
601
Returns:
579
- The cached dataset or the original .
602
+ The cached dataset, if more recent. Otherwise None .
580
603
581
604
Raises:
582
605
FileNotFoundError if no cache found.
@@ -591,13 +614,21 @@ def load_cache(self, folder: str) -> Dataset:
591
614
),
592
615
None ,
593
616
)
594
- if cache_file :
595
- log .debug (f"Loading latest cache: { cache_file .split ('/' )[- 1 ]} " )
596
- return Dataset .load_from_disk (cache_file )
597
- raise FileNotFoundError (f"No previously saved dataset in { folder } " )
617
+ if not cache_file :
618
+ raise FileNotFoundError (f"No previously saved dataset in { folder } " )
619
+
620
+ last_update = float (cache_file .split ("_" )[- 1 ][:- 6 ])
621
+ if current_last_update >= last_update :
622
+ return None , - 1
623
+
624
+ with FileLock (self ._file_lock ):
625
+ log .debug ("Loading latest dataset in cache." , path = cache_file )
626
+ return Dataset .load_from_disk (cache_file ), last_update
598
627
599
- def _get_new_version_path (self , directory ):
600
- return pjoin (directory , f"version_{ time .time ()} .arrow" )
628
+ @staticmethod
629
+ def _get_new_version_path (directory ) -> Tuple [str , Time ]:
630
+ now = time .time ()
631
+ return pjoin (directory , f"version_{ now } .arrow" ), now
601
632
602
633
def save_proposed_actions_to_csv (self ) -> str :
603
634
"""Save proposed actions to a csv file.
0 commit comments