diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 36f7ecc1cc..9e07df521e 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -84,7 +84,7 @@ def __init__( self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.register_buffer("window", window) + self.register_buffer("window", window, persistent=False) self.pad = pad self.power = power self.normalized = normalized @@ -178,7 +178,7 @@ def __init__( self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.register_buffer("window", window) + self.register_buffer("window", window, persistent=False) self.pad = pad self.normalized = normalized self.center = center @@ -974,7 +974,7 @@ def __init__( beta, dtype=dtype, ) - self.register_buffer("kernel", kernel) + self.register_buffer("kernel", kernel, persistent=False) def forward(self, waveform: Tensor) -> Tensor: r""" diff --git a/test/torchaudio_unittest/transforms/transforms_test.py b/test/torchaudio_unittest/transforms/transforms_test.py index 0a6e81b440..4004b623a2 100644 --- a/test/torchaudio_unittest/transforms/transforms_test.py +++ b/test/torchaudio_unittest/transforms/transforms_test.py @@ -238,18 +238,35 @@ def test_resample_size(self): sample_rate, upsample_rate, resampling_method="sinc_interp_hann" ) up_sampled = upsample_resample(waveform) - - # we expect the upsampled signal to have twice as many samples self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) downsample_resample = torchaudio.transforms.Resample( sample_rate, downsample_rate, resampling_method="sinc_interp_hann" ) down_sampled = downsample_resample(waveform) - - # we expect the downsampled signal to have half as many samples self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2) + + def test_spectrogram_window_not_in_state_dict(self): + transform = transforms.Spectrogram() + state_dict = transform.state_dict() + + self.assertNotIn("window", state_dict) + + + def test_inverse_spectrogram_window_not_in_state_dict(self): + transform = transforms.InverseSpectrogram(n_fft=400) + state_dict = transform.state_dict() + + self.assertNotIn("window", state_dict) + + + def test_resample_kernel_not_in_state_dict(self): + transform = transforms.Resample(orig_freq=16000, new_freq=8000) + state_dict = transform.state_dict() + + self.assertNotIn("kernel", state_dict) + def test_compute_deltas(self): channel = 13 n_mfcc = channel * 3