|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import os |
| 5 | +from concurrent.futures import Future |
5 | 6 | from types import SimpleNamespace |
6 | 7 |
|
7 | 8 | import numpy as np |
|
25 | 26 | from tidy3d.web.api.asynchronous import run_async |
26 | 27 | from tidy3d.web.api.container import Batch, Job, WebContainer |
27 | 28 | from tidy3d.web.api.run import _collect_by_hash, run |
28 | | -from tidy3d.web.api.tidy3d_stub import Tidy3dStubData |
| 29 | +from tidy3d.web.api.tidy3d_stub import Tidy3dStubData, task_type_name_of |
29 | 30 | from tidy3d.web.api.webapi import ( |
30 | 31 | abort, |
31 | 32 | delete, |
|
65 | 66 | Env.dev.active() |
66 | 67 |
|
67 | 68 |
|
| 69 | +class ImmediateExecutor: |
| 70 | + def __init__(self, *args, **kwargs): |
| 71 | + pass |
| 72 | + |
| 73 | + def submit(self, fn, *args, **kwargs): |
| 74 | + future = Future() |
| 75 | + try: |
| 76 | + result = fn(*args, **kwargs) |
| 77 | + except Exception as err: # pragma: no cover - defensive |
| 78 | + future.set_exception(err) |
| 79 | + else: |
| 80 | + future.set_result(result) |
| 81 | + return future |
| 82 | + |
| 83 | + def shutdown(self, wait=True): |
| 84 | + pass |
| 85 | + |
| 86 | + |
68 | 87 | def make_sim(): |
69 | 88 | """Makes a simulation.""" |
70 | 89 | pulse = td.GaussianPulse(freq0=200e12, fwidth=20e12) |
@@ -703,6 +722,106 @@ def mock_start_interrupt(self, *args, **kwargs): |
703 | 722 | batch.run(path_dir=str(tmp_path)) |
704 | 723 |
|
705 | 724 |
|
| 725 | +def test_batch_monitor_downloads_on_success(monkeypatch, tmp_path): |
| 726 | + events = [] |
| 727 | + |
| 728 | + class FakeJob: |
| 729 | + def __init__(self, task_id: str, statuses: list[str]): |
| 730 | + self.task_id = task_id |
| 731 | + self._statuses = statuses |
| 732 | + self._idx = 0 |
| 733 | + |
| 734 | + @property |
| 735 | + def status(self): |
| 736 | + status = self._statuses[self._idx] |
| 737 | + if self._idx < len(self._statuses) - 1: |
| 738 | + self._idx += 1 |
| 739 | + events.append((self.task_id, "status", status)) |
| 740 | + return status |
| 741 | + |
| 742 | + def download(self, path: str): |
| 743 | + events.append((self.task_id, "download", path)) |
| 744 | + |
| 745 | + monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor) |
| 746 | + monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None) |
| 747 | + |
| 748 | + sims = {"task_a": make_sim(), "task_b": make_sim()} |
| 749 | + batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False) |
| 750 | + batch._cached_properties = {} |
| 751 | + fake_jobs = { |
| 752 | + "task_a": FakeJob("task_a_id", ["running", "success", "success"]), |
| 753 | + "task_b": FakeJob("task_b_id", ["running", "running", "success"]), |
| 754 | + } |
| 755 | + batch._cached_properties["jobs"] = fake_jobs |
| 756 | + |
| 757 | + batch.monitor(download_on_success=True, path_dir=str(tmp_path)) |
| 758 | + |
| 759 | + downloads = [event for event in events if event[1] == "download"] |
| 760 | + assert len(downloads) == 2 |
| 761 | + assert {event[0] for event in downloads} == {"task_a_id", "task_b_id"} |
| 762 | + |
| 763 | + expected_paths = { |
| 764 | + "task_a_id": os.path.join(str(tmp_path), "task_a_id.hdf5"), |
| 765 | + "task_b_id": os.path.join(str(tmp_path), "task_b_id.hdf5"), |
| 766 | + } |
| 767 | + |
| 768 | + for task_id, _, path in downloads: |
| 769 | + assert path == expected_paths[task_id] |
| 770 | + |
| 771 | + job1_download_idx = next( |
| 772 | + i |
| 773 | + for i, event in enumerate(events) |
| 774 | + if event == ("task_a_id", "download", expected_paths["task_a_id"]) |
| 775 | + ) |
| 776 | + job2_success_idx = next( |
| 777 | + i for i, event in enumerate(events) if event == ("task_b_id", "status", "success") |
| 778 | + ) |
| 779 | + |
| 780 | + assert job1_download_idx < job2_success_idx, "Download should start before other jobs finish" |
| 781 | + |
| 782 | + |
| 783 | +def test_batch_monitor_skips_existing_download(monkeypatch, tmp_path): |
| 784 | + events = [] |
| 785 | + |
| 786 | + class FakeJob: |
| 787 | + def __init__(self, task_id: str, statuses: list[str]): |
| 788 | + self.task_id = task_id |
| 789 | + self._statuses = statuses |
| 790 | + self._idx = 0 |
| 791 | + |
| 792 | + @property |
| 793 | + def status(self): |
| 794 | + status = self._statuses[self._idx] |
| 795 | + if self._idx < len(self._statuses) - 1: |
| 796 | + self._idx += 1 |
| 797 | + events.append((self.task_id, "status", status)) |
| 798 | + return status |
| 799 | + |
| 800 | + def download(self, path: str): |
| 801 | + events.append((self.task_id, "download", path)) |
| 802 | + |
| 803 | + monkeypatch.setattr("tidy3d.web.api.container.ThreadPoolExecutor", ImmediateExecutor) |
| 804 | + monkeypatch.setattr("tidy3d.web.api.container.time.sleep", lambda *_args, **_kwargs: None) |
| 805 | + |
| 806 | + sims = {"task_a": make_sim(), "task_b": make_sim()} |
| 807 | + batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False) |
| 808 | + batch._cached_properties = {} |
| 809 | + fake_jobs = { |
| 810 | + "task_a": FakeJob("task_a_id", ["success", "success"]), |
| 811 | + "task_b": FakeJob("task_b_id", ["running", "success"]), |
| 812 | + } |
| 813 | + batch._cached_properties["jobs"] = fake_jobs |
| 814 | + |
| 815 | + existing_path = os.path.join(str(tmp_path), "task_a_id.hdf5") |
| 816 | + with open(existing_path, "w", encoding="utf8") as handle: |
| 817 | + handle.write("cached") |
| 818 | + |
| 819 | + batch.monitor(download_on_success=True, path_dir=str(tmp_path)) |
| 820 | + |
| 821 | + downloads = [event for event in events if event[1] == "download"] |
| 822 | + assert downloads == [("task_b_id", "download", os.path.join(str(tmp_path), "task_b_id.hdf5"))] |
| 823 | + |
| 824 | + |
706 | 825 | """ Async """ |
707 | 826 |
|
708 | 827 |
|
@@ -800,16 +919,30 @@ def _fake_load(task_id, path="simulation_data.hdf5", lazy=False, **kwargs): |
800 | 919 |
|
801 | 920 |
|
802 | 921 | def apply_common_patches( |
803 | | - monkeypatch, tmp_root, *, api_path="tidy3d.web.api.webapi", path_to_sim=None, taskid_to_sim=None |
| 922 | + monkeypatch, |
| 923 | + tmp_root, |
| 924 | + *, |
| 925 | + api_path="tidy3d.web.api.webapi", |
| 926 | + taskid_to_sim=None, |
804 | 927 | ): |
805 | 928 | """Patch start/monitor/get_info/estimate_cost/upload/_check_folder/_modesolver_patch/load.""" |
806 | 929 | monkeypatch.setattr(f"{api_path}.start", lambda *a, **k: True) |
807 | 930 | monkeypatch.setattr(f"{api_path}.monitor", lambda *a, **k: True) |
808 | | - monkeypatch.setattr(f"{api_path}.get_info", lambda *a, **k: SimpleNamespace(status="success")) |
| 931 | + |
| 932 | + # --- make get_info return also task type --- |
| 933 | + def _fake_get_info(task_id: str, *_, **__): |
| 934 | + sim = taskid_to_sim.get(task_id) if taskid_to_sim else None |
| 935 | + task_type = task_type_name_of(sim) if sim is not None else None |
| 936 | + return SimpleNamespace(status="success", taskType=task_type) |
| 937 | + |
| 938 | + monkeypatch.setattr(f"{api_path}.get_info", _fake_get_info) |
| 939 | + |
| 940 | + # other patches |
809 | 941 | monkeypatch.setattr(f"{api_path}.estimate_cost", lambda *a, **k: 0.0) |
810 | 942 | monkeypatch.setattr(f"{api_path}.upload", lambda *a, **k: k["task_name"]) |
811 | 943 | monkeypatch.setattr(WebContainer, "_check_folder", lambda *a, **k: True) |
812 | 944 | monkeypatch.setattr(f"{api_path}._modesolver_patch", lambda *_, **__: None, raising=False) |
| 945 | + monkeypatch.setattr(f"{api_path}.download", lambda *_, **__: None, raising=False) |
813 | 946 | monkeypatch.setattr( |
814 | 947 | f"{api_path}.load", |
815 | 948 | _fake_load_factory(tmp_root=str(tmp_root), taskid_to_sim=taskid_to_sim), |
|
0 commit comments