-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
117 lines (83 loc) · 3.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import pickle
import random
from collections import defaultdict
import librosa
import numpy as np
import tensorboard_logger
import torch
import config
def save_pickle(data, filename, base_directory, verbose=False):
if verbose:
print(f'Saving {filename}')
with open(os.path.join(base_directory, filename), 'wb') as f:
pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
def load_pickle(filename, base_directory):
with open(os.path.join(base_directory, filename), 'rb') as f:
return pickle.load(f)
def save_predictions(data, filename):
save_pickle(data, filename, base_directory=config.PREDICTIONS_PATH)
def load_predictions(filename):
return load_pickle(filename, base_directory=config.PREDICTIONS_PATH)
def make_path_func(base_dir):
def path_func(filename):
return os.path.join(base_dir, filename)
return path_func
def save_checkpoint(state, filename, verbose=False):
if verbose:
print(f'Saving {filename}')
filepath = os.path.join(config.SAVED_MODELS_PATH, filename)
torch.save(state, filepath)
return filepath
def load_checkpoint(filename, verbose=True):
path = os.path.join(config.SAVED_MODELS_PATH, filename)
if verbose:
print(f'Loading {path}')
return torch.load(path)
class MetricMonitor:
def __init__(self, batch_size=None):
self.batch_size = batch_size
self.reset()
def reset(self):
self.metrics = defaultdict(lambda: {'sum': 0, 'count': 0, 'avg': 0})
def update(self, metric_name, value, n=None, multiply_by_n=True):
if n is None:
n = self.batch_size
metric = self.metrics[metric_name]
if multiply_by_n:
metric['sum'] += value * n
else:
metric['sum'] += value
metric['count'] += n
metric['avg'] = metric['sum'] / metric['count']
def get_avg(self, metric_name):
return self.metrics[metric_name]['avg']
def get_metric_values(self):
return [
(metric, values['avg']) for metric, values in self.metrics.items()
]
def __str__(self):
return ' | '.join(
f'{metric_name} {metric["avg"]:.6f}'
for metric_name, metric in self.metrics.items()
)
class TensorboardClient:
def __init__(self, experiment_name):
tensorboard_logger.configure(
os.path.join(config.TENSORBOARD_LOGS_DIR, experiment_name))
def log_value(self, mode, key, value, step):
name = f'{mode}/{key}'
tensorboard_logger.log_value(name, value, step)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def calculate_accuracy(outputs, targets):
_, predictions = outputs.topk(1, 1, True, True)
predictions = predictions.t()
correct = predictions.eq(targets.view(1, -1).expand_as(predictions))
correct_k = correct[0].view(-1).float().sum(0)
return correct_k.data.cpu()[0]
def load_wav(filepath):
return librosa.core.load(filepath, sr=config.AUDIO_SAMPLING_RATE)[0]