diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b86cb8520..0593f63d03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added support for `tidy3d-extras`, an optional plugin that enables more accurate local mode solving via subpixel averaging. +- Added configurable local simulation result caching with checksum validation, eviction limits, and per-call overrides across `web.run`, `web.load`, and job workflows. ### Changed - Improved performance of antenna metrics calculation by utilizing cached wave amplitude calculations instead of recomputing wave amplitudes for each port excitation in the `TerminalComponentModelerData`. diff --git a/docs/index.rst b/docs/index.rst index 5cada50d0e..d53ffef1af 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -168,6 +168,13 @@ This will produce the following plot, which visualizes the electromagnetic field You can now postprocess simulation data using the same python session, or view the results of this simulation on our web-based `graphical user interface (GUI) `_. +.. tip:: + + Repeated runs of the same simulation can reuse solver results by enabling the optional + local cache: ``td.config.simulation_cache.enabled = True``. The cache location and limits are + configurable (see ``~/.tidy3d/config``), entries are checksum-validated, and you can clear + all stored artifacts with ``tidy3d.web.cache.clear()``. + .. `TODO: open example in colab `_ @@ -262,4 +269,3 @@ Contents - diff --git a/tests/test_components/autograd/test_autograd.py b/tests/test_components/autograd/test_autograd.py index cfd55079b0..95204a601f 100644 --- a/tests/test_components/autograd/test_autograd.py +++ b/tests/test_components/autograd/test_autograd.py @@ -662,7 +662,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None: # args = [("polyslab", "mode")] -def get_functions(structure_key: str, monitor_key: str) -> typing.Callable: +def get_functions(structure_key: str, monitor_key: str) -> dict[str, typing.Callable]: if structure_key == ALL_KEY: structure_keys = structure_keys_ else: diff --git a/tests/test_web/test_simulation_cache.py b/tests/test_web/test_simulation_cache.py new file mode 100644 index 0000000000..161e042afb --- /dev/null +++ b/tests/test_web/test_simulation_cache.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +import tidy3d as td +from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0 +from tidy3d import config +from tidy3d.web import Job, common, run_async +from tidy3d.web.api import webapi as web +from tidy3d.web.api.container import WebContainer +from tidy3d.web.cache import ( + CACHE_ARTIFACT_NAME, + get_cache, + resolve_simulation_cache, +) + +common.CONNECTION_RETRY_TIME = 0.1 + +MOCK_TASK_ID = "task-xyz" +# --- Fake pipeline global maps / queue --- +TASK_TO_SIM: dict[str, td.Simulation] = {} # task_id -> Simulation +PATH_TO_SIM: dict[str, td.Simulation] = {} # artifact path -> Simulation + + +def _reset_fake_maps(): + TASK_TO_SIM.clear() + PATH_TO_SIM.clear() + + +class _FakeStubData: + def __init__(self, simulation: td.Simulation): + self.simulation = simulation + + +@pytest.fixture +def basic_simulation(): + pulse = td.GaussianPulse(freq0=200e12, fwidth=20e12) + pt_dipole = td.PointDipole(source_time=pulse, polarization="Ex") + return td.Simulation( + size=(1, 1, 1), + grid_spec=td.GridSpec.auto(wavelength=1.0), + run_time=1e-12, + sources=[pt_dipole], + ) + + +@pytest.fixture(autouse=True) +def fake_data(monkeypatch, basic_simulation): + """Patch postprocess to return stub data bound to the correct simulation.""" + calls = {"postprocess": 0} + + def _fake_postprocess(path: str, lazy: bool = False): + calls["postprocess"] += 1 + p = Path(path) + sim = PATH_TO_SIM.get(str(p)) + if sim is None: + # Try to recover task_id from file payload written by _fake_download + try: + txt = p.read_text() + if "payload:" in txt: + task_id = txt.split("payload:", 1)[1].strip() + sim = TASK_TO_SIM.get(task_id) + except Exception: + pass + if sim is None: + # Last-resort fallback (keeps tests from crashing even if mapping failed) + sim = basic_simulation + return _FakeStubData(sim) + + monkeypatch.setattr(web.Tidy3dStubData, "postprocess", staticmethod(_fake_postprocess)) + return calls + + +def _patch_run_pipeline(monkeypatch): + """Patch upload, start, monitor, and download to avoid network calls and map sims.""" + counters = {"upload": 0, "start": 0, "monitor": 0, "download": 0} + _reset_fake_maps() # isolate between tests + + def _extract_simulation(kwargs): + """Extract the first td.Simulation object from upload kwargs.""" + if "simulation" in kwargs and isinstance(kwargs["simulation"], td.Simulation): + return kwargs["simulation"] + if "simulations" in kwargs: + sims = kwargs["simulations"] + if isinstance(sims, dict): + for sim in sims.values(): + if isinstance(sim, td.Simulation): + return sim + elif isinstance(sims, (list, tuple)): + for sim in sims: + if isinstance(sim, td.Simulation): + return sim + return None + + def _fake_upload(**kwargs): + counters["upload"] += 1 + task_id = f"{MOCK_TASK_ID}{kwargs['simulation']._hash_self()}" + sim = _extract_simulation(kwargs) + if sim is not None: + TASK_TO_SIM[task_id] = sim + return task_id + + def _fake_start(task_id, **kwargs): + counters["start"] += 1 + + def _fake_monitor(task_id, verbose=True): + counters["monitor"] += 1 + + def _fake_download(*, task_id, path, **kwargs): + counters["download"] += 1 + # Ensure we have a simulation for this task id (even if upload wasn't called) + sim = TASK_TO_SIM.get(task_id) + Path(path).write_text(f"payload:{task_id}") + if sim is not None: + PATH_TO_SIM[str(Path(path))] = sim + + def _fake__check_folder(*args, **kwargs): + pass + + def _fake_status(self): + return "success" + + monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder) + monkeypatch.setattr(web, "upload", _fake_upload) + monkeypatch.setattr(web, "start", _fake_start) + monkeypatch.setattr(web, "monitor", _fake_monitor) + monkeypatch.setattr(web, "download", _fake_download) + monkeypatch.setattr(web, "estimate_cost", lambda *args, **kwargs: 0.0) + monkeypatch.setattr(Job, "status", property(_fake_status)) + monkeypatch.setattr( + web, + "get_info", + lambda task_id, verbose=True: type( + "_Info", (), {"solverVersion": "solver-1", "taskType": "FDTD"} + )(), + ) + return counters + + +def _reset_counters(counters: dict[str, int]) -> None: + for key in counters: + counters[key] = 0 + + +def _test_run_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data): + counters = _patch_run_pipeline(monkeypatch) + out_path = tmp_path / "result.hdf5" + get_cache().clear() + + data = web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True) + assert isinstance(data, _FakeStubData) + assert counters == {"upload": 1, "start": 1, "monitor": 1, "download": 1} + + _reset_counters(counters) + data2 = web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True) + assert isinstance(data2, _FakeStubData) + assert counters == {"upload": 0, "start": 0, "monitor": 0, "download": 0} + + +def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path): + counters = _patch_run_pipeline(monkeypatch) + monkeypatch.setattr(config.simulation_cache, "max_entries", 128) + monkeypatch.setattr(config.simulation_cache, "max_size_gb", 10) + cache = resolve_simulation_cache(use_cache=True) + cache.clear() + _reset_fake_maps() + + _reset_counters(counters) + sim2 = basic_simulation.updated_copy(shutoff=1e-4) + sim3 = basic_simulation.updated_copy(shutoff=1e-3) + + data = run_async( + {"task1": basic_simulation, "task2": sim2}, use_cache=True, path_dir=str(tmp_path) + ) + data_task1 = data["task1"] # access to store in cache + data_task2 = data["task2"] # access to store in cache + assert counters["download"] == 2 + assert isinstance(data_task1, _FakeStubData) + assert isinstance(data_task2, _FakeStubData) + assert len(cache) == 2 + + _reset_counters(counters) + run_async({"task1": basic_simulation, "task2": sim2}, use_cache=True, path_dir=str(tmp_path)) + assert counters["download"] == 0 + assert isinstance(data_task1, _FakeStubData) + assert len(cache) == 2 + + _reset_counters(counters) + data = run_async( + {"task1": basic_simulation, "task3": sim3}, use_cache=True, path_dir=str(tmp_path) + ) + + data_task1 = data["task1"] + data_task2 = data["task3"] # access to store in cache + print(counters["download"]) + assert counters["download"] == 1 # sim3 is new + assert isinstance(data_task1, _FakeStubData) + assert isinstance(data_task2, _FakeStubData) + assert len(cache) == 3 + + +def _test_job_run_cache(monkeypatch, basic_simulation): + counters = _patch_run_pipeline(monkeypatch) + cache = resolve_simulation_cache(use_cache=True) + cache.clear() + job = Job(simulation=basic_simulation, use_cache=True, task_name="test") + job.run() + + assert len(cache) == 1 + + _reset_counters(counters) + + job2 = Job(simulation=basic_simulation, use_cache=True, task_name="test") + job2.run() + assert len(cache) == 1 + assert counters["download"] == 0 + + +def _test_autograd_cache(monkeypatch): + counters = _patch_run_pipeline(monkeypatch) + cache = resolve_simulation_cache(use_cache=True) + cache.clear() + + functions = get_functions(ALL_KEY, "mode") + make_sim = functions["sim"] + sim = make_sim(params0) + web.run(sim, use_cache=True) + assert counters["download"] == 1 + assert len(cache) == 1 + + _reset_counters(counters) + sim = make_sim(params0) + web.run(sim, use_cache=True) + assert counters["download"] == 0 + assert len(cache) == 1 + + +def _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data): + get_cache().clear() + counters = _patch_run_pipeline(monkeypatch) + out_path = tmp_path / "load.hdf5" + + cache = get_cache() + + web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True) + assert counters["download"] == 1 + assert len(cache) == 1 + + _reset_counters(counters) + data = web.load(None, path=str(out_path), from_cache=True) + assert isinstance(data, _FakeStubData) + assert counters["download"] == 0 # served from cache + assert len(cache) == 1 # still 1 item in cache + + +def _test_checksum_mismatch_triggers_refresh(monkeypatch, tmp_path, basic_simulation): + out_path = tmp_path / "checksum.hdf5" + get_cache().clear() + + web.run(basic_simulation, task_name="demo", path=str(out_path), use_cache=True) + + cache = get_cache() + metadata = cache.list()[0] + corrupted_path = cache.root / metadata["cache_key"] / CACHE_ARTIFACT_NAME + corrupted_path.write_text("corrupted") + + cache._fetch(metadata["cache_key"]) + assert len(cache) == 0 + + +def _test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulation): + monkeypatch.setattr(config.simulation_cache, "max_entries", 1) + cache = resolve_simulation_cache(use_cache=True) + cache.clear() + + file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME + file1.write_text("a") + cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + assert len(cache) == 1 + + sim2 = basic_simulation.updated_copy(shutoff=1e-4) + file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME + file2.write_text("b") + cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + + entries = cache.list() + assert len(entries) == 1 + assert entries[0]["simulation_hash"] == sim2._hash_self() + + +def _test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation): + monkeypatch.setattr(config.simulation_cache, "max_size_gb", float(10_000 * 1e-9)) + cache = resolve_simulation_cache(use_cache=True) + cache.clear() + + file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME + file1.write_text("a" * 8_000) + cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + assert len(cache) == 1 + + sim2 = basic_simulation.updated_copy(shutoff=1e-4) + file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME + file2.write_text("b" * 8_000) + cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + + entries = cache.list() + assert len(cache) == 1 + assert entries[0]["simulation_hash"] == sim2._hash_self() + + +def test_configure_cache_roundtrip(monkeypatch, tmp_path): + monkeypatch.setattr(config.simulation_cache, "enabled", True) + monkeypatch.setattr(config.simulation_cache, "directory", tmp_path) + monkeypatch.setattr(config.simulation_cache, "max_size_gb", 1.23) + monkeypatch.setattr(config.simulation_cache, "max_entries", 5) + + cfg = resolve_simulation_cache().config + assert cfg.enabled is True + assert cfg.directory == tmp_path + assert cfg.max_size_gb == 1.23 + assert cfg.max_entries == 5 + + +def test_env_var_overrides(monkeypatch, tmp_path): + monkeypatch.setenv("TIDY3D_CACHE_ENABLED", "true") + monkeypatch.setenv("TIDY3D_CACHE_DIR", str(tmp_path)) + monkeypatch.setenv("TIDY3D_CACHE_MAX_SIZE_GB", "0.5") + + monkeypatch.setattr(config.simulation_cache, "max_entries", 5) + monkeypatch.setenv("TIDY3D_CACHE_MAX_ENTRIES", "7") + + cfg = resolve_simulation_cache().config + assert cfg.enabled is True + assert cfg.directory == tmp_path + assert cfg.max_size_gb == 0.5 + assert cfg.max_entries == 7 + + +def test_cache_end_to_end(monkeypatch, tmp_path, tmp_path_factory, basic_simulation, fake_data): + """Run all critical cache tests in sequence to ensure end-to-end stability.""" + _test_run_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data) + _test_load_cache_hit(monkeypatch, tmp_path, basic_simulation, fake_data) + _test_checksum_mismatch_triggers_refresh(monkeypatch, tmp_path, basic_simulation) + _test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulation) + _test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation) + _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path) + _test_job_run_cache(monkeypatch, basic_simulation) + _test_autograd_cache(monkeypatch) diff --git a/tidy3d/config.py b/tidy3d/config.py index 9dfc7b702c..7fa0c6c169 100644 --- a/tidy3d/config.py +++ b/tidy3d/config.py @@ -2,12 +2,42 @@ from __future__ import annotations +from pathlib import Path from typing import Optional import pydantic.v1 as pd from .log import DEFAULT_LEVEL, LogLevel, set_log_suppression, set_logging_level +_DEFAULT_CACHE_DIR = Path.home() / ".tidy3d" / "cache" / "simulations" + + +class SimulationCacheSettings(pd.BaseModel): + """Settings controlling the optional local simulation cache.""" + + enabled: bool = pd.Field( + False, + description="Enable or disable the local simulation cache.", + ) + directory: Path = pd.Field( + _DEFAULT_CACHE_DIR, + description="Directory where cached simulation artifacts are stored.", + ) + max_size_gb: float = pd.Field( + 10.0, + description="Maximum cache size in gigabytes. Set to 0 for no size limit.", + ge=0.0, + ) + max_entries: int = pd.Field( + 128, + description="Maximum number of cache entries. Set to 0 for no limit.", + ge=0, + ) + + @pd.validator("directory", pre=True, always=True) + def _validate_directory(cls, value): + return Path(value).expanduser() + class Tidy3dConfig(pd.BaseModel): """configuration of tidy3d""" @@ -43,6 +73,12 @@ class Config: "averaging will be used if 'tidy3d-extras' is installed and not used otherwise.", ) + simulation_cache: SimulationCacheSettings = pd.Field( + default_factory=SimulationCacheSettings, + title="Simulation Cache", + description="Configuration for the optional local simulation cache.", + ) + @pd.validator("logging_level", pre=True, always=True) def _set_logging_level(cls, val): """Set the logging level if logging_level is changed.""" diff --git a/tidy3d/web/api/asynchronous.py b/tidy3d/web/api/asynchronous.py index da628261f3..34c03569c1 100644 --- a/tidy3d/web/api/asynchronous.py +++ b/tidy3d/web/api/asynchronous.py @@ -24,6 +24,7 @@ def run_async( reduce_simulation: Literal["auto", True, False] = "auto", pay_type: Union[PayType, str] = PayType.AUTO, priority: Optional[int] = None, + use_cache: Optional[bool] = None, ) -> BatchData: """Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server, starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object. @@ -56,6 +57,10 @@ def run_async( priority: int = None Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest). It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits. + use_cache: Optional[bool] = None + Whether to use local cache if identical simulation is rerun. If not provided, cache settings from config or + environment variables will be used. + Returns ------ :class:`BatchData` @@ -91,6 +96,7 @@ def run_async( parent_tasks=parent_tasks, reduce_simulation=reduce_simulation, pay_type=pay_type, + use_cache=use_cache, ) batch_data = batch.run(path_dir=path_dir, priority=priority) diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 7958dc57c9..5bd4e611e5 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -117,6 +117,7 @@ def run( reduce_simulation: typing.Literal["auto", True, False] = "auto", pay_type: typing.Union[PayType, str] = PayType.AUTO, priority: typing.Optional[int] = None, + use_cache: typing.Optional[bool] = None, ) -> WorkflowDataType: """ Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads, @@ -158,6 +159,9 @@ def run( Which method to pay for the simulation. priority: int = None Task priority for vGPU queue (1=lowest, 10=highest). + use_cache: Optional[bool] = None + Whether to use local cache if identical simulation is rerun. If not provided, cache settings from config or + environment variables will be used. Returns ------- Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`, :class:`.ModalComponentModelerData`, :class:`.TerminalComponentModelerData`] @@ -248,6 +252,7 @@ def run( max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, pay_type=pay_type, priority=priority, + use_cache=use_cache, ) return run_webapi( @@ -266,6 +271,7 @@ def run( reduce_simulation=reduce_simulation, pay_type=pay_type, priority=priority, + use_cache=use_cache, ) @@ -284,6 +290,7 @@ def run_async( reduce_simulation: typing.Literal["auto", True, False] = "auto", pay_type: typing.Union[PayType, str] = PayType.AUTO, priority: typing.Optional[int] = None, + use_cache: typing.Optional[bool] = None, ) -> BatchData: """Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server, starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object. @@ -318,6 +325,9 @@ def run_async( Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver. pay_type: typing.Union[PayType, str] = PayType.AUTO Specify the payment method. + use_cache: Optional[bool] = None + Whether to use local cache if identical simulation is rerun. If not provided, cache settings from config or + environment variables will be used. Returns ------ @@ -360,6 +370,7 @@ def run_async( max_num_adjoint_per_fwd=max_num_adjoint_per_fwd, pay_type=pay_type, priority=priority, + use_cache=use_cache, ) return run_async_webapi( @@ -375,6 +386,7 @@ def run_async( reduce_simulation=reduce_simulation, pay_type=pay_type, priority=priority, + use_cache=use_cache, ) diff --git a/tidy3d/web/api/autograd/engine.py b/tidy3d/web/api/autograd/engine.py index 2cd0abe451..bf7713f534 100644 --- a/tidy3d/web/api/autograd/engine.py +++ b/tidy3d/web/api/autograd/engine.py @@ -10,7 +10,7 @@ def parse_run_kwargs(**run_kwargs): """Parse the ``run_kwargs`` to extract what should be passed to the ``Job``/``Batch`` init.""" - job_fields = [*list(Job._upload_fields), "solver_version", "pay_type"] + job_fields = [*list(Job._upload_fields), "solver_version", "pay_type", "use_cache"] job_init_kwargs = {k: v for k, v in run_kwargs.items() if k in job_fields} return job_init_kwargs diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index dc6a96b5a3..d8ecf6eb12 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -4,13 +4,16 @@ import concurrent import os +import shutil import time +import uuid from abc import ABC from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor from typing import Literal, Optional, Union import pydantic.v1 as pd +from pydantic.v1 import PrivateAttr from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn from tidy3d.components.base import Tidy3dBaseModel, cached_property @@ -21,6 +24,10 @@ from tidy3d.log import get_logging_console, log from tidy3d.web.api import webapi as web from tidy3d.web.api.tidy3d_stub import Tidy3dStub +from tidy3d.web.api.webapi import ( + restore_simulation_if_cached, +) +from tidy3d.web.cache import TMP_BATCH_PREFIX, resolve_simulation_cache from tidy3d.web.core.constants import TaskId, TaskName from tidy3d.web.core.task_core import Folder from tidy3d.web.core.task_info import RunInfo, TaskInfo @@ -224,6 +231,14 @@ class Job(WebContainer): "reduce_simulation", ) + _cache_file_moved: bool = PrivateAttr(default=False) + + use_cache: Optional[bool] = pd.Field( + None, + title="Use Cache", + description="Whether to use local cache for retrieving Simulation results.", + ) + def to_file(self, fname: str) -> None: """Exports :class:`Tidy3dBaseModel` instance to .yaml, .json, or .hdf5 file @@ -241,7 +256,9 @@ def to_file(self, fname: str) -> None: super(Job, self).to_file(fname=fname) # noqa: UP008 def run( - self, path: str = DEFAULT_DATA_PATH, priority: Optional[int] = None + self, + path: str = DEFAULT_DATA_PATH, + priority: Optional[int] = None, ) -> WorkflowDataType: """Run :class:`Job` all the way through and return data. @@ -257,17 +274,58 @@ def run( :class:`WorkflowDataType` Object containing simulation results. """ - self.upload() - if priority is None: - self.start() - else: - self.start(priority=priority) - self.monitor() - return self.load(path=path) + self._check_path_dir(path=path) + + loaded_from_cache = self.load_if_cached + if not loaded_from_cache: + self.upload() + if priority is None: + self.start() + else: + self.start(priority=priority) + self.monitor() + data = self.load(path=path) + + return data + + @cached_property + def data_cache_path(self) -> Optional[str]: + "Temporary path where cached results are stored." + cache = resolve_simulation_cache(self.use_cache) + if cache is not None: + path = os.path.join(cache._root, TMP_BATCH_PREFIX, f"{self.task_id_if_cached}.hdf5") + return path + return None + + @cached_property + def load_if_cached(self) -> bool: + """Checks if data is already cached. + + Returns + ------- + bool + Whether item was found in cache. + """ + path = self.data_cache_path + if path is None: + return False + self._check_path_dir(path=path) + return restore_simulation_if_cached( + simulation=self.simulation, + path=path, + use_cache=self.use_cache, + reduce_simulation=self.reduce_simulation, + ) + + @property + def task_id_if_cached(self) -> str: + return "cached_" + self.task_name + "_" + str(uuid.uuid4()) @cached_property def task_id(self) -> TaskId: """The task ID for this ``Job``. Uploads the ``Job`` if it hasn't already been uploaded.""" + if self.load_if_cached: + return self.task_id_if_cached if self.task_id_cached: return self.task_id_cached self._check_folder(self.folder_name) @@ -281,7 +339,9 @@ def _upload(self) -> TaskId: return task_id def upload(self) -> None: - """Upload this ``Job``.""" + """Upload this ``Job`` if not already got cached results.""" + if self.load_if_cached: + return _ = self.task_id def get_info(self) -> TaskInfo: @@ -298,6 +358,8 @@ def get_info(self) -> TaskInfo: @property def status(self): """Return current status of :class:`Job`.""" + if self.load_if_cached: + return "success" return self.get_info().status def start(self, priority: Optional[int] = None) -> None: @@ -312,13 +374,16 @@ def start(self, priority: Optional[int] = None) -> None: Note ---- To monitor progress of the :class:`Job`, call :meth:`Job.monitor` after started. + Function has no effect if cache is enabled and data was found in cache. """ - web.start( - self.task_id, - solver_version=self.solver_version, - pay_type=self.pay_type, - priority=priority, - ) + loaded = self.load_if_cached + if not loaded: + web.start( + self.task_id, + solver_version=self.solver_version, + pay_type=self.pay_type, + priority=priority, + ) def get_run_info(self) -> RunInfo: """Return information about the running :class:`Job`. @@ -338,6 +403,8 @@ def monitor(self) -> None: To load the output of completed simulation into :class:`.SimulationData` objects, call :meth:`Job.load`. """ + if self.load_if_cached: + return web.monitor(self.task_id, verbose=self.verbose) def download(self, path: str = DEFAULT_DATA_PATH) -> None: @@ -352,9 +419,21 @@ def download(self, path: str = DEFAULT_DATA_PATH) -> None: ---- To load the data after download, use :meth:`Job.load`. """ + if self.load_if_cached: + self.move_cache_file(path=path) + return self._check_path_dir(path=path) web.download(task_id=self.task_id, path=path, verbose=self.verbose) + def move_cache_file(self, path: str) -> None: + if self._cache_file_moved: + return + if os.path.exists(self.data_cache_path): + shutil.move(self.data_cache_path, path) + self._cache_file_moved = True + else: + raise FileNotFoundError(f"Cached file does not longer exist in {self.data_cache_path}.") + def load(self, path: str = DEFAULT_DATA_PATH) -> WorkflowDataType: """Download job results and load them into a data object. @@ -368,8 +447,17 @@ def load(self, path: str = DEFAULT_DATA_PATH) -> WorkflowDataType: Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] Object containing simulation results. """ + if self.load_if_cached: + self.move_cache_file(path=path) + self._check_path_dir(path=path) - data = web.load(task_id=self.task_id, path=path, verbose=self.verbose) + data = web.load( + task_id=self.task_id, + path=path, + verbose=self.verbose, + use_cache=self.use_cache, + from_cache=self.load_if_cached, + ) if isinstance(self.simulation, ModeSolver): self.simulation._patch_data(data=data) return data @@ -411,6 +499,8 @@ def estimate_cost(self, verbose: bool = True) -> float: Cost is calculated assuming the simulation runs for the full ``run_time``. If early shut-off is triggered, the cost is adjusted proportionately. """ + if self.load_if_cached: + return 0.0 return web.estimate_cost(self.task_id, verbose=verbose, solver_version=self.solver_version) @staticmethod @@ -477,13 +567,41 @@ class BatchData(Tidy3dBaseModel, Mapping): True, title="Verbose", description="Whether to print info messages and progressbars." ) + cached_tasks: Optional[dict[TaskName, bool]] = pd.Field( + None, + title="Cached Tasks", + description="Whether the data of a task came from the cache.", + ) + + use_cache: Optional[bool] = pd.Field( + None, + title="Use Cache", + description="Whether to use local cache for retrieving Simulation results.", + ) + + is_downloaded: Optional[bool] = pd.Field( + False, + title="Is Downloaded", + description="Whether the simulation data was downloaded before.", + ) + + def load_sim_data(self, task_name: str) -> WorkflowDataType: """Load a simulation data object from file by task name.""" task_data_path = self.task_paths[task_name] task_id = self.task_ids[task_name] - web.get_info(task_id) - - return web.load(task_id=task_id, path=task_data_path, verbose=False) + from_cache = self.cached_tasks[task_name] if self.cached_tasks else False + if not from_cache: + web.get_info(task_id) + + return web.load( + task_id=task_id, + path=task_data_path, + verbose=False, + from_cache=from_cache, + use_cache=self.use_cache, + replace_existing=not (from_cache or self.is_downloaded), + ) def __getitem__(self, task_name: TaskName) -> WorkflowDataType: """Get the simulation data object for a given ``task_name``.""" @@ -623,6 +741,12 @@ class Batch(WebContainer): "fields that were not used to create the task will cause errors.", ) + use_cache: Optional[bool] = pd.Field( + None, + title="Use Cache", + description="Whether to use local cache for retrieving Simulation results.", + ) + _job_type = Job def run( @@ -659,14 +783,16 @@ def run( rather it iterates over the task names and loads the corresponding data from file one by one. If no file exists for that task, it downloads it. """ - self._check_path_dir(path_dir) - self.upload() - self.to_file(self._batch_path(path_dir=path_dir)) - if priority is None: - self.start() - else: - self.start(priority=priority) - self.monitor() + loaded = [job.load_if_cached for job in self.jobs.values()] + if not all(loaded): + self._check_path_dir(path_dir) + self.upload() + self.to_file(self._batch_path(path_dir=path_dir)) + if priority is None: + self.start() + else: + self.start(priority=priority) + self.monitor() return self.load(path_dir=path_dir) @cached_property @@ -708,6 +834,7 @@ def jobs(self) -> dict[TaskName, Job]: job_kwargs["solver_version"] = self.solver_version job_kwargs["pay_type"] = self.pay_type job_kwargs["reduce_simulation"] = self.reduce_simulation + job_kwargs["use_cache"] = self.use_cache if self.parent_tasks and task_name in self.parent_tasks: job_kwargs["parent_tasks"] = self.parent_tasks[task_name] job = JobType(**job_kwargs) @@ -1087,15 +1214,24 @@ def load(self, path_dir: str = DEFAULT_DATA_DIR, replace_existing: bool = False) task_paths[task_name] = self._job_data_path(task_id=job.task_id, path_dir=path_dir) task_ids[task_name] = self.jobs[task_name].task_id - data = BatchData(task_paths=task_paths, task_ids=task_ids, verbose=self.verbose) + loaded = {task_name: job.load_if_cached for task_name, job in self.jobs.items()} + + self.download(path_dir=path_dir, replace_existing=replace_existing) + + data = BatchData( + task_paths=task_paths, + task_ids=task_ids, + verbose=self.verbose, + cached_tasks=loaded, + use_cache=self.use_cache, + is_downloaded=True, + ) for task_name, job in self.jobs.items(): if isinstance(job.simulation, ModeSolver): job_data = data[task_name] job.simulation._patch_data(data=job_data) - self.download(path_dir=path_dir, replace_existing=replace_existing) - return data def delete(self) -> None: diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index e165b83b05..2fd4f17856 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -6,6 +6,7 @@ import os import tempfile import time +from pathlib import Path from typing import Callable, Literal, Optional, Union from requests import HTTPError @@ -18,6 +19,7 @@ from tidy3d.exceptions import WebError from tidy3d.log import get_logging_console, log from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler +from tidy3d.web.cache import CacheEntry, resolve_simulation_cache from tidy3d.web.core.account import Account from tidy3d.web.core.constants import ( CM_DATA_HDF5_GZ, @@ -122,6 +124,64 @@ def _task_dict_to_url_bullet_list(data_dict: dict) -> str: return "\n".join([f"- {key}: '{value}'" for key, value in data_dict.items()]) +def _copy_simulation_data_from_cache_entry(entry: CacheEntry, path: str) -> bool: + if entry is not None: + try: + entry.materialize(Path(path)) + return True + except Exception: + return False + return False + + +def restore_simulation_if_cached( + simulation: WorkflowType, + path: str, + use_cache: Optional[bool] = None, + reduce_simulation: Literal["auto", True, False] = "auto", + verbose: bool = True, +) -> bool: + simulation_cache = resolve_simulation_cache(use_cache) + copied_from_cache = False + if simulation_cache is not None: + sim_for_cache = simulation + if isinstance(simulation, (ModeSolver, ModeSimulation)): + sim_for_cache = get_reduced_simulation(simulation, reduce_simulation) + entry = simulation_cache.try_fetch(simulation=sim_for_cache, verbose=verbose) + if entry is not None: + copied_from_cache = _copy_simulation_data_from_cache_entry(entry, path) + cached_task_id = entry.metadata.get("task_id") + cached_workflow_type = entry.metadata.get("workflow_type") + if cached_task_id is not None and cached_workflow_type is not None and verbose: + console = get_logging_console() if verbose else None + url, _ = _get_task_urls( + cached_workflow_type, + simulation, + cached_task_id) + console.log(f"Loaded simulation from local cache.\nView cached task using web UI at [link={url}]'{url}'[/link].") + return copied_from_cache + + +def load_simulation_if_cached( + simulation: WorkflowType, + path: str, + use_cache: Optional[bool] = None, + reduce_simulation: Literal["auto", True, False] = "auto", +) -> Optional[WorkflowDataType]: + restored = restore_simulation_if_cached(simulation, path, use_cache, reduce_simulation) + if restored: + data = load( + task_id=None, + path=path, + from_cache=True, + ) + if isinstance(simulation, ModeSolver): + simulation._patch_data(data=data) + return data + else: + return None + + @wait_for_connection def run( simulation: WorkflowType, @@ -139,6 +199,8 @@ def run( reduce_simulation: Literal["auto", True, False] = "auto", pay_type: Union[PayType, str] = PayType.AUTO, priority: Optional[int] = None, + use_cache: Optional[bool] = None, + lazy: bool = False, ) -> WorkflowDataType: """ Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads, @@ -176,6 +238,11 @@ def run( priority: int = None Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest). It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits. + use_cache: Optional[bool] = None + Whether to use local cache if identical simulation is rerun. If not provided, cache settings from config or + environment variables will be used. + lazy: bool = False + Whether to load the simulation data lazily (not until data access). Returns ------- Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] @@ -220,34 +287,69 @@ def run( :meth:`tidy3d.web.api.container.Batch.monitor` Monitor progress of each of the running tasks. """ - task_id = upload( - simulation=simulation, - task_name=task_name, - folder_name=folder_name, - callback_url=callback_url, - verbose=verbose, - progress_callback=progress_callback_upload, - simulation_type=simulation_type, - parent_tasks=parent_tasks, - solver_version=solver_version, - reduce_simulation=reduce_simulation, - ) - start( - task_id, - verbose=verbose, - solver_version=solver_version, - worker_group=worker_group, - pay_type=pay_type, - priority=priority, + copied_from_cache = restore_simulation_if_cached( + simulation=simulation, path=path, use_cache=use_cache, reduce_simulation=reduce_simulation, verbose=verbose ) - monitor(task_id, verbose=verbose) + + if not copied_from_cache: + task_id = upload( + simulation=simulation, + task_name=task_name, + folder_name=folder_name, + callback_url=callback_url, + verbose=verbose, + progress_callback=progress_callback_upload, + simulation_type=simulation_type, + parent_tasks=parent_tasks, + solver_version=solver_version, + reduce_simulation=reduce_simulation, + ) + start( + task_id, + verbose=verbose, + solver_version=solver_version, + worker_group=worker_group, + pay_type=pay_type, + priority=priority, + ) + monitor(task_id, verbose=verbose) + else: + task_id = None + data = load( - task_id=task_id, path=path, verbose=verbose, progress_callback=progress_callback_download + task_id=task_id, + path=path, + verbose=verbose, + progress_callback=progress_callback_download, + use_cache=use_cache, + from_cache=copied_from_cache, + lazy=lazy, ) + if isinstance(simulation, ModeSolver): simulation._patch_data(data=data) return data +def _get_task_urls( + task_type: str, + simulation: WorkflowType, + resource_id: str, + folder_id: Optional[str] = None, + group_id: Optional[str] = None, +) -> tuple[str, Optional[str]]: + """Log task and folder links to the web UI.""" + print("task_type:", task_type) + if (task_type in ["RF", "COMPONENT_MODELER", "TERMINAL_COMPONENT_MODELER"]) and isinstance(simulation, TerminalComponentModeler): + url = _get_url_rf(group_id or resource_id) + else: + url = _get_url(resource_id) + + if folder_id is not None: + folder_url = _get_folder_url(folder_id) + else: + folder_url = None + return url, folder_url + @wait_for_connection def upload( @@ -369,16 +471,9 @@ def upload( f"Cost of {solver_name} simulations is subject to change in the future." ) if task_type in GUI_SUPPORTED_TASK_TYPES: - if (task_type == "RF") and (isinstance(simulation, TerminalComponentModeler)): - url = _get_url_rf(group_id or resource_id) - folder_url = _get_folder_url(task.folder_id) - console.log(f"View task using web UI at [link={url}]'{url}'[/link].") - console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") - else: - url = _get_url(resource_id) - folder_url = _get_folder_url(task.folder_id) - console.log(f"View task using web UI at [link={url}]'{url}'[/link].") - console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") + url, folder_url = _get_task_urls(task_type, simulation, resource_id, task.folder_id, group_id) + console.log(f"View task using web UI at [link={url}]'{url}'[/link].") + console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") remote_sim_file = SIM_FILE_HDF5_GZ if task_type == "MODE_SOLVER": @@ -983,11 +1078,13 @@ def download_log( @wait_for_connection def load( - task_id: TaskId, + task_id: Optional[TaskId], path: str = "simulation_data.hdf5", replace_existing: bool = True, verbose: bool = True, progress_callback: Optional[Callable[[float], None]] = None, + use_cache: Optional[bool] = False, + from_cache: bool = False, lazy: bool = False, ) -> WorkflowDataType: """ @@ -1018,6 +1115,11 @@ def load( If ``True``, will print progressbars and status, otherwise, will run silently. progress_callback : Callable[[float], None] = None Optional callback function called when downloading file with ``bytes_in_chunk`` as argument. + use_cache: Optional[bool] = None + Whether to use local cache if identical simulation is rerun. If not provided, cache settings from config or + environment variables will be used. + from_cache: bool = None + Whether data will be loaded from cache. lazy : bool = False Whether to load the actual data (``lazy=False``) or return a proxy that loads the data when accessed (``lazy=True``). @@ -1027,22 +1129,43 @@ def load( Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] Object containing simulation data. """ + assert from_cache or task_id, "Either task_id or from_cache must be provided." + # For component modeler batches, default to a clearer filename if the default was used. - if _is_modeler_batch(task_id) and os.path.basename(path) == "simulation_data.hdf5": + if ( + not from_cache + and _is_modeler_batch(task_id) + and os.path.basename(path) == "simulation_data.hdf5" + ): base_dir = os.path.dirname(path) or "." path = os.path.join(base_dir, "cm_data.hdf5") - if not os.path.exists(path) or replace_existing: + if from_cache: + if not os.path.exists(path): + raise FileNotFoundError("Cached file not found.") + elif not os.path.exists(path) or replace_existing: download(task_id=task_id, path=path, verbose=verbose, progress_callback=progress_callback) if verbose: console = get_logging_console() - if _is_modeler_batch(task_id): + if not from_cache and _is_modeler_batch(task_id): # TODO inspect console.log(f"loading component modeler data from {path}") else: console.log(f"loading simulation from {path}") stub_data = Tidy3dStubData.postprocess(path, lazy=lazy) + + simulation_cache = resolve_simulation_cache(use_cache) + if simulation_cache is not None and not from_cache: + info = get_info(task_id, verbose=False) + workflow_type = getattr(info, "taskType", None) or type(stub_data).__name__ + simulation_cache.store_result( + stub_data=stub_data, + task_id=task_id, + path=path, + workflow_type=workflow_type, + ) + return stub_data diff --git a/tidy3d/web/cache.py b/tidy3d/web/cache.py new file mode 100644 index 0000000000..d661e74ea2 --- /dev/null +++ b/tidy3d/web/cache.py @@ -0,0 +1,664 @@ +"""Local simulation cache manager.""" + +from __future__ import annotations + +import hashlib +import json +import os +import shutil +import tempfile +import threading +from collections.abc import Iterable +from dataclasses import dataclass, field, replace +from datetime import datetime, timezone +from enum import Enum +from functools import lru_cache +from pathlib import Path +from typing import Any, Optional + +from tidy3d import config +from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType +from tidy3d.log import log +from tidy3d.web.api.tidy3d_stub import Tidy3dStub +from tidy3d.web.core.constants import TaskId +from tidy3d.web.core.http_util import get_version as _get_protocol_version + +DEFAULT_CACHE_RELATIVE_DIR = Path(".tidy3d") / "cache" / "simulations" +CACHE_ARTIFACT_NAME = "simulation_data.hdf5" +CACHE_METADATA_NAME = "metadata.json" + +ENV_ENABLE = "TIDY3D_CACHE_ENABLED" +ENV_DIRECTORY = "TIDY3D_CACHE_DIR" +ENV_MAX_SIZE = "TIDY3D_CACHE_MAX_SIZE_GB" +ENV_MAX_ENTRIES = "TIDY3D_CACHE_MAX_ENTRIES" + +TMP_PREFIX = "tidy3d-cache-" +TMP_BATCH_PREFIX = "tmp_batch" + + +_CONFIG_LOCK = threading.RLock() + + +@dataclass(frozen=True) +class SimulationCacheConfig: + """Configuration for the simulation cache.""" + + enabled: bool = False + directory: Path = field(default_factory=lambda: Path.home() / DEFAULT_CACHE_RELATIVE_DIR) + max_size_gb: float = 8.0 + max_entries: int = 32 + + +def _coerce_bool(value: str) -> Optional[bool]: + if value is None: + return None + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + return None + + +def _coerce_float(value: str) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _coerce_int(value: str) -> Optional[int]: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _load_env_overrides() -> dict[str, Any]: + overrides: dict[str, Any] = {} + + enabled_env = _coerce_bool(os.getenv(ENV_ENABLE)) + if enabled_env is not None: + overrides["enabled"] = enabled_env + + directory_env = os.getenv(ENV_DIRECTORY) + if directory_env: + overrides["directory"] = directory_env + + size_env = _coerce_float(os.getenv(ENV_MAX_SIZE)) + if size_env is not None: + overrides["max_size_gb"] = size_env + + entries_env = _coerce_int(os.getenv(ENV_MAX_ENTRIES)) + if entries_env is not None: + overrides["max_entries"] = entries_env + + return overrides + + +def _load_effective_config() -> SimulationCacheConfig: + """ + Build the initial, global cache config at import-time. + + Precedence for fields (lowest → highest): + 1) library defaults (disabled, ~/.tidy3d/cache/simulations, limits) + 2) persisted app config (config.simulation_cache_settings), if present + 3) environment overrides (TIDY3D_CACHE_*) + + Note: per-call `use_cache` is *not* applied here; that’s handled in + resolve_simulation_cache(...), which can reconfigure the singleton later. + """ + sim_cache_settings = config.simulation_cache + + cfg = SimulationCacheConfig( + enabled=sim_cache_settings.enabled, + directory=sim_cache_settings.directory, + max_size_gb=sim_cache_settings.max_size_gb, + max_entries=sim_cache_settings.max_entries, + ) + + env_overrides = _load_env_overrides() + if env_overrides: + allowed = {k: v for k, v in env_overrides.items() if v is not None} + if allowed: + cfg = replace(cfg, **allowed) + + if cfg.directory: + cfg = replace(cfg, directory=Path(cfg.directory).expanduser().resolve()) + + return cfg + + +_CACHE_CONFIG: SimulationCacheConfig = _load_effective_config() + + +def get_cache_config() -> SimulationCacheConfig: + """Thread-safe snapshot copy of the active global cache configuration.""" + with _CONFIG_LOCK: + return replace(_CACHE_CONFIG) + + +def configure_cache(new_config: SimulationCacheConfig) -> None: + """Swap the active global config and reset the cache singleton.""" + global _CACHE_CONFIG + with _CONFIG_LOCK: + _CACHE_CONFIG = new_config + get_cache.cache_clear() + + +@lru_cache +def get_cache() -> SimulationCache: + """ + Return the singleton SimulationCache built from the *current* global config. + + This is automatically refreshed whenever `configure_cache(...)` is called, + because that function clears this LRU entry. + """ + cfg = get_cache_config() + return SimulationCache(cfg) + + +def _apply_overrides( + cfg: SimulationCacheConfig, overrides: dict[str, Any] +) -> SimulationCacheConfig: + """Apply dict-based overrides (enabled/directory/max_size_gb/max_entries).""" + if not overrides: + return cfg + # Filter to fields that exist on the dataclass and are not None + allowed = {k: v for k, v in overrides.items() if v is not None and hasattr(cfg, k)} + return replace(cfg, **allowed) if allowed else cfg + + +def resolve_simulation_cache(use_cache: Optional[bool] = None) -> Optional[SimulationCache]: + """ + Return a SimulationCache configured from: + 1) persisted config (directory/limits + default enabled), + 2) environment overrides (enabled + directory/limits), + 3) per-call 'use_cache' (enabled only, highest precedence). + + If effective config differs from the active global config, reconfigure the singleton. + Returns None if final 'enabled' is False. + """ + current = get_cache_config() + desired = _load_effective_config() + + if use_cache is not None: + if desired.directory != current.directory: + get_cache().clear(hard=True) + desired = replace(desired, enabled=use_cache) + + if desired != current: + configure_cache(desired) + + if not desired.enabled: + return None + + try: + return get_cache() + except Exception as err: + log.debug("Simulation cache unavailable: %s", err) + return None + + +@dataclass +class CacheEntry: + """Internal representation of a cache entry.""" + + key: str + root: Path + metadata: dict[str, Any] + + @property + def path(self) -> Path: + return self.root / self.key + + @property + def artifact_path(self) -> Path: + return self.path / CACHE_ARTIFACT_NAME + + @property + def metadata_path(self) -> Path: + return self.path / CACHE_METADATA_NAME + + def exists(self) -> bool: + return self.path.exists() and self.artifact_path.exists() and self.metadata_path.exists() + + def verify(self) -> bool: + if not self.exists(): + return False + checksum = self.metadata.get("checksum") + if not checksum: + return False + try: + actual_checksum, file_size = _copy_and_hash(self.artifact_path, None) + except FileNotFoundError: + return False + if checksum != actual_checksum: + log.warning( + "Simulation cache checksum mismatch for key '%s'. Removing stale entry.", self.key + ) + return False + if int(self.metadata.get("file_size", file_size)) != file_size: + self.metadata["file_size"] = file_size + _write_metadata(self.metadata_path, self.metadata) + return True + + def materialize(self, target: Path) -> Path: + """Copy cached artifact to ``target`` and return the resulting path.""" + target = Path(target) + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(self.artifact_path, target) + return target + + +class SimulationCache: + """Manages storing and retrieving cached simulation artifacts.""" + + def __init__(self, config: SimulationCacheConfig): + self._config = config + self._root = Path(config.directory).expanduser().resolve() + self._lock = threading.RLock() + if config.enabled: + self._root.mkdir(parents=True, exist_ok=True) + + @property + def config(self) -> SimulationCacheConfig: + return self._config + + @property + def root(self) -> Path: + return self._root + + def list(self) -> list[dict[str, Any]]: + """Return metadata for all cache entries.""" + with self._lock: + return [entry.metadata for entry in self._iter_entries()] + + def clear(self, hard=False) -> None: + """Remove all cache contents.""" + with self._lock: + if self._root.exists(): + try: + shutil.rmtree(self._root) + if not hard: + self._root.mkdir(parents=True, exist_ok=True) + except (FileNotFoundError, OSError): + pass + + def _fetch(self, key: str) -> Optional[CacheEntry]: + """Retrieve an entry by key, verifying checksum.""" + with self._lock: + entry = self._load_entry(key) + if not entry or not entry.exists(): + return None + if not entry.verify(): + self._remove_entry(entry) + return None + self._touch(entry) + return entry + + def __len__(self) -> int: + """Return number of valid cache entries.""" + with self._lock: + return sum(1 for _ in self._iter_entries()) + + def _store( + self, key: str, source_path: Path, metadata: dict[str, Any] + ) -> Optional[CacheEntry]: + """Store a new cache entry from ``source_path``. + + Parameters + ---------- + key : str + Cache key computed from simulation hash and runtime context. + source_path : Path + Location of the artifact to cache. + metadata : dict[str, Any] + Additional metadata to persist alongside artifact. + + Returns + ------- + CacheEntry + Representation of the stored cache entry. + """ + source_path = Path(source_path) + if not source_path.exists(): + raise FileNotFoundError(f"Cannot cache missing artifact: {source_path}") + os.makedirs(self._root, exist_ok=True) + tmp_dir = Path(tempfile.mkdtemp(prefix=TMP_PREFIX, dir=self._root)) + tmp_artifact = tmp_dir / CACHE_ARTIFACT_NAME + tmp_meta = tmp_dir / CACHE_METADATA_NAME + os.makedirs(tmp_dir, exist_ok=True) + + checksum, file_size = _copy_and_hash(source_path, tmp_artifact) + now_iso = _now() + metadata = dict(metadata) + metadata.setdefault("cache_key", key) + metadata.setdefault("created_at", now_iso) + metadata["last_used"] = now_iso + metadata["checksum"] = checksum + metadata["file_size"] = file_size + + _write_metadata(tmp_meta, metadata) + try: + with self._lock: + self._root.mkdir(parents=True, exist_ok=True) + self._ensure_limits(file_size) + final_dir = self._root / key + backup_dir: Optional[Path] = None + + try: + if final_dir.exists(): + backup_dir = final_dir.with_name( + f"{final_dir.name}.bak.{_timestamp_suffix()}" + ) + os.replace(final_dir, backup_dir) + # move tmp_dir into place + os.replace(tmp_dir, final_dir) + except Exception: + # restore backup if needed + if backup_dir and backup_dir.exists(): + os.replace(backup_dir, final_dir) + raise + else: + entry = CacheEntry(key=key, root=self._root, metadata=metadata) + if backup_dir and backup_dir.exists(): + shutil.rmtree(backup_dir, ignore_errors=True) + log.debug("Stored simulation cache entry '%s' (%d bytes).", key, file_size) + return entry + finally: + try: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + except FileNotFoundError: + pass + + def invalidate(self, key: str) -> None: + with self._lock: + entry = self._load_entry(key) + if entry: + self._remove_entry(entry) + + def _ensure_limits(self, incoming_size: int) -> None: + max_entries = max(self._config.max_entries, 0) + max_size_bytes = int(max(0.0, self._config.max_size_gb) * (1024**3)) + + entries = list(self._iter_entries()) + if max_entries and len(entries) >= max_entries: + self._evict(entries, keep=max_entries - 1) + entries = list(self._iter_entries()) + + if not max_size_bytes: + return + + existing_size = sum(int(e.metadata.get("file_size", 0)) for e in entries) + allowed_size = max(max_size_bytes - incoming_size, 0) + if existing_size > allowed_size: + self._evict_by_size(entries, existing_size, allowed_size) + + def _evict(self, entries: Iterable[CacheEntry], keep: int) -> None: + sorted_entries = sorted(entries, key=lambda e: e.metadata.get("last_used", "")) + to_remove = sorted_entries[: max(0, len(sorted_entries) - keep)] + for entry in to_remove: + self._remove_entry(entry) + + def _evict_by_size( + self, entries: Iterable[CacheEntry], current_size: int, allowed_size: float + ) -> None: + if allowed_size < 0: + allowed_size = 0 + sorted_entries = sorted(entries, key=lambda e: e.metadata.get("last_used", "")) + reclaimed = 0 + for entry in sorted_entries: + if current_size - reclaimed <= allowed_size: + break + size = int(entry.metadata.get("file_size", 0)) + self._remove_entry(entry) + reclaimed += size + log.info(f"Simulation cache evicted entry '{entry.key}' to reclaim {size} bytes.") + + def _iter_entries(self) -> Iterable[CacheEntry]: + if not self._root.exists(): + return [] + entries: list[CacheEntry] = [] + for child in self._root.iterdir(): + if child.name.startswith(TMP_PREFIX) or child.name.startswith(TMP_BATCH_PREFIX): + continue + meta_path = child / CACHE_METADATA_NAME + if not meta_path.exists(): + continue + try: + metadata = json.loads(meta_path.read_text(encoding="utf-8")) + except Exception: + metadata = {} + entries.append(CacheEntry(key=child.name, root=self._root, metadata=metadata)) + return entries + + def _load_entry(self, key: str) -> Optional[CacheEntry]: + entry = CacheEntry(key=key, root=self._root, metadata={}) + if not entry.metadata_path.exists() or not entry.artifact_path.exists(): + return None + try: + metadata = json.loads(entry.metadata_path.read_text(encoding="utf-8")) + except Exception: + metadata = {} + entry.metadata = metadata + return entry + + def _touch(self, entry: CacheEntry) -> None: + entry.metadata["last_used"] = _now() + _write_metadata(entry.metadata_path, entry.metadata) + + def _remove_entry(self, entry: CacheEntry) -> None: + if entry.path.exists(): + shutil.rmtree(entry.path, ignore_errors=True) + + def try_fetch( + self, + simulation: WorkflowType, + verbose: bool = False, + ) -> Optional[CacheEntry]: + """ + Attempt to resolve and fetch a cached result entry for the given simulation context. + On miss or any cache error, returns None (the caller should proceed with upload/run). + + Notes + ----- + - Mirrors the exact cache key/context computation from `run`. + - Safe to call regardless of `use_cache` value; will no-op if cache is disabled. + """ + try: + simulation_hash = simulation._hash_self() + workflow_type = Tidy3dStub(simulation=simulation).get_type() + + versions = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=versions, + ) + + entry = self._fetch(cache_key) + if not entry: + return None + if verbose: + log.info( + f"Simulation cache hit for workflow '{workflow_type}'; using local results." + ) + + return entry + except Exception: + log.error("Failed to fetch cache results.") + + def store_result( + self, + stub_data: WorkflowDataType, + task_id: TaskId, + path: str, + workflow_type: str, + ) -> None: + """ + After we have the data (postprocess done), store it in the cache using the + canonical key (simulation hash + workflow type + environment + version). + Also records the task_id mapping for legacy lookups. + """ + try: + simulation_obj = getattr(stub_data, "simulation", None) + simulation_hash = simulation_obj._hash_self() if simulation_obj is not None else None + if not simulation_hash: + return + + version = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=version, + ) + + metadata = build_entry_metadata( + simulation_hash=simulation_hash, + workflow_type=workflow_type, + task_id=task_id, + version=version, + path=Path(path), + ) + + self._store( + key=cache_key, + source_path=Path(path), + metadata=metadata, + ) + except Exception: + log.error("Could not store cache entry.") + + +def _copy_and_hash( + source: Path, dest: Optional[Path], existing_hash: Optional[str] = None +) -> tuple[str, int]: + """Copy ``source`` to ``dest`` while computing SHA256 checksum. + + Parameters + ---------- + source : Path + Source file path. + dest : Path or None + Destination file path. If ``None``, no copy is performed. + existing_hash : str, optional + If provided alongside ``dest`` and ``dest`` already exists, skip copying when hashes match. + + Returns + ------- + tuple[str, int] + The hexadecimal digest and file size in bytes. + """ + source = Path(source) + if dest is not None: + dest = Path(dest) + sha256 = _Hasher() + size = 0 + with source.open("rb") as src: + if dest is None: + while chunk := src.read(1024 * 1024): + sha256.update(chunk) + size += len(chunk) + else: + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as dst: + while chunk := src.read(1024 * 1024): + dst.write(chunk) + sha256.update(chunk) + size += len(chunk) + return sha256.hexdigest(), size + + +def _write_metadata(path: Path, metadata: dict[str, Any]) -> None: + tmp_path = path.with_suffix(".tmp") + with tmp_path.open("w", encoding="utf-8") as fh: + json.dump(metadata, fh, indent=2, sort_keys=True) + os.replace(tmp_path, path) + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _timestamp_suffix() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f") + + +class _Hasher: + def __init__(self): + self._hasher = hashlib.sha256() + + def update(self, data: bytes) -> None: + self._hasher.update(data) + + def hexdigest(self) -> str: + return self._hasher.hexdigest() + + +def clear() -> None: + """Remove all cache entries.""" + get_cache().clear() + + +def _canonicalize(value: Any) -> Any: + """Convert value into a JSON-serializable object for hashing/metadata.""" + + if isinstance(value, dict): + return { + str(k): _canonicalize(v) + for k, v in sorted(value.items(), key=lambda item: str(item[0])) + } + if isinstance(value, (list, tuple)): + return [_canonicalize(v) for v in value] + if isinstance(value, set): + return sorted(_canonicalize(v) for v in value) + if isinstance(value, Enum): + return value.value + if isinstance(value, Path): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, bytes): + return value.decode("utf-8", errors="ignore") + return value + + +def build_cache_key( + *, + simulation_hash: str, + version: str, +) -> str: + """Construct a deterministic cache key.""" + + payload = { + "simulation_hash": simulation_hash, + "versions": _canonicalize(version), + } + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +def build_entry_metadata( + *, + simulation_hash: str, + workflow_type: str, + task_id: str, + version: str, + path: Path, +) -> dict[str, Any]: + """Create metadata dictionary for a cache entry.""" + + metadata: dict[str, Any] = { + "simulation_hash": simulation_hash, + "workflow_type": workflow_type, + "versions": _canonicalize(version), + "task_id": task_id, + "path": str(path), + } + return metadata diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index f802752bb1..f2f24e576a 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -48,7 +48,7 @@ class ResponseCodes(Enum): NOT_FOUND = 404 -def get_version() -> None: +def get_version() -> str: """Get the version for the current environment.""" return core_config.get_version()