15
15
16
16
class FiLM (nn .Module ):
17
17
def __init__ (self ,
18
- cond_dim : int , # dim of conditioning input
19
- num_features : int , # dim of the conv channel
18
+ cond_dim : int , # Dim of conditioning input
19
+ num_features : int , # Dim of the conv channel
20
20
use_bn : bool ) -> None :
21
21
super ().__init__ ()
22
22
self .num_features = num_features
@@ -47,8 +47,8 @@ def __init__(self,
47
47
dilation : int = 1 ,
48
48
bias : bool = True ,
49
49
padding_mode : str = "zeros" ,
50
- is_causal : bool = True ,
51
- is_cached : bool = False ,
50
+ causal : bool = True ,
51
+ cached : bool = False ,
52
52
use_dynamic_bs : bool = True ,
53
53
batch_size : int = 1 ,
54
54
use_ln : bool = False ,
@@ -85,8 +85,8 @@ def __init__(self,
85
85
dilation = dilation ,
86
86
bias = bias ,
87
87
padding_mode = padding_mode ,
88
- causal = is_causal ,
89
- cached = is_cached ,
88
+ causal = causal ,
89
+ cached = cached ,
90
90
use_dynamic_bs = use_dynamic_bs ,
91
91
batch_size = batch_size ,
92
92
debug_mode = debug_mode )
@@ -151,7 +151,7 @@ def prepare_for_inference(self) -> None:
151
151
"""
152
152
self .debug_mode = False
153
153
self .conv .prepare_for_inference ()
154
- self .eval () # TODO(cm): check if this is applied to all modules recursively
154
+ self .eval ()
155
155
156
156
def forward (self , x : Tensor , cond : Optional [Tensor ] = None ) -> Tensor :
157
157
if self .debug_mode :
@@ -166,55 +166,63 @@ def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
166
166
if self .film is not None :
167
167
if self .debug_mode :
168
168
assert cond is not None
169
- x = self .film (x , cond )
169
+ if cond is not None : # This if statement is needed for TorchScript
170
+ x = self .film (x , cond )
170
171
if self .act is not None :
171
172
x = self .act (x )
172
173
if self .res is not None :
173
174
res = self .res (x_in )
174
- x_res = self .crop_fn (res , x .size (- 1 )) # TODO
175
+ right_offset = self .get_delay_samples ()
176
+ x_res = Conv1dGeneral .right_offset_crop (res , x .size (- 1 ), right_offset )
175
177
x += x_res
176
178
return x
177
179
178
180
179
181
class TCN (nn .Module ):
180
182
def __init__ (self ,
183
+ in_channels : int ,
181
184
out_channels : List [int ],
182
- dilations : Optional [List [int ]] = None ,
183
- in_ch : int = 1 ,
184
- kernel_size : int = 13 ,
185
+ kernel_size : int = 3 ,
185
186
strides : Optional [List [int ]] = None ,
186
- padding : Optional [int ] = 0 ,
187
+ padding : Union [str , int , Tuple [int ]] = "same" ,
188
+ dilations : Optional [List [int ]] = None ,
189
+ bias : bool = True ,
190
+ padding_mode : str = "zeros" ,
191
+ causal : bool = True ,
192
+ cached : bool = False ,
193
+ use_dynamic_bs : bool = True ,
194
+ batch_size : int = 1 ,
187
195
use_ln : bool = False ,
188
196
temporal_dims : Optional [List [int ]] = None ,
189
197
use_act : bool = True ,
190
198
use_res : bool = True ,
191
199
cond_dim : int = 0 ,
192
- use_film_bn : bool = False ,
193
- is_causal : bool = True ,
194
- is_cached : bool = False ) -> None :
200
+ use_film_bn : bool = True , # TODO(cm): check if this should be false
201
+ debug_mode : bool = True ) -> None :
195
202
super ().__init__ ()
203
+ self .in_channels = in_channels
196
204
self .out_channels = out_channels
197
- self .in_ch = in_ch
198
- self .out_ch = out_channels [- 1 ]
199
205
self .kernel_size = kernel_size
206
+ self .strides = strides
200
207
self .padding = padding
208
+ self .dilations = dilations
209
+ self .bias = bias
210
+ self .padding_mode = padding_mode
211
+ self .causal = causal
212
+ self .cached = cached
213
+ self .use_dynamic_bs = use_dynamic_bs
214
+ self .batch_size = batch_size
201
215
self .use_ln = use_ln
202
- self .temporal_dims = temporal_dims # TODO(cm): calculate automatically
216
+ self .temporal_dims = temporal_dims
203
217
self .use_act = use_act
204
218
self .use_res = use_res
205
219
self .cond_dim = cond_dim
206
220
self .use_film_bn = use_film_bn
207
- self .is_causal = is_causal
208
- self .is_cached = is_cached
209
- if is_causal :
210
- assert padding == 0 , "If the TCN is causal, padding must be 0"
211
- self .crop_fn = causal_crop
212
- else :
213
- self .crop_fn = center_crop
214
- if is_cached :
215
- assert is_causal , "If the TCN is streaming, it must be causal"
221
+ self .debug_mode = debug_mode
216
222
217
223
self .n_blocks = len (out_channels )
224
+ assert self .n_blocks > 0
225
+
218
226
if dilations is None :
219
227
dilations = [4 ** idx for idx in range (self .n_blocks )]
220
228
log .info (f"Setting dilations automatically to: { dilations } " )
@@ -223,7 +231,7 @@ def __init__(self,
223
231
224
232
if strides is None :
225
233
strides = [1 ] * self .n_blocks
226
- log .info (f"Setting strides automatically to: { strides } " )
234
+ log .info (f"Setting strides automatically to: { strides } " )
227
235
assert len (strides ) == self .n_blocks
228
236
self .strides = strides
229
237
@@ -233,9 +241,11 @@ def __init__(self,
233
241
234
242
self .blocks = nn .ModuleList ()
235
243
block_out_ch = None
236
- for idx , (curr_out_ch , dil , stride ) in enumerate (zip (out_channels , dilations , strides )):
244
+ for idx , (curr_out_ch , dil , stride ) in enumerate (zip (out_channels ,
245
+ dilations ,
246
+ strides )):
237
247
if idx == 0 :
238
- block_in_ch = in_ch
248
+ block_in_ch = in_channels
239
249
else :
240
250
block_in_ch = block_out_ch
241
251
block_out_ch = curr_out_ch
@@ -244,51 +254,115 @@ def __init__(self,
244
254
if temporal_dims is not None :
245
255
temp_dim = temporal_dims [idx ]
246
256
247
- self .blocks .append (TCNBlock (
248
- block_in_ch ,
249
- block_out_ch ,
250
- kernel_size ,
251
- dil ,
252
- stride ,
253
- padding ,
254
- use_ln ,
255
- temp_dim ,
256
- use_act ,
257
- use_res ,
258
- cond_dim ,
259
- use_film_bn ,
260
- is_causal ,
261
- is_cached
262
- ))
257
+ self .blocks .append (TCNBlock (block_in_ch ,
258
+ block_out_ch ,
259
+ kernel_size ,
260
+ stride ,
261
+ padding ,
262
+ dil ,
263
+ bias ,
264
+ padding_mode ,
265
+ causal ,
266
+ cached ,
267
+ use_dynamic_bs ,
268
+ batch_size ,
269
+ use_ln ,
270
+ temp_dim ,
271
+ use_act ,
272
+ use_res ,
273
+ cond_dim ,
274
+ use_film_bn ,
275
+ debug_mode ))
263
276
277
+ @tr .jit .export
264
278
def is_conditional (self ) -> bool :
279
+ """Returns True if the TCN is conditional, False otherwise."""
265
280
return self .cond_dim > 0
266
281
267
- def forward (self , x : Tensor , cond : Optional [Tensor ] = None ) -> Tensor :
268
- assert x .ndim == 3 # (batch_size, in_ch, samples)
269
- if self .is_conditional ():
270
- assert cond is not None
271
- assert cond .shape == (x .size (0 ), self .cond_dim ) # (batch_size, cond_dim)
282
+ @tr .jit .export
283
+ def is_cached (self ) -> bool :
284
+ """Returns True if the TCN is cached, False otherwise."""
285
+ return self .cached
286
+
287
+ @tr .jit .export
288
+ def set_cached (self , cached : bool ) -> None :
289
+ """
290
+ Sets the TCN to cached or not cached mode and resets its state.
291
+
292
+ Args:
293
+ cached: If True, the TCN is cached. If False, it is not cached.
294
+ """
295
+ self .cached = cached
272
296
for block in self .blocks :
273
- x = block (x , cond )
274
- return x
297
+ block .set_cached (cached )
298
+
299
+ @tr .jit .export
300
+ def reset (self , batch_size : Optional [int ] = None ) -> None :
301
+ """
302
+ Resets the TCN's state. If batch_size is provided, the cached padding
303
+ will be resized to match the new batch size.
304
+
305
+ Args:
306
+ batch_size: If provided, the cached padding will be resized to match the new
307
+ batch size.
308
+ """
309
+ for block in self .blocks :
310
+ block .reset (batch_size )
311
+
312
+ @tr .jit .export
313
+ def get_delay_samples (self ) -> int :
314
+ """
315
+ Returns the number of samples that the TCN delays the output by. This
316
+ should always be 0 when the TCN is causal. This is ill-defined when not
317
+ in cached mode since the output number of samples can be different than the
318
+ input number of samples, so this would typically only be used in cached mode.
319
+ """
320
+ # TODO(cm): verify this
321
+ delay_samples = 0
322
+ for block in self .blocks :
323
+ delay_samples += block .get_delay_samples ()
324
+ return delay_samples
275
325
326
+ @tr .jit .export
276
327
def calc_receptive_field (self ) -> int :
277
- """Compute the receptive field in samples."""
328
+ """Computes the receptive field of the TCN in samples."""
278
329
assert all (_ == 1 for _ in self .strides ) # TODO(cm): add support for dsTCN
279
330
assert self .dilations [0 ] == 1 # TODO(cm): add support for >1 starting dilation
280
331
rf = self .kernel_size
281
332
for dil in self .dilations [1 :]:
282
- rf = rf + ((self .kernel_size - 1 ) * dil )
333
+ rf += ((self .kernel_size - 1 ) * dil )
283
334
return rf
284
335
336
+ def prepare_for_inference (self ) -> None :
337
+ """
338
+ Prepares the TCN for inference by disabling debug mode and ensuring the
339
+ TCN is in cached mode.
340
+ """
341
+ self .debug_mode = False
342
+ for block in self .blocks :
343
+ block .prepare_for_inference ()
344
+ self .eval ()
345
+
346
+ def forward (self , x : Tensor , cond : Optional [Tensor ] = None ) -> Tensor :
347
+ if self .debug_mode :
348
+ assert x .ndim == 3 # (bs, in_ch, samples)
349
+ if self .is_conditional ():
350
+ assert cond is not None
351
+ assert cond .shape == (x .size (0 ), self .cond_dim ) # (bs, cond_dim)
352
+ for block in self .blocks :
353
+ x = block (x , cond )
354
+ return x
355
+
285
356
286
357
if __name__ == '__main__' :
287
358
out_channels = [8 ] * 4
288
- tcn = TCN (out_channels , cond_dim = 3 , padding = 0 , is_causal = True , is_cached = True )
289
- log .info (tcn .calc_receptive_field ())
359
+ tcn = TCN (1 , out_channels , cond_dim = 3 , causal = False , cached = False , padding = "valid" )
360
+ log .info (f"Receptive field: { tcn .calc_receptive_field ()} " )
361
+ log .info (f"Delay samples: { tcn .get_delay_samples ()} " )
290
362
audio = tr .rand ((1 , 1 , 65536 ))
291
363
cond = tr .rand ((1 , 3 ))
292
364
# cond = None
293
365
out = tcn .forward (audio , cond )
294
366
log .info (out .shape )
367
+
368
+ script = tr .jit .script (tcn )
0 commit comments