1
1
# Copyright ServiceNow, Inc. 2021 – 2022
2
2
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
3
3
# in the root directory of this source tree.
4
- from multiprocessing import Lock
5
- from typing import Callable , Dict , Optional
4
+ from collections import defaultdict
5
+ from typing import Callable , Dict
6
6
7
+ import structlog
7
8
from datasets import DatasetDict
8
9
9
10
from azimuth .config import AzimuthConfig
18
19
Hash = int
19
20
20
21
22
+ log = structlog .get_logger ()
23
+
24
+
25
+ class Singleton :
26
+ """
27
+ A non-thread-safe helper class to ease implementing singletons.
28
+ This should be used as a decorator -- not a metaclass -- to the
29
+ class that should be a singleton.
30
+
31
+ To get the singleton instance, use the `instance` method. Trying
32
+ to use `__call__` will result in a `TypeError` being raised.
33
+
34
+ Args:
35
+ decorated: Decorated class
36
+ """
37
+
38
+ def __init__ (self , decorated ):
39
+ self ._decorated = decorated
40
+
41
+ def instance (self ):
42
+ """
43
+ Returns the singleton instance. Upon its first call, it creates a
44
+ new instance of the decorated class and calls its `__init__` method.
45
+ On all subsequent calls, the already created instance is returned.
46
+
47
+ Returns:
48
+ Instance of the decorated class
49
+ """
50
+ try :
51
+ return self ._instance
52
+ except AttributeError :
53
+ self ._instance = self ._decorated ()
54
+ return self ._instance
55
+
56
+ def __call__ (self ):
57
+ raise TypeError ("Singletons must be accessed through `instance()`." )
58
+
59
+ def clear_instance (self ):
60
+ """For test purposes only"""
61
+ if hasattr (self , "_instance" ):
62
+ delattr (self , "_instance" )
63
+
64
+
65
+ @Singleton
21
66
class ArtifactManager :
22
67
"""This class is a singleton which holds different artifacts.
23
68
24
69
Artifacts include dataset_split_managers, datasets and models for each config, so they don't
25
- need to be reloaded many times for a same module.
70
+ need to be reloaded many times for a same module. Inspired from
71
+ https://stackoverflow.com/questions/31875/is-there-a-simple-elegant-way-to-define-singletons.
26
72
"""
27
73
28
- instance : Optional ["ArtifactManager" ] = None
29
-
30
74
def __init__ (self ):
31
75
# The keys of the dict are a hash of the config.
32
76
self .dataset_dict_mapping : Dict [Hash , DatasetDict ] = {}
33
77
self .dataset_split_managers_mapping : Dict [
34
78
Hash , Dict [DatasetSplitName , DatasetSplitManager ]
35
- ] = {}
36
- self .models_mapping : Dict [Hash , Dict [int , Callable ]] = {}
37
- self .tokenizer = None
79
+ ] = defaultdict (dict )
80
+ self .models_mapping : Dict [Hash , Dict [int , Callable ]] = defaultdict (dict )
38
81
self .metrics = {}
39
-
40
- @classmethod
41
- def get_instance (cls ):
42
- with Lock ():
43
- if cls .instance is None :
44
- cls .instance = cls ()
45
- return cls .instance
82
+ log .debug (f"Creating new Artifact Manager { id (self )} ." )
46
83
47
84
def get_dataset_split_manager (
48
85
self , config : AzimuthConfig , name : DatasetSplitName
@@ -68,8 +105,6 @@ def get_dataset_split_manager(
68
105
f"Found { tuple (dataset_dict .keys ())} ."
69
106
)
70
107
project_hash : Hash = config .get_project_hash ()
71
- if project_hash not in self .dataset_split_managers_mapping :
72
- self .dataset_split_managers_mapping [project_hash ] = {}
73
108
if name not in self .dataset_split_managers_mapping [project_hash ]:
74
109
self .dataset_split_managers_mapping [project_hash ][name ] = DatasetSplitManager (
75
110
name = name ,
@@ -78,6 +113,7 @@ def get_dataset_split_manager(
78
113
initial_prediction_tags = ALL_PREDICTION_TAGS ,
79
114
dataset_split = dataset_dict [name ],
80
115
)
116
+ log .debug (f"New { name } DM in Artifact Manager { id (self )} " )
81
117
return self .dataset_split_managers_mapping [project_hash ][name ]
82
118
83
119
def get_dataset_dict (self , config ) -> DatasetDict :
@@ -106,25 +142,23 @@ def get_model(self, config: AzimuthConfig, pipeline_idx: int):
106
142
Returns:
107
143
Loaded model.
108
144
"""
109
-
110
- project_hash : Hash = config .get_project_hash ()
111
- if project_hash not in self .models_mapping :
112
- self .models_mapping [project_hash ] = {}
113
- if pipeline_idx not in self .models_mapping [project_hash ]:
145
+ model_contract_hash : Hash = config .get_model_contract_hash ()
146
+ if pipeline_idx not in self .models_mapping [model_contract_hash ]:
147
+ log .debug (f"Loading pipeline { pipeline_idx } ." )
114
148
pipelines = assert_not_none (config .pipelines )
115
- self .models_mapping [project_hash ][pipeline_idx ] = load_custom_object (
149
+ self .models_mapping [model_contract_hash ][pipeline_idx ] = load_custom_object (
116
150
assert_not_none (pipelines [pipeline_idx ].model ), azimuth_config = config
117
151
)
118
152
119
- return self .models_mapping [project_hash ][pipeline_idx ]
153
+ return self .models_mapping [model_contract_hash ][pipeline_idx ]
120
154
121
155
def get_metric (self , config , name : str , ** kwargs ):
122
- hash : Hash = md5_hash ({"name" : name , ** kwargs })
123
- if hash not in self .metrics :
124
- self .metrics [hash ] = load_custom_object (config .metrics [name ], ** kwargs )
125
- return self .metrics [hash ]
156
+ metric_hash : Hash = md5_hash ({"name" : name , ** kwargs })
157
+ if metric_hash not in self .metrics :
158
+ self .metrics [metric_hash ] = load_custom_object (config .metrics [name ], ** kwargs )
159
+ return self .metrics [metric_hash ]
126
160
127
161
@classmethod
128
- def clear_cache (cls ) -> None :
129
- with Lock ():
130
- cls . instance = None
162
+ def instance (cls ):
163
+ # Implemented in decorator
164
+ raise NotImplementedError
0 commit comments