-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutil.py
109 lines (94 loc) · 3.24 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
""" Utility functions """
import os
from shutil import copyfile
import torch
import torch.nn as nn
def construct(obj_map, extra_kwargs={}):
"""
Constructs an object of class and with parameters specified in
``obj_map``, along with any additional arguments passed in through
``extra_kwargs``.
"""
classname = obj_map["class"]
if "args" in obj_map:
kwargs = obj_map["args"]
else:
kwargs = {}
kwargs.update(extra_kwargs)
c = get_from_module(classname)
return c(**kwargs)
def get_from_module(attrname):
"""
Returns the Python class/method of the specified |attrname|.
Typical usage pattern:
m = get_class("this.module.MyClass")
my_class = m(**kwargs)
"""
parts = attrname.split('.')
module = '.'.join(parts[:-1])
m = __import__(module)
for comp in parts[1:]:
m = getattr(m, comp)
return m
def init_weights(mod):
"""
Initializes parameters for PyTorch module ``mod``. This should only be
called when ``mod`` has been newly insantiated has not yet been trained.
"""
if len(list(mod.modules())) == 0:
return
for m in mod.modules():
if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv2d):
torch.nn.init.xavier_uniform_(
m.weight, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv1d):
torch.nn.init.xavier_uniform_(
m.weight, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
m.bias.data.zero_()
def load_state(filename):
"""
Loads the PyTorch files saved at ``filename``, converting them to be
able to run with CPU or GPU, depending on the availability on the machine.
"""
if torch.cuda.is_available():
return torch.load(filename)
else:
return torch.load(filename, map_location=lambda storage, loc: storage)
def save_checkpoint(state_dict, save_dir, filename, is_best):
"""
Serializes and saves dictionary ``state_dict`` to ``save_dir`` with name
``filename``. If parameter ``is_best`` is set to ``True``, then this
dictionary is also saved under ``save_dir`` as "best.pth".
"""
save_file = os.path.join(save_dir, filename)
torch.save(state_dict, save_file)
if is_best:
copyfile(save_file, os.path.join(save_dir, 'best.pth'))
def try_cuda(x):
"""
Sends PyTorch tensor or Variable ``x`` to GPU, if available.
"""
if torch.cuda.is_available():
return x.cuda()
return x
def write_vals(outfile, vals, names):
"""
Writes each value in ``vals[i]`` to a file with name formatted as
``outfile.format(names[i])``.
"""
def write_value(val, outfilename):
with open(outfilename, 'a') as outfile:
outfile.write("{}\n".format(val))
for v, n in zip(vals, names):
write_value(v, outfile.format(n))