forked from jonahar/lightning-systemic-attack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
258 lines (201 loc) · 8.05 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import json
import logging
import os
import sqlite3
import sys
import time
from datetime import datetime
from functools import wraps
from logging import Logger
from typing import Any, Callable
import plyvel
from datatypes import Json
from paths import CACHES_DIR
def print_json(o: Json):
print(json.dumps(o, indent=4))
def now() -> str:
"""
return current time in YYYY-MM-DD_HH:MM
"""
return datetime.now().strftime("%Y-%m-%d_%H:%M")
def timeit(logger: Logger, print_args: bool = False):
"""
the `timeit` decorator logs entering and exiting from a function, and the total
time it ran.
if `print_args` is True, the function arguments are also logged
"""
def decorator(func):
# the `wraps` decorator gives `wrapper` the attributes of func.
# in particular, its name
@wraps(func)
def wrapper(*args, **kwargs):
logger.info(
f"Entering {func.__name__}"
+
(f" with args={args}, kwargs={kwargs}" if print_args else "")
)
t0 = time.time()
result = func(*args, **kwargs)
t1 = time.time()
logger.info(
f"Exiting {func.__name__}. Total runtime: {round(t1 - t0, 3)} seconds"
)
return result
return wrapper
return decorator
def setup_logging(
logger_name: str = None,
console: bool = True,
filename: str = None,
fmt: str = None
) -> Logger:
"""
setup a logger with a specific format, console handler and possibly file handler
:param logger_name: the logger to setup. If none, setup the root logger
:param console: booleans indicating whether to log to standard output with level INFO
:param filename: log file. If not None, log to this file with level DEBUG
:param fmt: format for log messages. if None, a default format is used
:return: the logger that was set-up
"""
if fmt is None:
fmt = "%(asctime)s: %(levelname)s: %(module)s: %(funcName)s: %(message)s"
formatter = logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG) # the logger doesn't filter anything
# console handler
if console:
ch = logging.StreamHandler(sys.stdout)
ch.setFormatter(formatter)
ch.setLevel(logging.INFO)
logger.addHandler(ch)
# file handler
if filename:
fh = logging.FileHandler(filename)
fh.setFormatter(formatter)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
return logger
def get_db_str_key(*args, **kwargs) -> str:
"""
this methods return a string representation of its arguments.
Important: different arguments should have different string representation
"""
args_str = [str(arg) for arg in args]
kwargs_str = [f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]
return ",".join(args_str + kwargs_str)
def get_leveldb_cache_fullpath(func_name: str) -> str:
return os.path.join(CACHES_DIR, f"{func_name}_py_function_leveldb")
def leveldb_cache(
value_to_str: Callable[[Any], str],
str_to_value: Callable[[str], Any],
key_to_str: Callable[..., str] = None,
db_path: str = None,
):
"""
This decorator caches results of function calls in a LevelDB on disk.
The cache size is (currently) not configurable, and is unlimited.
Each DB entry is a pair of input/output, representing function arguments and
the function result for these arguments. Both represented as strings.
The decision to store only strings in the DB was made to allow other
applications (specifically, not python) to open the DB and to be able to easily
parse and understand it.
Args:
value_to_str: a callable that takes results of the cached function and return
their string representation
str_to_value: a callable that takes string representation of some result of
the cached function and return the value it represents
key_to_str: a callable that takes any combination of arguments and returns
a string representing this set of arguments. if None (default),
a default conversion method will be used. In that case you should
make sure that different arguments have different string representation,
or they will be considered equal
db_path: full path to the db file. if None (default) use a default one
Usage examples:
@sqlite_cache(value_to_str=str, str_to_value=float)
def foo(arg1: str, arg2: float) -> float:
...
import json
@sqlite_cache(value_to_str=json.dumps, str_to_value=json.loads)
def bar(arg1: str, arg2: str) -> List[str]:
...
"""
if key_to_str is None:
key_to_str = get_db_str_key
def decorator(func):
try:
# TODO if we fail to open because the db is locked, go around it and
# open it, just don't write in case new input/output arrive
cache_fullpath = db_path if db_path else get_leveldb_cache_fullpath(func_name=func.__name__)
db = plyvel.DB(cache_fullpath, create_if_missing=True)
except plyvel.IOError as e:
print(
f"WARNING: leveldb_cache: IOERROR occurred when trying to open leveldb "
f"for function `{func.__name__}`. function will NOT be cached. "
f"Error: {type(e)}: {str(e)}",
file=sys.stderr,
)
return func
@wraps(func)
def wrapper(*args, **kwargs):
db_key = key_to_str(*args, **kwargs).encode("utf-8")
value: bytes = db.get(db_key)
if value:
return str_to_value(value.decode("utf-8"))
value = func(*args, **kwargs)
db.put(db_key, value_to_str(value).encode("utf-8"))
return value
return wrapper
return decorator
def get_sqlite_cache_fullpath(func_name: str) -> str:
return os.path.join(CACHES_DIR, f"{func_name}_py_function_cache.sqlite")
def sqlite_cache(
value_to_str: Callable[[Any], str],
str_to_value: Callable[[str], Any],
key_to_str: Callable[..., str] = None,
db_path: str = None,
):
"""
similar of leveldb_cache, only based on sqlite.
see documentation of leveldb_cache
"""
if key_to_str is None:
key_to_str = get_db_str_key
def decorator(func):
try:
cache_fullpath = (
db_path if db_path else get_sqlite_cache_fullpath(func_name=func.__name__)
)
conn = sqlite3.connect(cache_fullpath)
c = conn.cursor()
c.execute(
f"CREATE TABLE IF NOT EXISTS {func.__name__} "
f"(input TEXT PRIMARY KEY, output TEXT);"
)
except sqlite3.Error as e:
print(
f"WARNING: sqlite_cache: ERROR occurred when trying to open sqlite db "
f"for function `{func.__name__}`. function will NOT be cached. Error: {e}",
file=sys.stderr,
)
return func
@wraps(func)
def wrapper(*args, **kwargs):
db_key = key_to_str(*args, **kwargs)
res = c.execute(
f"select output from {func.__name__} where input=(?)",
(db_key,)
)
line = res.fetchone()
if line:
# key exists
serialized_value = line[0]
return str_to_value(serialized_value)
value = func(*args, **kwargs)
c.execute(
F"INSERT INTO {func.__name__} (input, output) values (?, ?)",
(db_key, value_to_str(value)),
)
conn.commit()
return value
return wrapper
return decorator