-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmasked_linear.py
178 lines (161 loc) · 7.5 KB
/
masked_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import numpy as np
class L0Mask(nn.Module):
def __init__(self, mask_dim, mask_p):
super().__init__()
self.mask_setting = 'mask'
self.mask_scores = nn.Parameter(torch.zeros(mask_dim))
self.mask_p = mask_p
self.l, self.r, self.b = -0.1, 1.1, 2 / 3
self.init_weights()
def init_weights(self):
p = (self.mask_p - self.l) / (self.r - self.l)
init.constant_(self.mask_scores, val=np.log(p / (1 - p)))
# init.normal_(self.mask_scores, mean=0, std=0.01)
def set_temperature(self, temp):
self.b = temp
def produce_mask(self):
if self.training:
u = torch.zeros_like(self.mask_scores).uniform_().clamp(0.0001, 0.9999)
s = torch.sigmoid((u.log() - (1 - u).log() + self.mask_scores) / self.b)
else:
s = torch.sigmoid(self.mask_scores)
s_bar = s * (self.r - self.l) + self.l
mask = s_bar.clamp(min=0.0, max=1.0)
return mask
def regularizer(self):
return torch.sum(torch.sigmoid(self.mask_scores - self.b * np.log(-self.l / self.r))) / self.mask_scores.numel()
class MaskedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
mask_p: float=0.9, out_w_per_mask=1, in_w_per_mask=1, num_heads=12):
super().__init__(in_features=in_features, out_features=out_features, bias=bias)
self.num_heads = num_heads
self.out_w_per_mask = out_w_per_mask
self.in_w_per_mask = in_w_per_mask
assert out_features % out_w_per_mask == 0, "{} % {} not 0".format(out_features, out_w_per_mask)
assert in_features % in_w_per_mask == 0, "{} % {} not 0".format(in_features, in_w_per_mask)
mask_dim = (1, out_features // out_w_per_mask, 1, in_features // in_w_per_mask)
self.mask = L0Mask(mask_dim, mask_p)
self.cached_activation = None
self.do_caching = False
def produce_mask_reshaped(self):
mask = self.mask.produce_mask()
mask = mask.repeat(self.out_w_per_mask, 1, self.in_w_per_mask, 1)
return mask.reshape(self.out_features, self.in_features)
def produce_mask(self):
mask = self.mask.produce_mask()
return mask
def forward(self, input: torch.tensor):
# "masked_weight = self.produce_mask_reshaped() * self.weight" is equivalent but slower.
masked_weight = self.produce_mask() * self.weight.reshape(
self.out_w_per_mask, self.out_features // self.out_w_per_mask,
self.in_w_per_mask, self.in_features // self.in_w_per_mask)
masked_weight = masked_weight.reshape(self.out_features, self.in_features)
act = F.linear(input, masked_weight, self.bias)
if self.do_caching:
if self.cached_activation is None:
self.cached_activation = act.detach()
else: # only works if subbatched, since maxlen must be constant
self.cached_activation = torch.cat((
self.cached_activation, act.detach()), dim = 0)
return act
def activate_caching(self, caching = True):
self.cached_activation = None
self.do_caching = caching
@classmethod
def from_layer(cls, layer, out_w_per_mask, in_w_per_mask, mask_p):
assert type(layer) == nn.modules.linear.Linear
res = cls(mask_p=mask_p, in_features=layer.in_features, out_features=layer.out_features,
bias=layer.bias is not None, out_w_per_mask=out_w_per_mask, in_w_per_mask=in_w_per_mask)
res.weight = layer.weight
res.bias = layer.bias
return res # make sure to call cuda
#
# class BinaryL0Mask(nn.Module):
# def __init__(self, mask_dim, mask_p):
# super().__init__()
# self.mask_setting = 'mask'
# self.mask_scores = nn.Parameter(torch.zeros(mask_dim))
# self.mask_p = mask_p
# self.l, self.r, self.b = -0.1, 1.1, 2 / 3
# self.init_weights()
#
# def init_weights(self):
# p = (self.mask_p - self.l) / (self.r - self.l)
# init.constant_(self.mask_scores, val=np.log(p / (1 - p)))
# # init.normal_(self.mask_scores, mean=0, std=0.01)
#
# def set_temperature(self, temp):
# self.b = temp
#
# def produce_mask(self):
# if self.training:
# u = torch.zeros_like(self.mask_scores).uniform_().clamp(0.0001, 0.9999)
# s = torch.sigmoid((u.log() - (1 - u).log() + self.mask_scores) / self.b)
# else:
# s = torch.sigmoid(self.mask_scores)
# s_bar = s * (self.r - self.l) + self.l
# mask = s_bar.clamp(min=0.0, max=1.0)
# mask = torch.round(mask)
# return mask
#
# def regularizer(self):
# return torch.sum(torch.sigmoid(self.mask_scores - self.b * np.log(-self.l / self.r))) / self.mask_scores.numel()
#
#
# class BinaryMaskedLinear(nn.Linear):
# def __init__(self, in_features: int, out_features: int, bias: bool = True,
# mask_p: float = 0.9, out_w_per_mask=1, in_w_per_mask=1, num_heads=12):
# super().__init__(in_features=in_features, out_features=out_features, bias=bias)
# self.num_heads = num_heads
# self.out_w_per_mask = out_w_per_mask
# self.in_w_per_mask = in_w_per_mask
#
# assert out_features % out_w_per_mask == 0, "{} % {} not 0".format(out_features, out_w_per_mask)
# assert in_features % in_w_per_mask == 0, "{} % {} not 0".format(in_features, in_w_per_mask)
# mask_dim = (1, out_features // out_w_per_mask, 1, in_features // in_w_per_mask)
# self.mask = BinaryL0Mask(mask_dim, mask_p)
#
# self.cached_activation = None
# self.do_caching = False
#
# def produce_mask_reshaped(self):
# mask = self.mask.produce_mask()
# mask = mask.repeat(self.out_w_per_mask, 1, self.in_w_per_mask, 1)
# return mask.reshape(self.out_features, self.in_features)
#
# def produce_mask(self):
# mask = self.mask.produce_mask()
# return mask
#
# def forward(self, input: torch.tensor):
# # "masked_weight = self.produce_mask_reshaped() * self.weight" is equivalent but slower.
# masked_weight = self.produce_mask() * self.weight.reshape(
# self.out_w_per_mask, self.out_features // self.out_w_per_mask,
# self.in_w_per_mask, self.in_features // self.in_w_per_mask)
# masked_weight = masked_weight.reshape(self.out_features, self.in_features)
#
# act = F.linear(input, masked_weight, self.bias)
# if self.do_caching:
# if self.cached_activation is None:
# self.cached_activation = act.detach()
# else: # only works if subbatched, since maxlen must be constant
# self.cached_activation = torch.cat((
# self.cached_activation, act.detach()), dim=0)
# return act
#
# def activate_caching(self, caching=True):
# self.cached_activation = None
# self.do_caching = caching
#
# @classmethod
# def from_layer(cls, layer, out_w_per_mask, in_w_per_mask, mask_p):
# assert type(layer) == nn.modules.linear.Linear
# res = cls(mask_p=mask_p, in_features=layer.in_features, out_features=layer.out_features,
# bias=layer.bias is not None, out_w_per_mask=out_w_per_mask, in_w_per_mask=in_w_per_mask)
# res.weight = layer.weight
# res.bias = layer.bias
# return res # make sure to call cuda