-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconf_utils.py
executable file
·181 lines (151 loc) · 6.48 KB
/
conf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Contains utility objects necessary to process configuration values and experiment setups in a sane matter
Uses plain text files or Mongodb v3 for actual storage"""
import datetime
import itertools
from functools import reduce
from operator import mul
from typing import Iterable, List, Union
from pymongo.collection import Collection
__encoding__ = "utf-8"
__author__ = 'Alex Pyattaev'
from lib.mongo_stuff import recursive_clean
class config_container(object):
"""This class should be used as container in global configs (like SLS.py)"""
pass
def unwrap_inner(params):
for K, V in list(params.items()):
if not isinstance(K, tuple):
continue
params.pop(K)
params[K] = []
for vset in V:
vset = list(vset)
for i, v in enumerate(vset):
if isinstance(v, str) or not isinstance(v, Iterable):
vset[i] = (v,)
for pset in itertools.product(*vset):
pset = list(pset)
params[K].append(pset)
def unwrap_tuples(params):
for k, v in list(params.items()):
if isinstance(k, tuple):
params.pop(k)
params.update({kk: vv for kk, vv in zip(k, v)})
class Experiment(object):
"""Abstraction for Experiment
Auto-fills database fields:
{"type":"EXPERIMENT", "tag":tag, "time":current time}
"""
def __init__(self, params: dict, seeds: list, storage: Union[Collection, str], tag: str = "", code_versions=dict):
"""
:param params: Parameters for trial {key:array of values}
:param seeds: Random trial integer seed list (used for Monte-Carlo analysis)
:param storage: the database collection to store data
:param code_versions: the versions of relevant repos with code to run experiment
:param tag: Optional tag to locate the experiment in DB
"""
self.params = params
self.tag = tag
self.seeds = seeds
self.db_id = None
self.storage = storage
self.code_versions = code_versions
unwrap_inner(self.params)
if isinstance(storage, Collection):
document = {"type": "EXPERIMENT", "tag": self.tag, "time": datetime.datetime.now(),
"time_completed": None, "code_versions": code_versions}
res = storage.insert_one(document)
self.db_id = res.inserted_id
elif isinstance(storage, str):
assert storage.count("{}") == 1, "Should have exactly one format field for point index!"
else:
raise TypeError('Storage must be a database collection or a filename prefix')
def mark_done(self, return_codes: List[int]) -> None:
self.storage.update_one({'_id': self.db_id}, {'$set': {'time_completed': datetime.datetime.now(),
'return_codes': return_codes}})
def cleanup(self) -> None:
"""
Clean all junk possibly linked to this experiment. Useful if you want to abort.
"""
recursive_clean(self.storage, {'_id': self.db_id})
def __iter__(self):
"""
Experiment can and should be used as iterator
:return: the next Trial to be used
:rtype: Trial
"""
sweep_params = {k: v for k, v in self.params.items() if not isinstance(v, str) and isinstance(v, Iterable)}
fix_params = {k: v for k, v in self.params.items() if k not in sweep_params}
if sweep_params:
sweep_keys, sweep_values = list(zip(*list(sweep_params.items())))
# Keep index for text file export (Matlab)
trial_idx = 0
# Construct all possible combinations of parameter values
for params in itertools.product(*sweep_values):
# combine them back with their names
p = dict(zip(sweep_keys, params))
# Update the dict with fixed params
p.update(fix_params)
unwrap_tuples(p)
for s in self.seeds:
trial_idx += 1
print("Spawning trial #{}, params{} seed {}".format(trial_idx, p, s))
# Create a trial
if isinstance(self.storage, Collection):
yield Trial(self, p, s, self.storage)
else:
yield Trial(self, p, s, self.storage.format(trial_idx))
else:
trial_idx = 0
for s in self.seeds:
trial_idx += 1
print("Spawning trial #{}, seed {}".format(trial_idx, s))
# Create a trial
if isinstance(self.storage, Collection):
yield Trial(self, fix_params, s, self.storage)
else:
yield Trial(self, fix_params, s, self.storage.format(trial_idx))
def __len__(self):
def ll(v):
if isinstance(v, str):
return 1
try:
return len(v)
except TypeError:
return 1
if self.params:
return reduce(mul, [ll(v) for v in self.params.values()]) * len(self.seeds)
else:
return len(self.seeds)
class Trial(object):
"""Abstraction for Trial, i.e. a point within an experiment set
Auto-fills database fields:
{"type":"TRIAL", "tag":tag, "seed":seed, "link":experiment}
"""
def __init__(self, experiment: Experiment, params: dict, seed: int, storage: Union[Collection, str]):
"""
:param experiment: the database ID of the experiment this trial belongs to
:param params: Parameters for trial {key:value}
:param seed: Random trial seed (used for Monte-Carlo analysis)
:param storage: the database collection to store data
:return: Trial
"""
self.params = params
self.experiment = experiment
self.db_id = None
self.storage = storage
self.seed = seed
if isinstance(storage, Collection):
document = {"type": "TRIAL", "params": self.params, "seed": seed, "link": experiment.db_id}
res = storage.insert_one(document)
self.db_id = res.inserted_id
elif isinstance(storage, str):
pass
else:
raise TypeError('Storage must be a database collection or a filename prefix')
@property
def storage_path(self):
if isinstance(self.storage, Collection):
return str(self.db_id)
else:
return self.storage