diff --git a/azimuth/modules/base_classes/dask_module.py b/azimuth/modules/base_classes/dask_module.py index 898cfca3..30e48624 100644 --- a/azimuth/modules/base_classes/dask_module.py +++ b/azimuth/modules/base_classes/dask_module.py @@ -7,6 +7,7 @@ import threading import time import uuid +from enum import IntEnum from functools import partial from os.path import join as pjoin from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, cast @@ -25,6 +26,11 @@ ConfigScope = TypeVar("ConfigScope", bound=CommonFieldsConfig) +class Worker(IntEnum): + model = 0 + encoder = 0 + + class DaskModule(HDF5CacheMixin, Generic[ConfigScope]): """Abstract class that define an item of work to be computed on the cluster. @@ -37,6 +43,7 @@ class DaskModule(HDF5CacheMixin, Generic[ConfigScope]): """ allowed_splits = {DatasetSplitName.train, DatasetSplitName.eval} + worker: Optional[Worker] = None def __init__( self, @@ -118,6 +125,7 @@ def start_task_on_dataset_split( pure=False, dependencies=deps, key=f"{self.task_id}_{uuid.uuid4()}", # Unique identifier + workers=self.worker, ) # Tell that this future is used on which indices. self.future.indices = self.get_caching_indices() @@ -148,7 +156,11 @@ def start_task(self, client: Client, custom_query: Dict[str, Any]) -> "DaskModul log.info(f"Starting custom query {self.name}") # pure=false to be sure that everything is rerun. self.future = client.submit( - self.compute, custom_query, key=self.custom_query_task_id(custom_query), pure=False + self.compute, + custom_query, + key=self.custom_query_task_id(custom_query), + pure=False, + workers=self.worker, ) # Tell that this future is for custom use only. self.future.is_custom = True diff --git a/azimuth/modules/base_classes/indexable_module.py b/azimuth/modules/base_classes/indexable_module.py index 4e08bdbe..4ae84842 100644 --- a/azimuth/modules/base_classes/indexable_module.py +++ b/azimuth/modules/base_classes/indexable_module.py @@ -12,6 +12,7 @@ from azimuth.config import ModelContractConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.modules.base_classes import ConfigScope, Module +from azimuth.modules.base_classes.dask_module import Worker from azimuth.types import ( DatasetColumn, DatasetSplitName, @@ -76,6 +77,7 @@ def save_result(self, res: List[ModuleResponse], dm: DatasetSplitManager): class ModelContractModule(DatasetResultModule[ModelContractConfig], abc.ABC): required_mod_options: Set[str] = {"pipeline_index", "model_contract_method_name"} optional_mod_options: Set[str] = DatasetResultModule.optional_mod_options | {"threshold"} + worker = Worker.model def compute(self, batch: Dataset) -> List[ModuleResponse]: my_func = self.route_request(assert_not_none(self.model_contract_method_name)) diff --git a/azimuth/modules/base_classes/module.py b/azimuth/modules/base_classes/module.py index 753f608f..0d041247 100644 --- a/azimuth/modules/base_classes/module.py +++ b/azimuth/modules/base_classes/module.py @@ -10,6 +10,7 @@ from azimuth.config import ModelContractConfig, PipelineDefinition from azimuth.dataset_split_manager import DatasetSplitManager, PredictionTableKey from azimuth.modules.base_classes import ArtifactManager, ConfigScope, DaskModule +from azimuth.modules.base_classes.dask_module import Worker from azimuth.types import DatasetColumn, DatasetSplitName, ModuleOptions, ModuleResponse from azimuth.types.general.module_arguments import ModuleEffectiveArguments from azimuth.utils.conversion import md5_hash @@ -164,6 +165,8 @@ def get_model(self): Raises: ValueError if no valid pipeline exists. """ + if self.worker != Worker.model: + raise RuntimeError("This module cannot load the model. Modify self.worker.") _ = self.get_pipeline_definition() # Validate current pipeline exists return self.artifact_manager.get_model(self.config, self.mod_options.pipeline_index) diff --git a/azimuth/modules/dataset_analysis/similarity_analysis.py b/azimuth/modules/dataset_analysis/similarity_analysis.py index 3319164c..759a70bd 100644 --- a/azimuth/modules/dataset_analysis/similarity_analysis.py +++ b/azimuth/modules/dataset_analysis/similarity_analysis.py @@ -16,6 +16,7 @@ from azimuth.config import SimilarityConfig, SimilarityOptions from azimuth.dataset_split_manager import FEATURE_FAISS, DatasetSplitManager from azimuth.modules.base_classes import DatasetResultModule, IndexableModule +from azimuth.modules.base_classes.dask_module import Worker from azimuth.modules.task_execution import get_task_result from azimuth.types import Array, DatasetColumn, DatasetSplitName, ModuleOptions from azimuth.types.similarity_analysis import FAISSResponse @@ -28,6 +29,8 @@ class FAISSModule(IndexableModule[SimilarityConfig]): """Compute the FAISS features for a dataset split.""" + worker = Worker.encoder + def __init__( self, dataset_split_name: DatasetSplitName, @@ -45,6 +48,8 @@ def get_encoder_name_or_path(self): return model_name_or_path def get_encoder(self): + if self.worker != Worker.encoder: + raise RuntimeError("This module cannot load the encoder. Modify self.worker.") if self.encoder is None: with FileLock(os.path.join(self.cache_dir, "st.lock")): self.encoder = SentenceTransformer(self.get_encoder_name_or_path()) diff --git a/azimuth/modules/perturbation_testing/perturbation_testing.py b/azimuth/modules/perturbation_testing/perturbation_testing.py index e189aa5b..7f313841 100644 --- a/azimuth/modules/perturbation_testing/perturbation_testing.py +++ b/azimuth/modules/perturbation_testing/perturbation_testing.py @@ -10,6 +10,7 @@ from azimuth.config import PerturbationTestingConfig from azimuth.dataset_split_manager import DatasetSplitManager from azimuth.modules.base_classes import DatasetResultModule +from azimuth.modules.base_classes.dask_module import Worker from azimuth.modules.model_contract_task_mapping import model_contract_task_mapping from azimuth.types import ( DatasetColumn, @@ -58,6 +59,9 @@ class PerturbationTestingModule(DatasetResultModule[PerturbationTestingConfig]): """ required_mod_options = {"pipeline_index"} + # This module doesn't call self.get_model() but requires the model (predict_task.compute(batch)) + # TODO Find a more robust way to determine when modules require models. + worker = Worker.model def __init__( self, diff --git a/azimuth/modules/validation/validation.py b/azimuth/modules/validation/validation.py index 1b68cee5..94dd8d5d 100644 --- a/azimuth/modules/validation/validation.py +++ b/azimuth/modules/validation/validation.py @@ -9,6 +9,7 @@ from azimuth.config import ModelContractConfig from azimuth.modules.base_classes import AggregationModule +from azimuth.modules.base_classes.dask_module import Worker from azimuth.modules.model_contract_task_mapping import model_contract_task_mapping from azimuth.types import ModuleOptions, SupportedMethod, SupportedModelContract from azimuth.types.validation import ValidationResponse @@ -36,6 +37,7 @@ def try_calling_function(self, fn, *args, **kwargs) -> Optional[Any]: class ValidationModule(AggregationModule[ModelContractConfig]): optional_mod_options = {"pipeline_index"} + worker = Worker.model def compute_on_dataset_split(self) -> List[ValidationResponse]: # type: ignore cuda_available = torch.cuda.is_available() diff --git a/azimuth/modules/word_analysis/top_words.py b/azimuth/modules/word_analysis/top_words.py index ec12532c..787628e7 100644 --- a/azimuth/modules/word_analysis/top_words.py +++ b/azimuth/modules/word_analysis/top_words.py @@ -9,6 +9,7 @@ from azimuth.config import TopWordsConfig from azimuth.modules.base_classes import FilterableModule +from azimuth.modules.base_classes.dask_module import Worker from azimuth.modules.task_execution import get_task_result from azimuth.modules.word_analysis.tokens_to_words import TokensToWordsModule from azimuth.types import ModuleOptions @@ -33,6 +34,7 @@ class TopWordsModule(FilterableModule[TopWordsConfig]): "th_importance", "force_no_saliency", } + worker = Worker.model @staticmethod def count_words(list_of_words: List[str], top_x: int) -> List[TopWordsResult]: diff --git a/poetry.lock b/poetry.lock index e2d43687..8eb91b1a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2383,6 +2383,21 @@ files = [ {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, ] +[[package]] +name = "memory-profiler" +version = "0.61.0" +description = "A module for monitoring memory usage of a python program" +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "memory_profiler-0.61.0-py3-none-any.whl", hash = "sha256:400348e61031e3942ad4d4109d18753b2fb08c2f6fb8290671c5513a34182d84"}, + {file = "memory_profiler-0.61.0.tar.gz", hash = "sha256:4e5b73d7864a1d1292fb76a03e82a3e78ef934d06828a698d9dada76da2067b0"}, +] + +[package.dependencies] +psutil = "*" + [[package]] name = "mergedeep" version = "1.3.4" @@ -6271,4 +6286,4 @@ gpu = ["onnxruntime-gpu"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.10" -content-hash = "6035d2c61babc388b323282b79ec49ff779497a5bab6780f9e97c1af57e43738" +content-hash = "e697b84ba2c9ab8175bd238a6a5d21b57fd586d3cb9140015024709b1b43a4ab" diff --git a/pyproject.toml b/pyproject.toml index e48bb4bd..3d610341 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ bokeh = "<3" # Documentation mkdocs = "^1.2.3" mkdocs-material = "^8.1.7" +memory-profiler = "^0.61.0" [build-system] requires = ["poetry-core>=1.0.0"]