Skip to content

Commit c956dbf

Browse files
committed
Assign tasks to workers (#525)
* Send tasks that require the pipelines to same worker * Move encoder to ArtifactManager * Add memory-profiler to profile the memory usage * Move the encoder out of the ArtifactManager * Adapt based on comments * Assign encoder and model tasks to same worker * Fix similarity test * Add worker to custom tasks * Add TODO
1 parent 40b93ed commit c956dbf

File tree

9 files changed

+48
-2
lines changed

9 files changed

+48
-2
lines changed

azimuth/modules/base_classes/dask_module.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import threading
88
import time
99
import uuid
10+
from enum import IntEnum
1011
from functools import partial
1112
from os.path import join as pjoin
1213
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, cast
@@ -25,6 +26,11 @@
2526
ConfigScope = TypeVar("ConfigScope", bound=CommonFieldsConfig)
2627

2728

29+
class Worker(IntEnum):
30+
model = 0
31+
encoder = 0
32+
33+
2834
class DaskModule(HDF5CacheMixin, Generic[ConfigScope]):
2935
"""Abstract class that define an item of work to be computed on the cluster.
3036
@@ -37,6 +43,7 @@ class DaskModule(HDF5CacheMixin, Generic[ConfigScope]):
3743
"""
3844

3945
allowed_splits = {DatasetSplitName.train, DatasetSplitName.eval}
46+
worker: Optional[Worker] = None
4047

4148
def __init__(
4249
self,
@@ -118,6 +125,7 @@ def start_task_on_dataset_split(
118125
pure=False,
119126
dependencies=deps,
120127
key=f"{self.task_id}_{uuid.uuid4()}", # Unique identifier
128+
workers=self.worker,
121129
)
122130
# Tell that this future is used on which indices.
123131
self.future.indices = self.get_caching_indices()
@@ -148,7 +156,11 @@ def start_task(self, client: Client, custom_query: Dict[str, Any]) -> "DaskModul
148156
log.info(f"Starting custom query {self.name}")
149157
# pure=false to be sure that everything is rerun.
150158
self.future = client.submit(
151-
self.compute, custom_query, key=self.custom_query_task_id(custom_query), pure=False
159+
self.compute,
160+
custom_query,
161+
key=self.custom_query_task_id(custom_query),
162+
pure=False,
163+
workers=self.worker,
152164
)
153165
# Tell that this future is for custom use only.
154166
self.future.is_custom = True

azimuth/modules/base_classes/indexable_module.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from azimuth.config import ModelContractConfig
1313
from azimuth.dataset_split_manager import DatasetSplitManager
1414
from azimuth.modules.base_classes import ConfigScope, Module
15+
from azimuth.modules.base_classes.dask_module import Worker
1516
from azimuth.types import (
1617
DatasetColumn,
1718
DatasetSplitName,
@@ -76,6 +77,7 @@ def save_result(self, res: List[ModuleResponse], dm: DatasetSplitManager):
7677
class ModelContractModule(DatasetResultModule[ModelContractConfig], abc.ABC):
7778
required_mod_options: Set[str] = {"pipeline_index", "model_contract_method_name"}
7879
optional_mod_options: Set[str] = DatasetResultModule.optional_mod_options | {"threshold"}
80+
worker = Worker.model
7981

8082
def compute(self, batch: Dataset) -> List[ModuleResponse]:
8183
my_func = self.route_request(assert_not_none(self.model_contract_method_name))

azimuth/modules/base_classes/module.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from azimuth.config import ModelContractConfig, PipelineDefinition
1111
from azimuth.dataset_split_manager import DatasetSplitManager, PredictionTableKey
1212
from azimuth.modules.base_classes import ArtifactManager, ConfigScope, DaskModule
13+
from azimuth.modules.base_classes.dask_module import Worker
1314
from azimuth.types import DatasetColumn, DatasetSplitName, ModuleOptions, ModuleResponse
1415
from azimuth.types.general.module_arguments import ModuleEffectiveArguments
1516
from azimuth.utils.conversion import md5_hash
@@ -164,6 +165,8 @@ def get_model(self):
164165
Raises:
165166
ValueError if no valid pipeline exists.
166167
"""
168+
if self.worker != Worker.model:
169+
raise RuntimeError("This module cannot load the model. Modify self.worker.")
167170
_ = self.get_pipeline_definition() # Validate current pipeline exists
168171
return self.artifact_manager.get_model(self.config, self.mod_options.pipeline_index)
169172

azimuth/modules/dataset_analysis/similarity_analysis.py

+5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from azimuth.config import SimilarityConfig, SimilarityOptions
1717
from azimuth.dataset_split_manager import FEATURE_FAISS, DatasetSplitManager
1818
from azimuth.modules.base_classes import DatasetResultModule, IndexableModule
19+
from azimuth.modules.base_classes.dask_module import Worker
1920
from azimuth.modules.task_execution import get_task_result
2021
from azimuth.types import Array, DatasetColumn, DatasetSplitName, ModuleOptions
2122
from azimuth.types.similarity_analysis import FAISSResponse
@@ -28,6 +29,8 @@
2829
class FAISSModule(IndexableModule[SimilarityConfig]):
2930
"""Compute the FAISS features for a dataset split."""
3031

32+
worker = Worker.encoder
33+
3134
def __init__(
3235
self,
3336
dataset_split_name: DatasetSplitName,
@@ -45,6 +48,8 @@ def get_encoder_name_or_path(self):
4548
return model_name_or_path
4649

4750
def get_encoder(self):
51+
if self.worker != Worker.encoder:
52+
raise RuntimeError("This module cannot load the encoder. Modify self.worker.")
4853
if self.encoder is None:
4954
with FileLock(os.path.join(self.cache_dir, "st.lock")):
5055
self.encoder = SentenceTransformer(self.get_encoder_name_or_path())

azimuth/modules/perturbation_testing/perturbation_testing.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from azimuth.config import PerturbationTestingConfig
1111
from azimuth.dataset_split_manager import DatasetSplitManager
1212
from azimuth.modules.base_classes import DatasetResultModule
13+
from azimuth.modules.base_classes.dask_module import Worker
1314
from azimuth.modules.model_contract_task_mapping import model_contract_task_mapping
1415
from azimuth.types import (
1516
DatasetColumn,
@@ -58,6 +59,9 @@ class PerturbationTestingModule(DatasetResultModule[PerturbationTestingConfig]):
5859
"""
5960

6061
required_mod_options = {"pipeline_index"}
62+
# This module doesn't call self.get_model() but requires the model (predict_task.compute(batch))
63+
# TODO Find a more robust way to determine when modules require models.
64+
worker = Worker.model
6165

6266
def __init__(
6367
self,

azimuth/modules/validation/validation.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from azimuth.config import ModelContractConfig
1111
from azimuth.modules.base_classes import AggregationModule
12+
from azimuth.modules.base_classes.dask_module import Worker
1213
from azimuth.modules.model_contract_task_mapping import model_contract_task_mapping
1314
from azimuth.types import ModuleOptions, SupportedMethod, SupportedModelContract
1415
from azimuth.types.validation import ValidationResponse
@@ -36,6 +37,7 @@ def try_calling_function(self, fn, *args, **kwargs) -> Optional[Any]:
3637

3738
class ValidationModule(AggregationModule[ModelContractConfig]):
3839
optional_mod_options = {"pipeline_index"}
40+
worker = Worker.model
3941

4042
def compute_on_dataset_split(self) -> List[ValidationResponse]: # type: ignore
4143
cuda_available = torch.cuda.is_available()

azimuth/modules/word_analysis/top_words.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from azimuth.config import TopWordsConfig
1111
from azimuth.modules.base_classes import FilterableModule
12+
from azimuth.modules.base_classes.dask_module import Worker
1213
from azimuth.modules.task_execution import get_task_result
1314
from azimuth.modules.word_analysis.tokens_to_words import TokensToWordsModule
1415
from azimuth.types import ModuleOptions
@@ -33,6 +34,7 @@ class TopWordsModule(FilterableModule[TopWordsConfig]):
3334
"th_importance",
3435
"force_no_saliency",
3536
}
37+
worker = Worker.model
3638

3739
@staticmethod
3840
def count_words(list_of_words: List[str], top_x: int) -> List[TopWordsResult]:

poetry.lock

+16-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ bokeh = "<3"
8888
# Documentation
8989
mkdocs = "^1.2.3"
9090
mkdocs-material = "^8.1.7"
91+
memory-profiler = "^0.61.0"
9192

9293
[build-system]
9394
requires = ["poetry-core>=1.0.0"]

0 commit comments

Comments
 (0)