Skip to content

Commit 8d478df

Browse files
feat(tidy3d): FXC-3689 Per-Simulation Downloads During Batch Runs
1 parent efa6d66 commit 8d478df

File tree

4 files changed

+340
-145
lines changed

4 files changed

+340
-145
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3737
- `DirectivityMonitor` now forces `far_field_approx` to `True`, which was previously configurable.
3838
- Unified run submission API: `web.run(...)` is now a container-aware wrapper that accepts a single simulation or arbitrarily nested containers (`list`, `tuple`, `dict` values) and returns results in the same shape.
3939
- `web.Batch(ComponentModeler)` and `web.Job(ComponentModeler)` native support
40+
- Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion.
4041

4142
### Fixed
4243
- More robust `Sellmeier` and `Debye` material model, and prevent very large pole parameters in `PoleResidue` material model.

tests/test_web/test_webapi.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import os
5+
from concurrent.futures import Future
56
from types import SimpleNamespace
67

78
import numpy as np
@@ -25,7 +26,7 @@
2526
from tidy3d.web.api.asynchronous import run_async
2627
from tidy3d.web.api.container import Batch, Job, WebContainer
2728
from tidy3d.web.api.run import _collect_by_hash, run
28-
from tidy3d.web.api.tidy3d_stub import Tidy3dStubData
29+
from tidy3d.web.api.tidy3d_stub import Tidy3dStubData, task_type_name_of
2930
from tidy3d.web.api.webapi import (
3031
abort,
3132
delete,
@@ -65,6 +66,24 @@
6566
Env.dev.active()
6667

6768

69+
class ImmediateExecutor:
70+
def __init__(self, *args, **kwargs):
71+
pass
72+
73+
def submit(self, fn, *args, **kwargs):
74+
future = Future()
75+
try:
76+
result = fn(*args, **kwargs)
77+
except Exception as err: # pragma: no cover - defensive
78+
future.set_exception(err)
79+
else:
80+
future.set_result(result)
81+
return future
82+
83+
def shutdown(self, wait=True):
84+
pass
85+
86+
6887
def make_sim():
6988
"""Makes a simulation."""
7089
pulse = td.GaussianPulse(freq0=200e12, fwidth=20e12)
@@ -703,6 +722,106 @@ def mock_start_interrupt(self, *args, **kwargs):
703722
batch.run(path_dir=str(tmp_path))
704723

705724

725+
def test_batch_monitor_downloads_on_success(monkeypatch, tmp_path):
726+
events = []
727+
728+
class FakeJob:
729+
def __init__(self, task_id: str, statuses: list[str]):
730+
self.task_id = task_id
731+
self._statuses = statuses
732+
self._idx = 0
733+
734+
@property
735+
def status(self):
736+
status = self._statuses[self._idx]
737+
if self._idx < len(self._statuses) - 1:
738+
self._idx += 1
739+
events.append((self.task_id, "status", status))
740+
return status
741+
742+
def download(self, path: str):
743+
events.append((self.task_id, "download", path))
744+
745+
monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor)
746+
monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None)
747+
748+
sims = {"task_a": make_sim(), "task_b": make_sim()}
749+
batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False)
750+
batch._cached_properties = {}
751+
fake_jobs = {
752+
"task_a": FakeJob("task_a_id", ["running", "success", "success"]),
753+
"task_b": FakeJob("task_b_id", ["running", "running", "success"]),
754+
}
755+
batch._cached_properties["jobs"] = fake_jobs
756+
757+
batch.monitor(download_on_success=True, path_dir=str(tmp_path))
758+
759+
downloads = [event for event in events if event[1] == "download"]
760+
assert len(downloads) == 2
761+
assert {event[0] for event in downloads} == {"task_a_id", "task_b_id"}
762+
763+
expected_paths = {
764+
"task_a_id": os.path.join(str(tmp_path), "task_a_id.hdf5"),
765+
"task_b_id": os.path.join(str(tmp_path), "task_b_id.hdf5"),
766+
}
767+
768+
for task_id, _, path in downloads:
769+
assert path == expected_paths[task_id]
770+
771+
job1_download_idx = next(
772+
i
773+
for i, event in enumerate(events)
774+
if event == ("task_a_id", "download", expected_paths["task_a_id"])
775+
)
776+
job2_success_idx = next(
777+
i for i, event in enumerate(events) if event == ("task_b_id", "status", "success")
778+
)
779+
780+
assert job1_download_idx < job2_success_idx, "Download should start before other jobs finish"
781+
782+
783+
def test_batch_monitor_skips_existing_download(monkeypatch, tmp_path):
784+
events = []
785+
786+
class FakeJob:
787+
def __init__(self, task_id: str, statuses: list[str]):
788+
self.task_id = task_id
789+
self._statuses = statuses
790+
self._idx = 0
791+
792+
@property
793+
def status(self):
794+
status = self._statuses[self._idx]
795+
if self._idx < len(self._statuses) - 1:
796+
self._idx += 1
797+
events.append((self.task_id, "status", status))
798+
return status
799+
800+
def download(self, path: str):
801+
events.append((self.task_id, "download", path))
802+
803+
monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor)
804+
monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None)
805+
806+
sims = {"task_a": make_sim(), "task_b": make_sim()}
807+
batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False)
808+
batch._cached_properties = {}
809+
fake_jobs = {
810+
"task_a": FakeJob("task_a_id", ["success", "success"]),
811+
"task_b": FakeJob("task_b_id", ["running", "success"]),
812+
}
813+
batch._cached_properties["jobs"] = fake_jobs
814+
815+
existing_path = os.path.join(str(tmp_path), "task_a_id.hdf5")
816+
with open(existing_path, "w", encoding="utf8") as handle:
817+
handle.write("cached")
818+
819+
batch.monitor(download_on_success=True, path_dir=str(tmp_path))
820+
821+
downloads = [event for event in events if event[1] == "download"]
822+
assert downloads == [("task_b_id", "download", os.path.join(str(tmp_path), "task_b_id.hdf5"))]
823+
824+
706825
""" Async """
707826

