Skip to content

Commit 306abb3

Browse files
committed
[cm] Refactoring TCN
1 parent 7c9b22f commit 306abb3

File tree

1 file changed

+132
-58
lines changed

1 file changed

+132
-58
lines changed

neutone_sdk/tcn.py

Lines changed: 132 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
class FiLM(nn.Module):
1717
def __init__(self,
18-
cond_dim: int, # dim of conditioning input
19-
num_features: int, # dim of the conv channel
18+
cond_dim: int, # Dim of conditioning input
19+
num_features: int, # Dim of the conv channel
2020
use_bn: bool) -> None:
2121
super().__init__()
2222
self.num_features = num_features
@@ -47,8 +47,8 @@ def __init__(self,
4747
dilation: int = 1,
4848
bias: bool = True,
4949
padding_mode: str = "zeros",
50-
is_causal: bool = True,
51-
is_cached: bool = False,
50+
causal: bool = True,
51+
cached: bool = False,
5252
use_dynamic_bs: bool = True,
5353
batch_size: int = 1,
5454
use_ln: bool = False,
@@ -85,8 +85,8 @@ def __init__(self,
8585
dilation=dilation,
8686
bias=bias,
8787
padding_mode=padding_mode,
88-
causal=is_causal,
89-
cached=is_cached,
88+
causal=causal,
89+
cached=cached,
9090
use_dynamic_bs=use_dynamic_bs,
9191
batch_size=batch_size,
9292
debug_mode=debug_mode)
@@ -151,7 +151,7 @@ def prepare_for_inference(self) -> None:
151151
"""
152152
self.debug_mode = False
153153
self.conv.prepare_for_inference()
154-
self.eval() # TODO(cm): check if this is applied to all modules recursively
154+
self.eval()
155155

156156
def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
157157
if self.debug_mode:
@@ -166,55 +166,63 @@ def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
166166
if self.film is not None:
167167
if self.debug_mode:
168168
assert cond is not None
169-
x = self.film(x, cond)
169+
if cond is not None: # This if statement is needed for TorchScript
170+
x = self.film(x, cond)
170171
if self.act is not None:
171172
x = self.act(x)
172173
if self.res is not None:
173174
res = self.res(x_in)
174-
x_res = self.crop_fn(res, x.size(-1)) # TODO
175+
right_offset = self.get_delay_samples()
176+
x_res = Conv1dGeneral.right_offset_crop(res, x.size(-1), right_offset)
175177
x += x_res
176178
return x
177179

178180

179181
class TCN(nn.Module):
180182
def __init__(self,
183+
in_channels: int,
181184
out_channels: List[int],
182-
dilations: Optional[List[int]] = None,
183-
in_ch: int = 1,
184-
kernel_size: int = 13,
185+
kernel_size: int = 3,
185186
strides: Optional[List[int]] = None,
186-
padding: Optional[int] = 0,
187+
padding: Union[str, int, Tuple[int]] = "same",
188+
dilations: Optional[List[int]] = None,
189+
bias: bool = True,
190+
padding_mode: str = "zeros",
191+
causal: bool = True,
192+
cached: bool = False,
193+
use_dynamic_bs: bool = True,
194+
batch_size: int = 1,
187195
use_ln: bool = False,
188196
temporal_dims: Optional[List[int]] = None,
189197
use_act: bool = True,
190198
use_res: bool = True,
191199
cond_dim: int = 0,
192-
use_film_bn: bool = False,
193-
is_causal: bool = True,
194-
is_cached: bool = False) -> None:
200+
use_film_bn: bool = True, # TODO(cm): check if this should be false
201+
debug_mode: bool = True) -> None:
195202
super().__init__()
203+
self.in_channels = in_channels
196204
self.out_channels = out_channels
197-
self.in_ch = in_ch
198-
self.out_ch = out_channels[-1]
199205
self.kernel_size = kernel_size
206+
self.strides = strides
200207
self.padding = padding
208+
self.dilations = dilations
209+
self.bias = bias
210+
self.padding_mode = padding_mode
211+
self.causal = causal
212+
self.cached = cached
213+
self.use_dynamic_bs = use_dynamic_bs
214+
self.batch_size = batch_size
201215
self.use_ln = use_ln
202-
self.temporal_dims = temporal_dims # TODO(cm): calculate automatically
216+
self.temporal_dims = temporal_dims
203217
self.use_act = use_act
204218
self.use_res = use_res
205219
self.cond_dim = cond_dim
206220
self.use_film_bn = use_film_bn
207-
self.is_causal = is_causal
208-
self.is_cached = is_cached
209-
if is_causal:
210-
assert padding == 0, "If the TCN is causal, padding must be 0"
211-
self.crop_fn = causal_crop
212-
else:
213-
self.crop_fn = center_crop
214-
if is_cached:
215-
assert is_causal, "If the TCN is streaming, it must be causal"
221+
self.debug_mode = debug_mode
216222

217223
self.n_blocks = len(out_channels)
224+
assert self.n_blocks > 0
225+
218226
if dilations is None:
219227
dilations = [4 ** idx for idx in range(self.n_blocks)]
220228
log.info(f"Setting dilations automatically to: {dilations}")
@@ -223,7 +231,7 @@ def __init__(self,
223231

224232
if strides is None:
225233
strides = [1] * self.n_blocks
226-
log.info(f"Setting strides automatically to: {strides}")
234+
log.info(f"Setting strides automatically to: {strides}")
227235
assert len(strides) == self.n_blocks
228236
self.strides = strides
229237

@@ -233,9 +241,11 @@ def __init__(self,
233241

234242
self.blocks = nn.ModuleList()
235243
block_out_ch = None
236-
for idx, (curr_out_ch, dil, stride) in enumerate(zip(out_channels, dilations, strides)):
244+
for idx, (curr_out_ch, dil, stride) in enumerate(zip(out_channels,
245+
dilations,
246+
strides)):
237247
if idx == 0:
238-
block_in_ch = in_ch
248+
block_in_ch = in_channels
239249
else:
240250
block_in_ch = block_out_ch
241251
block_out_ch = curr_out_ch
@@ -244,51 +254,115 @@ def __init__(self,
244254
if temporal_dims is not None:
245255
temp_dim = temporal_dims[idx]
246256

247-
self.blocks.append(TCNBlock(
248-
block_in_ch,
249-
block_out_ch,
250-
kernel_size,
251-
dil,
252-
stride,
253-
padding,
254-
use_ln,
255-
temp_dim,
256-
use_act,
257-
use_res,
258-
cond_dim,
259-
use_film_bn,
260-
is_causal,
261-
is_cached
262-
))
257+
self.blocks.append(TCNBlock(block_in_ch,
258+
block_out_ch,
259+
kernel_size,
260+
stride,
261+
padding,
262+
dil,
263+
bias,
264+
padding_mode,
265+
causal,
266+
cached,
267+
use_dynamic_bs,
268+
batch_size,
269+
use_ln,
270+
temp_dim,
271+
use_act,
272+
use_res,
273+
cond_dim,
274+
use_film_bn,
275+
debug_mode))
263276

277+
@tr.jit.export
264278
def is_conditional(self) -> bool:
279+
"""Returns True if the TCN is conditional, False otherwise."""
265280
return self.cond_dim > 0
266281

267-
def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
268-
assert x.ndim == 3 # (batch_size, in_ch, samples)
269-
if self.is_conditional():
270-
assert cond is not None
271-
assert cond.shape == (x.size(0), self.cond_dim) # (batch_size, cond_dim)
282+
@tr.jit.export
283+
def is_cached(self) -> bool:
284+
"""Returns True if the TCN is cached, False otherwise."""
285+
return self.cached
286+
287+
@tr.jit.export
288+
def set_cached(self, cached: bool) -> None:
289+
"""
290+
Sets the TCN to cached or not cached mode and resets its state.
291+
292+
Args:
293+
cached: If True, the TCN is cached. If False, it is not cached.
294+
"""
295+
self.cached = cached
272296
for block in self.blocks:
273-
x = block(x, cond)
274-
return x
297+
block.set_cached(cached)
298+
299+
@tr.jit.export
300+
def reset(self, batch_size: Optional[int] = None) -> None:
301+
"""
302+
Resets the TCN's state. If batch_size is provided, the cached padding
303+
will be resized to match the new batch size.
304+
305+
Args:
306+
batch_size: If provided, the cached padding will be resized to match the new
307+
batch size.
308+
"""
309+
for block in self.blocks:
310+
block.reset(batch_size)
311+
312+
@tr.jit.export
313+
def get_delay_samples(self) -> int:
314+
"""
315+
Returns the number of samples that the TCN delays the output by. This
316+
should always be 0 when the TCN is causal. This is ill-defined when not
317+
in cached mode since the output number of samples can be different than the
318+
input number of samples, so this would typically only be used in cached mode.
319+
"""
320+
# TODO(cm): verify this
321+
delay_samples = 0
322+
for block in self.blocks:
323+
delay_samples += block.get_delay_samples()
324+
return delay_samples
275325

326+
@tr.jit.export
276327
def calc_receptive_field(self) -> int:
277-
"""Compute the receptive field in samples."""
328+
"""Computes the receptive field of the TCN in samples."""
278329
assert all(_ == 1 for _ in self.strides) # TODO(cm): add support for dsTCN
279330
assert self.dilations[0] == 1 # TODO(cm): add support for >1 starting dilation
280331
rf = self.kernel_size
281332
for dil in self.dilations[1:]:
282-
rf = rf + ((self.kernel_size - 1) * dil)
333+
rf += ((self.kernel_size - 1) * dil)
283334
return rf
284335

336+
def prepare_for_inference(self) -> None:
337+
"""
338+
Prepares the TCN for inference by disabling debug mode and ensuring the
339+
TCN is in cached mode.
340+
"""
341+
self.debug_mode = False
342+
for block in self.blocks:
343+
block.prepare_for_inference()
344+
self.eval()
345+
346+
def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
347+
if self.debug_mode:
348+
assert x.ndim == 3 # (bs, in_ch, samples)
349+
if self.is_conditional():
350+
assert cond is not None
351+
assert cond.shape == (x.size(0), self.cond_dim) # (bs, cond_dim)
352+
for block in self.blocks:
353+
x = block(x, cond)
354+
return x
355+
285356

286357
if __name__ == '__main__':
287358
out_channels = [8] * 4
288-
tcn = TCN(out_channels, cond_dim=3, padding=0, is_causal=True, is_cached=True)
289-
log.info(tcn.calc_receptive_field())
359+
tcn = TCN(1, out_channels, cond_dim=3, causal=False, cached=False, padding="valid")
360+
log.info(f"Receptive field: {tcn.calc_receptive_field()}")
361+
log.info(f"Delay samples: {tcn.get_delay_samples()}")
290362
audio = tr.rand((1, 1, 65536))
291363
cond = tr.rand((1, 3))
292364
# cond = None
293365
out = tcn.forward(audio, cond)
294366
log.info(out.shape)
367+
368+
script = tr.jit.script(tcn)

0 commit comments

Comments
 (0)