Skip to content

Commit 9ba2f9b

Browse files
authoredApr 18, 2023
Refactor ArtifactManager and restart TaskManager (#529)
* Add logging to help with config update debugging * Restart client to free memory * Refactor artifact manager to real singleton and remove clear_cache * Adapt based on comments
1 parent 6f98295 commit 9ba2f9b

19 files changed

+153
-146
lines changed
 

‎azimuth/app.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@
2525

2626
from azimuth.config import AzimuthConfig, load_azimuth_config
2727
from azimuth.dataset_split_manager import DatasetSplitManager
28-
from azimuth.modules.base_classes import DaskModule
28+
from azimuth.modules.base_classes import ArtifactManager, DaskModule
2929
from azimuth.startup import startup_tasks
3030
from azimuth.task_manager import TaskManager
3131
from azimuth.types import DatasetSplitName, ModuleOptions, SupportedModule
3232
from azimuth.utils.cluster import default_cluster
3333
from azimuth.utils.conversion import JSONResponseIgnoreNan
3434
from azimuth.utils.logs import set_logger_config
35-
from azimuth.utils.project import load_dataset_split_managers_from_config
3635
from azimuth.utils.validation import assert_not_none
3736

3837
_dataset_split_managers: Dict[DatasetSplitName, Optional[DatasetSplitManager]] = {}
@@ -296,6 +295,32 @@ def create_app() -> FastAPI:
296295
return app
297296

298297

298+
def load_dataset_split_managers_from_config(
299+
azimuth_config: AzimuthConfig,
300+
) -> Dict[DatasetSplitName, Optional[DatasetSplitManager]]:
301+
"""
302+
Load all dataset splits for the application.
303+
304+
Args:
305+
azimuth_config: Azimuth Configuration.
306+
307+
Returns:
308+
For all DatasetSplitName, None or a dataset_split manager.
309+
310+
"""
311+
artifact_manager = ArtifactManager.instance()
312+
dataset = artifact_manager.get_dataset_dict(azimuth_config)
313+
314+
return {
315+
dataset_split_name: None
316+
if dataset_split_name not in dataset
317+
else artifact_manager.get_dataset_split_manager(
318+
azimuth_config, DatasetSplitName[dataset_split_name]
319+
)
320+
for dataset_split_name in [DatasetSplitName.eval, DatasetSplitName.train]
321+
}
322+
323+
299324
def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
300325
"""Initialize DatasetSplitManagers and TaskManagers.
301326
@@ -346,7 +371,6 @@ def run_validation_module(pipeline_index=None):
346371
else:
347372
for pipeline_index in range(len(config.pipelines)):
348373
run_validation_module(pipeline_index)
349-
task_manager.clear_worker_cache()
350374

351375

352376
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):

‎azimuth/config.py

+10
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,16 @@ def check_pipeline_names(cls, pipeline_definitions):
358358
raise ValueError(f"Duplicated pipeline names {pipeline_names}.")
359359
return pipeline_definitions
360360

361+
def get_model_contract_hash(self):
362+
"""Hash for fields related to model contract only (excluding fields from the parents)."""
363+
return md5_hash(
364+
self.dict(
365+
include=ModelContractConfig.__fields__.keys()
366+
- CommonFieldsConfig.__fields__.keys(),
367+
by_alias=True,
368+
)
369+
)
370+
361371

362372
class MetricsConfig(ModelContractConfig):
363373
# Custom HuggingFace metrics

‎azimuth/modules/base_classes/artifact_manager.py

+65-31
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright ServiceNow, Inc. 2021 – 2022
22
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
33
# in the root directory of this source tree.
4-
from multiprocessing import Lock
5-
from typing import Callable, Dict, Optional
4+
from collections import defaultdict
5+
from typing import Callable, Dict
66

7+
import structlog
78
from datasets import DatasetDict
89

910
from azimuth.config import AzimuthConfig
@@ -18,31 +19,67 @@
1819
Hash = int
1920

2021

22+
log = structlog.get_logger()
23+
24+
25+
class Singleton:
26+
"""
27+
A non-thread-safe helper class to ease implementing singletons.
28+
This should be used as a decorator -- not a metaclass -- to the
29+
class that should be a singleton.
30+
31+
To get the singleton instance, use the `instance` method. Trying
32+
to use `__call__` will result in a `TypeError` being raised.
33+
34+
Args:
35+
decorated: Decorated class
36+
"""
37+
38+
def __init__(self, decorated):
39+
self._decorated = decorated
40+
41+
def instance(self):
42+
"""
43+
Returns the singleton instance. Upon its first call, it creates a
44+
new instance of the decorated class and calls its `__init__` method.
45+
On all subsequent calls, the already created instance is returned.
46+
47+
Returns:
48+
Instance of the decorated class
49+
"""
50+
try:
51+
return self._instance
52+
except AttributeError:
53+
self._instance = self._decorated()
54+
return self._instance
55+
56+
def __call__(self):
57+
raise TypeError("Singletons must be accessed through `instance()`.")
58+
59+
def clear_instance(self):
60+
"""For test purposes only"""
61+
if hasattr(self, "_instance"):
62+
delattr(self, "_instance")
63+
64+
65+
@Singleton
2166
class ArtifactManager:
2267
"""This class is a singleton which holds different artifacts.
2368
2469
Artifacts include dataset_split_managers, datasets and models for each config, so they don't
25-
need to be reloaded many times for a same module.
70+
need to be reloaded many times for a same module. Inspired from
71+
https://stackoverflow.com/questions/31875/is-there-a-simple-elegant-way-to-define-singletons.
2672
"""
2773

28-
instance: Optional["ArtifactManager"] = None
29-
3074
def __init__(self):
3175
# The keys of the dict are a hash of the config.
3276
self.dataset_dict_mapping: Dict[Hash, DatasetDict] = {}
3377
self.dataset_split_managers_mapping: Dict[
3478
Hash, Dict[DatasetSplitName, DatasetSplitManager]
35-
] = {}
36-
self.models_mapping: Dict[Hash, Dict[int, Callable]] = {}
37-
self.tokenizer = None
79+
] = defaultdict(dict)
80+
self.models_mapping: Dict[Hash, Dict[int, Callable]] = defaultdict(dict)
3881
self.metrics = {}
39-
40-
@classmethod
41-
def get_instance(cls):
42-
with Lock():
43-
if cls.instance is None:
44-
cls.instance = cls()
45-
return cls.instance
82+
log.debug(f"Creating new Artifact Manager {id(self)}.")
4683

4784
def get_dataset_split_manager(
4885
self, config: AzimuthConfig, name: DatasetSplitName
@@ -68,8 +105,6 @@ def get_dataset_split_manager(
68105
f"Found {tuple(dataset_dict.keys())}."
69106
)
70107
project_hash: Hash = config.get_project_hash()
71-
if project_hash not in self.dataset_split_managers_mapping:
72-
self.dataset_split_managers_mapping[project_hash] = {}
73108
if name not in self.dataset_split_managers_mapping[project_hash]:
74109
self.dataset_split_managers_mapping[project_hash][name] = DatasetSplitManager(
75110
name=name,
@@ -78,6 +113,7 @@ def get_dataset_split_manager(
78113
initial_prediction_tags=ALL_PREDICTION_TAGS,
79114
dataset_split=dataset_dict[name],
80115
)
116+
log.debug(f"New {name} DM in Artifact Manager {id(self)}")
81117
return self.dataset_split_managers_mapping[project_hash][name]
82118

83119
def get_dataset_dict(self, config) -> DatasetDict:
@@ -106,25 +142,23 @@ def get_model(self, config: AzimuthConfig, pipeline_idx: int):
106142
Returns:
107143
Loaded model.
108144
"""
109-
110-
project_hash: Hash = config.get_project_hash()
111-
if project_hash not in self.models_mapping:
112-
self.models_mapping[project_hash] = {}
113-
if pipeline_idx not in self.models_mapping[project_hash]:
145+
model_contract_hash: Hash = config.get_model_contract_hash()
146+
if pipeline_idx not in self.models_mapping[model_contract_hash]:
147+
log.debug(f"Loading pipeline {pipeline_idx}.")
114148
pipelines = assert_not_none(config.pipelines)
115-
self.models_mapping[project_hash][pipeline_idx] = load_custom_object(
149+
self.models_mapping[model_contract_hash][pipeline_idx] = load_custom_object(
116150
assert_not_none(pipelines[pipeline_idx].model), azimuth_config=config
117151
)
118152

119-
return self.models_mapping[project_hash][pipeline_idx]
153+
return self.models_mapping[model_contract_hash][pipeline_idx]
120154

121155
def get_metric(self, config, name: str, **kwargs):
122-
hash: Hash = md5_hash({"name": name, **kwargs})
123-
if hash not in self.metrics:
124-
self.metrics[hash] = load_custom_object(config.metrics[name], **kwargs)
125-
return self.metrics[hash]
156+
metric_hash: Hash = md5_hash({"name": name, **kwargs})
157+
if metric_hash not in self.metrics:
158+
self.metrics[metric_hash] = load_custom_object(config.metrics[name], **kwargs)
159+
return self.metrics[metric_hash]
126160

127161
@classmethod
128-
def clear_cache(cls) -> None:
129-
with Lock():
130-
cls.instance = None
162+
def instance(cls):
163+
# Implemented in decorator
164+
raise NotImplementedError

‎azimuth/modules/base_classes/module.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_indices(self, name: Optional[DatasetSplitName] = None) -> List[int]:
7979
def artifact_manager(self):
8080
"""This is set as a property so the Module always have access to the current version of
8181
the ArtifactManager on the worker."""
82-
return ArtifactManager.get_instance()
82+
return ArtifactManager.instance()
8383

8484
@property
8585
def available_dataset_splits(self) -> Set[DatasetSplitName]:
@@ -215,6 +215,3 @@ def get_pipeline_definition(self) -> PipelineDefinition:
215215
pipeline_index = assert_not_none(self.mod_options.pipeline_index)
216216
current_pipeline = pipelines[pipeline_index]
217217
return current_pipeline
218-
219-
def clear_cache(self):
220-
self.artifact_manager.clear_cache()

‎azimuth/routers/config.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,26 @@ def patch_config(
7878
)
7979

8080
try:
81+
log.info(f"Validating config change with {partial_config}.")
8182
new_config = update_config(old_config=config, partial_config=partial_config)
8283
if attribute_changed_in_config("large_dask_cluster", partial_config, config):
8384
cluster = default_cluster(partial_config["large_dask_cluster"])
8485
else:
8586
cluster = task_manager.cluster
8687
run_startup_tasks(new_config, cluster)
88+
log.info(f"Config successfully updated with {partial_config}.")
8789
except Exception as e:
8890
log.error("Rollback config update due to error", exc_info=e)
8991
new_config = config
9092
initialize_managers(new_config, task_manager.cluster)
93+
log.info("Config update cancelled.")
9194
if isinstance(e, (AzimuthValidationError, ValidationError)):
9295
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
9396
else:
9497
raise HTTPException(
9598
HTTP_500_INTERNAL_SERVER_ERROR, detail="Error when loading the new config."
9699
)
97100

98-
# Clear workers so that they load the correct config.
99-
task_manager.clear_worker_cache()
100101
return new_config
101102

102103

‎azimuth/routers/export.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import time
77
from os.path import join as pjoin
8-
from typing import Dict, Generator, List, Optional, cast
8+
from typing import Dict, Generator, List, Optional
99

1010
import pandas as pd
1111
from fastapi import APIRouter, Depends, HTTPException
@@ -155,7 +155,10 @@ def get_export_perturbed_set(
155155

156156
output = list(
157157
make_utterance_level_result(
158-
dataset_split_manager, task_result, pipeline_index=pipeline_index_not_null
158+
dataset_split_manager,
159+
task_result,
160+
pipeline_index=pipeline_index_not_null,
161+
config=config,
159162
)
160163
)
161164
with open(path, "w") as f:
@@ -164,20 +167,23 @@ def get_export_perturbed_set(
164167

165168

166169
def make_utterance_level_result(
167-
dm: DatasetSplitManager, results: List[List[PerturbedUtteranceResult]], pipeline_index: int
170+
dm: DatasetSplitManager,
171+
results: List[List[PerturbedUtteranceResult]],
172+
pipeline_index: int,
173+
config: AzimuthConfig,
168174
) -> Generator[Dict, None, None]:
169175
"""Massage perturbation testing results for the frontend.
170176
171177
Args:
172178
dm: Current DatasetSplitManager.
173179
results: Output of Perturbation Testing.
174180
pipeline_index: Index of the pipeline that made the results.
181+
config: Azimuth config
175182
176183
Returns:
177184
Generator that yield json-able object for the frontend.
178185
179186
"""
180-
config = cast(AzimuthConfig, dm.config)
181187
for idx, (utterance, test_results) in enumerate(
182188
zip(
183189
dm.get_dataset_split(

‎azimuth/routers/utterances.py

-2
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ def patch_utterances(
229229
utterances: List[UtterancePatch] = Body(...),
230230
config: AzimuthConfig = Depends(get_config),
231231
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
232-
task_manager: TaskManager = Depends(get_task_manager),
233232
ignore_not_found: bool = Query(False),
234233
) -> List[UtterancePatch]:
235234
if ignore_not_found:
@@ -250,7 +249,6 @@ def patch_utterances(
250249

251250
dataset_split_manager.add_tags(data_actions)
252251

253-
task_manager.clear_worker_cache()
254252
updated_tags = dataset_split_manager.get_tags(row_indices)
255253

256254
return [

‎azimuth/startup.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ def on_end(fut: Future, module: DaskModule, dm: DatasetSplitManager, task_manage
153153
# Task is done, save the result.
154154
if isinstance(module, DatasetResultModule):
155155
module.save_result(module.result(), dm)
156-
# We only need to clear cache when the dataset is modified.
157-
task_manager.clear_worker_cache()
158156
else:
159157
log.exception("Error in", module=module, fut=fut, exc_info=fut.exception())
160158

@@ -257,12 +255,16 @@ def startup_tasks(
257255

258256
mods = start_tasks_for_dms(config, dataset_split_managers, task_manager, start_up_tasks)
259257

260-
# Start a thread to monitor the status.
261-
th = threading.Thread(
262-
target=wait_for_startup, args=(mods, task_manager), name=START_UP_THREAD_NAME
263-
)
264-
th.setDaemon(True)
265-
th.start()
258+
startup_ready = all(m.done() for m in mods.values())
259+
if startup_ready:
260+
log.info("Loading the application from cache. It should be accessible now.")
261+
else:
262+
# Start a thread to monitor the status.
263+
th = threading.Thread(
264+
target=wait_for_startup, args=(mods, task_manager), name=START_UP_THREAD_NAME
265+
)
266+
th.setDaemon(True)
267+
th.start()
266268

267269
return mods
268270

@@ -376,4 +378,3 @@ def log_progress():
376378
task_manager.restart()
377379
# After restarting, it is safe to unlock the task manager.
378380
task_manager.unlock()
379-
log.info("Cluster restarted to free memory.")

‎azimuth/task_manager.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from distributed import Client, SpecCluster
88

99
from azimuth.config import AzimuthConfig
10-
from azimuth.modules.base_classes import ArtifactManager, DaskModule, ExpirableMixin
10+
from azimuth.modules.base_classes import DaskModule, ExpirableMixin
1111
from azimuth.modules.task_mapping import model_contract_methods, modules
1212
from azimuth.types import (
1313
DatasetSplitName,
@@ -66,7 +66,6 @@ def close(self):
6666
mod.future.cancel()
6767
except Exception:
6868
pass
69-
self.clear_worker_cache()
7069
self.client.close()
7170

7271
def register_task(self, name, cls):
@@ -213,10 +212,8 @@ def status(self):
213212
**self.get_all_tasks_status(task=None),
214213
}
215214

216-
def clear_worker_cache(self):
217-
self.client.run(ArtifactManager.clear_cache)
218-
219215
def restart(self):
220-
# Clear futures to free memory.
221216
for task_name, module in self.current_tasks.items():
222217
module.future = None
218+
self.client.restart()
219+
log.info("Cluster restarted to free memory.")

0 commit comments

Comments
 (0)
Please sign in to comment.