708827

@@ -800,16 +919,30 @@ def _fake_load(task_id, path="simulation_data.hdf5", lazy=False, **kwargs):
800919

801920

802921
def apply_common_patches(
803-
monkeypatch, tmp_root, *, api_path="tidy3d.web.api.webapi", path_to_sim=None, taskid_to_sim=None
922+
monkeypatch,
923+
tmp_root,
924+
*,
925+
api_path="tidy3d.web.api.webapi",
926+
taskid_to_sim=None,
804927
):
805928
"""Patch start/monitor/get_info/estimate_cost/upload/_check_folder/_modesolver_patch/load."""
806929
monkeypatch.setattr(f"{api_path}.start", lambda *a, **k: True)
807930
monkeypatch.setattr(f"{api_path}.monitor", lambda *a, **k: True)
808-
monkeypatch.setattr(f"{api_path}.get_info", lambda *a, **k: SimpleNamespace(status="success"))
931+
932+
# --- make get_info return also task type ---
933+
def _fake_get_info(task_id: str, *_, **__):
934+
sim = taskid_to_sim.get(task_id) if taskid_to_sim else None
935+
task_type = task_type_name_of(sim) if sim is not None else None
936+
return SimpleNamespace(status="success", taskType=task_type)
937+
938+
monkeypatch.setattr(f"{api_path}.get_info", _fake_get_info)
939+
940+
# other patches
809941
monkeypatch.setattr(f"{api_path}.estimate_cost", lambda *a, **k: 0.0)
810942
monkeypatch.setattr(f"{api_path}.upload", lambda *a, **k: k["task_name"])
811943
monkeypatch.setattr(WebContainer, "_check_folder", lambda *a, **k: True)
812944
monkeypatch.setattr(f"{api_path}._modesolver_patch", lambda *_, **__: None, raising=False)
945+
monkeypatch.setattr(f"{api_path}.download", lambda *_, **__: None, raising=False)
813946
monkeypatch.setattr(
814947
f"{api_path}.load",
815948
_fake_load_factory(tmp_root=str(tmp_root), taskid_to_sim=taskid_to_sim),

0 commit comments

Comments
 (0)