Skip to content

Commit 825ef39

Browse files
committed
Send tasks that require the pipelines to same worker
1 parent 10e0bb7 commit 825ef39

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

azimuth/modules/base_classes/dask_module.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from datasets import Dataset
1616
from distributed import Client, Event, Future, rejoin, secede
1717

18-
from azimuth.config import CommonFieldsConfig
18+
from azimuth.config import CommonFieldsConfig, ModelContractConfig
1919
from azimuth.modules.base_classes.caching import HDF5CacheMixin
2020
from azimuth.types import DatasetSplitName, ModuleResponse
2121
from azimuth.utils.logs import TimerLogging
@@ -94,6 +94,11 @@ def _get_config_scope(self, config) -> ConfigScope:
9494
scoped_config = base.__args__[0]
9595
return cast(ConfigScope, scoped_config.parse_obj(config.dict(by_alias=True)))
9696

97+
@property
98+
def can_load_model(self) -> bool:
99+
# TODO Not all modules that inherit from ModelContractConfig load the model. Smarter way?
100+
return isinstance(self.config, ModelContractConfig)
101+
97102
def start_task_on_dataset_split(
98103
self, client: Client, dependencies: List["DaskModule"] = None
99104
) -> "DaskModule":
@@ -118,6 +123,7 @@ def start_task_on_dataset_split(
118123
pure=False,
119124
dependencies=deps,
120125
key=f"{self.task_id}_{uuid.uuid4()}", # Unique identifier
126+
workers=0 if self.can_load_model else 1,
121127
)
122128
# Tell that this future is used on which indices.
123129
self.future.indices = self.get_caching_indices()

azimuth/utils/cluster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def default_cluster(large=False) -> SpecCluster:
4040
tmp_file = pjoin(str(tempfile.mkdtemp()), "dask-worker-space")
4141
with dask.config.set({"distributed.worker.daemon": False}):
4242
cluster = distributed.LocalCluster(
43-
n_workers=2,
43+
n_workers=2, # Assignment to workers is hard-coded, so it needs to stay 2.
4444
local_directory=tmp_file,
4545
threads_per_worker=1,
4646
memory_limit=memory_limit, # "auto" doesnt work well.

tests/test_routers/conftest.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastapi import FastAPI
99
from starlette.testclient import TestClient
1010

11-
import azimuth.app as me_app
11+
import azimuth.app as azimuth_app
1212
from azimuth.app import get_ready_flag
1313
from azimuth.config import AzimuthConfig
1414
from tests.utils import DATASET_CFG, SIMPLE_PERTURBATION_TESTING_CONFIG
@@ -24,7 +24,7 @@ def is_set(self):
2424

2525
def create_test_app(config) -> FastAPI:
2626
json.dump(config.dict(by_alias=True), open("/tmp/config.json", "w"))
27-
return me_app.start_app("/tmp/config.json", load_config_history=False, debug=False)
27+
return azimuth_app.start_app("/tmp/config.json", load_config_history=False, debug=False)
2828

2929

3030
FAST_TEST_CFG = {
@@ -44,7 +44,7 @@ def wait_for_startup_after(app):
4444
while resp.json()["startupTasksReady"] is not True:
4545
time.sleep(1)
4646
resp = client.get("/status")
47-
task_manager = me_app.get_task_manager()
47+
task_manager = azimuth_app.get_task_manager()
4848
while task_manager.is_locked:
4949
time.sleep(1)
5050

@@ -70,7 +70,7 @@ def app() -> FastAPI:
7070
while resp.json()["startupTasksReady"] is not True:
7171
time.sleep(1)
7272
resp = client.get("/status")
73-
task_manager = me_app.get_task_manager()
73+
task_manager = azimuth_app.get_task_manager()
7474
while task_manager.is_locked:
7575
time.sleep(1)
7676
yield _app
@@ -79,7 +79,7 @@ def app() -> FastAPI:
7979
@pytest.fixture(scope="function")
8080
def app_not_started(app) -> FastAPI:
8181

82-
startup_tasks = me_app.get_startup_tasks()
82+
startup_tasks = azimuth_app.get_startup_tasks()
8383

8484
class ModuleThatWillNeverEnd:
8585
def status(self):

0 commit comments

Comments
 (0)