Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions ionerdss/nerdss_analysis/data_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,17 @@ class DataIO:
def __init__(self):
"""
Initialize a cache for storing data to avoid repeated file reads.
Also initialize last-modified time storage for cache invalidation.
"""
self._cache = {}
self._modified = {}

def clear_cache(self):
"""
Clear the cached data.
"""
self._cache = {}
self._modified = {}

def get_copy_numbers(self, sim_dir: str) -> Optional[pd.DataFrame]:
"""
Expand All @@ -361,12 +364,19 @@ def get_copy_numbers(self, sim_dir: str) -> Optional[pd.DataFrame]:
Optional[pd.DataFrame]: DataFrame containing the data, or None if file not found
"""
cache_key = (sim_dir, "copy_numbers")
data_file = os.path.join(sim_dir, "DATA", "copy_numbers_time.dat")
mod_time = os.path.getmtime(data_file) # get last modified timestamp

if cache_key in self._cache:
return self._cache[cache_key]
if self._modified[cache_key] == mod_time:
return self._cache[cache_key]
else:
print("File modified, reloading from disk.")

result = read_copy_numbers(sim_dir)
if result is not None:
self._cache[cache_key] = result
self._modified[cache_key] = mod_time
return result

def get_histogram_complexes(self, sim_dir: str) -> Dict[str, Any]:
Expand All @@ -380,11 +390,18 @@ def get_histogram_complexes(self, sim_dir: str) -> Dict[str, Any]:
Dict[str, Any]: Dictionary containing time series and complex data
"""
cache_key = (sim_dir, "histogram_complexes")
data_file = os.path.join(sim_dir, "DATA", "histogram_complexes_time.dat")
mod_time = os.path.getmtime(data_file) # get last modified timestamp

if cache_key in self._cache:
return self._cache[cache_key]
if self._modified[cache_key] == mod_time:
return self._cache[cache_key]
else:
print("File modified, reloading from disk.")

result = read_histogram_complexes(sim_dir)
self._cache[cache_key] = result
self._modified[cache_key] = mod_time
return result

def get_transition_matrix(self, sim_dir: str, time_frame: Optional[Tuple[float, float]] = None) -> Tuple[Optional[np.ndarray], Optional[Dict[int, List[float]]]]:
Expand All @@ -405,13 +422,19 @@ def get_transition_matrix(self, sim_dir: str, time_frame: Optional[Tuple[float,
if time_frame is not None:
time_str = f"{time_frame[0]}-{time_frame[1]}"
cache_key = (sim_dir, "transition_matrix", time_str)
data_file = os.path.join(sim_dir, "DATA", "transition_matrix_time.dat")
mod_time = os.path.getmtime(data_file) # get last modified timestamp

if cache_key in self._cache:
return self._cache[cache_key]
if self._modified[cache_key] == mod_time:
return self._cache[cache_key]
else:
print("File modified, reloading from disk.")

result = read_transition_matrix(sim_dir, time_frame)
if result[0] is not None:
self._cache[cache_key] = result
self._modified[cache_key] = mod_time
return result

def get_multiple_copy_numbers(self, sim_dirs: List[str]) -> List[Optional[pd.DataFrame]]:
Expand Down Expand Up @@ -450,4 +473,4 @@ def get_multiple_transition_matrices(self, sim_dirs: List[str], time_frame: Opti
List[Tuple[Optional[np.ndarray], Optional[Dict[int, List[float]]]]]:
List of tuples containing transition matrix and lifetime data
"""
return [self.get_transition_matrix(sim_dir, time_frame) for sim_dir in sim_dirs]
return [self.get_transition_matrix(sim_dir, time_frame) for sim_dir in sim_dirs]