Skip to content

Commit 01566f3

Browse files
committed
[cm] Adding customizable activations to TCN block
1 parent 9f34d5d commit 01566f3

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

neutone_sdk/tcn.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self,
5454
use_ln: bool = False,
5555
temporal_dim: Optional[int] = None,
5656
use_act: bool = True,
57+
act_name: str = "prelu",
5758
use_res: bool = True,
5859
cond_dim: int = 0,
5960
use_film_bn: bool = True, # TODO(cm): check if this should be false
@@ -62,6 +63,7 @@ def __init__(self,
6263
self.use_ln = use_ln
6364
self.temporal_dim = temporal_dim
6465
self.use_act = use_act
66+
self.act_name = act_name
6567
self.use_res = use_res
6668
self.cond_dim = cond_dim
6769
self.use_film_bn = use_film_bn
@@ -75,7 +77,7 @@ def __init__(self,
7577

7678
self.act = None
7779
if use_act:
78-
self.act = nn.PReLU(out_channels)
80+
self.act = self.get_activation(act_name, out_channels)
7981

8082
self.conv = Conv1dGeneral(in_channels,
8183
out_channels,
@@ -177,6 +179,70 @@ def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
177179
x += x_res
178180
return x
179181

182+
@staticmethod
183+
def get_activation(act_name: str, out_ch: Optional[int] = None) -> nn.Module:
184+
"""
185+
Most of the code and experimental results in this method are from
186+
https://github.com/csteinmetz1/ronn
187+
188+
Given an activation name string, returns the corresponding activation function.
189+
190+
Args:
191+
act_name: Name of the activation function.
192+
out_ch: Optional number of output channels. Only used for determining the
193+
number of parameters in the PReLU activation function.
194+
195+
Returns:
196+
act: PyTorch activation function.
197+
198+
Experimental results for randomized overdrive neural networks.
199+
----------------------
200+
- ReLU: solid distortion
201+
- LeakyReLU: somewhat veiled sound
202+
- Tanh: insane levels of distortion with lots of aliasing (HF)
203+
- Sigmoid: too gritty to be useful
204+
- ELU: fading in and out
205+
- RReLU: really interesting HF noise with a background sound
206+
- SELU: rolled off soft distortion sound
207+
- GELU: roomy, not too interesting
208+
- Softplus: heavily distorted signal but with a very rolled off sound. (nice)
209+
- Softshrink: super distant sounding and somewhat roomy
210+
"""
211+
act_name = act_name.lower()
212+
if act_name == "relu":
213+
act = nn.ReLU()
214+
elif act_name == "leakyrelu":
215+
act = nn.LeakyReLU()
216+
elif act_name == "tanh":
217+
act = nn.Tanh()
218+
elif act_name == "sigmoid":
219+
act = nn.Sigmoid()
220+
elif act_name == "elu":
221+
act = nn.ELU()
222+
elif act_name == "rrelu":
223+
act = nn.RReLU()
224+
elif act_name == "selu":
225+
act = nn.SELU()
226+
elif act_name == "gelu":
227+
act = nn.GELU()
228+
elif act_name == "softplus":
229+
act = nn.Softplus()
230+
elif act_name == "softshrink":
231+
act = nn.Softshrink()
232+
elif act_name == "silu" or act_name == "swish":
233+
act = nn.SiLU()
234+
elif act_name == "prelu":
235+
if out_ch is None:
236+
act = nn.PReLU()
237+
else:
238+
act = nn.PReLU(out_ch)
239+
elif act_name == "prelu1":
240+
act = nn.PReLU()
241+
else:
242+
raise ValueError(f"Invalid activation name: '{act_name}'.")
243+
244+
return act
245+
180246

181247
class TCN(nn.Module):
182248
def __init__(self,
@@ -195,6 +261,7 @@ def __init__(self,
195261
use_ln: bool = False,
196262
temporal_dims: Optional[List[int]] = None,
197263
use_act: bool = True,
264+
act_name: str = "prelu",
198265
use_res: bool = True,
199266
cond_dim: int = 0,
200267
use_film_bn: bool = True, # TODO(cm): check if this should be false
@@ -215,6 +282,7 @@ def __init__(self,
215282
self.use_ln = use_ln
216283
self.temporal_dims = temporal_dims
217284
self.use_act = use_act
285+
self.act_name = act_name
218286
self.use_res = use_res
219287
self.cond_dim = cond_dim
220288
self.use_film_bn = use_film_bn
@@ -269,6 +337,7 @@ def __init__(self,
269337
use_ln,
270338
temp_dim,
271339
use_act,
340+
act_name,
272341
use_res,
273342
cond_dim,
274343
use_film_bn,

0 commit comments

Comments
 (0)