@@ -54,6 +54,7 @@ def __init__(self,
54
54
use_ln : bool = False ,
55
55
temporal_dim : Optional [int ] = None ,
56
56
use_act : bool = True ,
57
+ act_name : str = "prelu" ,
57
58
use_res : bool = True ,
58
59
cond_dim : int = 0 ,
59
60
use_film_bn : bool = True , # TODO(cm): check if this should be false
@@ -62,6 +63,7 @@ def __init__(self,
62
63
self .use_ln = use_ln
63
64
self .temporal_dim = temporal_dim
64
65
self .use_act = use_act
66
+ self .act_name = act_name
65
67
self .use_res = use_res
66
68
self .cond_dim = cond_dim
67
69
self .use_film_bn = use_film_bn
@@ -75,7 +77,7 @@ def __init__(self,
75
77
76
78
self .act = None
77
79
if use_act :
78
- self .act = nn . PReLU ( out_channels )
80
+ self .act = self . get_activation ( act_name , out_channels )
79
81
80
82
self .conv = Conv1dGeneral (in_channels ,
81
83
out_channels ,
@@ -177,6 +179,70 @@ def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
177
179
x += x_res
178
180
return x
179
181
182
+ @staticmethod
183
+ def get_activation (act_name : str , out_ch : Optional [int ] = None ) -> nn .Module :
184
+ """
185
+ Most of the code and experimental results in this method are from
186
+ https://github.com/csteinmetz1/ronn
187
+
188
+ Given an activation name string, returns the corresponding activation function.
189
+
190
+ Args:
191
+ act_name: Name of the activation function.
192
+ out_ch: Optional number of output channels. Only used for determining the
193
+ number of parameters in the PReLU activation function.
194
+
195
+ Returns:
196
+ act: PyTorch activation function.
197
+
198
+ Experimental results for randomized overdrive neural networks.
199
+ ----------------------
200
+ - ReLU: solid distortion
201
+ - LeakyReLU: somewhat veiled sound
202
+ - Tanh: insane levels of distortion with lots of aliasing (HF)
203
+ - Sigmoid: too gritty to be useful
204
+ - ELU: fading in and out
205
+ - RReLU: really interesting HF noise with a background sound
206
+ - SELU: rolled off soft distortion sound
207
+ - GELU: roomy, not too interesting
208
+ - Softplus: heavily distorted signal but with a very rolled off sound. (nice)
209
+ - Softshrink: super distant sounding and somewhat roomy
210
+ """
211
+ act_name = act_name .lower ()
212
+ if act_name == "relu" :
213
+ act = nn .ReLU ()
214
+ elif act_name == "leakyrelu" :
215
+ act = nn .LeakyReLU ()
216
+ elif act_name == "tanh" :
217
+ act = nn .Tanh ()
218
+ elif act_name == "sigmoid" :
219
+ act = nn .Sigmoid ()
220
+ elif act_name == "elu" :
221
+ act = nn .ELU ()
222
+ elif act_name == "rrelu" :
223
+ act = nn .RReLU ()
224
+ elif act_name == "selu" :
225
+ act = nn .SELU ()
226
+ elif act_name == "gelu" :
227
+ act = nn .GELU ()
228
+ elif act_name == "softplus" :
229
+ act = nn .Softplus ()
230
+ elif act_name == "softshrink" :
231
+ act = nn .Softshrink ()
232
+ elif act_name == "silu" or act_name == "swish" :
233
+ act = nn .SiLU ()
234
+ elif act_name == "prelu" :
235
+ if out_ch is None :
236
+ act = nn .PReLU ()
237
+ else :
238
+ act = nn .PReLU (out_ch )
239
+ elif act_name == "prelu1" :
240
+ act = nn .PReLU ()
241
+ else :
242
+ raise ValueError (f"Invalid activation name: '{ act_name } '." )
243
+
244
+ return act
245
+
180
246
181
247
class TCN (nn .Module ):
182
248
def __init__ (self ,
@@ -195,6 +261,7 @@ def __init__(self,
195
261
use_ln : bool = False ,
196
262
temporal_dims : Optional [List [int ]] = None ,
197
263
use_act : bool = True ,
264
+ act_name : str = "prelu" ,
198
265
use_res : bool = True ,
199
266
cond_dim : int = 0 ,
200
267
use_film_bn : bool = True , # TODO(cm): check if this should be false
@@ -215,6 +282,7 @@ def __init__(self,
215
282
self .use_ln = use_ln
216
283
self .temporal_dims = temporal_dims
217
284
self .use_act = use_act
285
+ self .act_name = act_name
218
286
self .use_res = use_res
219
287
self .cond_dim = cond_dim
220
288
self .use_film_bn = use_film_bn
@@ -269,6 +337,7 @@ def __init__(self,
269
337
use_ln ,
270
338
temp_dim ,
271
339
use_act ,
340
+ act_name ,
272
341
use_res ,
273
342
cond_dim ,
274
343
use_film_bn ,
0 commit comments