Skip to content

Commit 861d028

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Add test
1 parent f9f6d9d commit 861d028

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9526,6 +9526,54 @@ def forward(self, x):
95269526
compute_unit=compute_unit
95279527
)
95289528

9529+
class TestISTFT(TorchBaseTest):
9530+
@pytest.mark.slow
9531+
@pytest.mark.parametrize(
9532+
"compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length",
9533+
itertools.product(
9534+
compute_units,
9535+
backends,
9536+
[(1, 32, 9), (32, 9), (3, 32, 9)], # input shape
9537+
[False, True], # complex
9538+
[16], # n_fft
9539+
[None, 4, 5], # hop_length
9540+
[None, 16, 9], # win_length
9541+
[None, torch.hann_window], # window
9542+
[None, False, True], # center
9543+
["constant", "reflect", "replicate"], # pad mode
9544+
[False, True], # normalized
9545+
[None, False, True], # onesided
9546+
[None, 60], # length
9547+
)
9548+
)
9549+
def test_istft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
9550+
if complex and onesided:
9551+
pytest.skip("Onesided stft not possible for complex inputs")
9552+
9553+
class ISTFTModel(torch.nn.Module):
9554+
def forward(self, x):
9555+
applied_window = window(win_length) if window and win_length else None
9556+
x = torch.complex(x, x)
9557+
x = torch.istft(
9558+
x,
9559+
n_fft=n_fft,
9560+
hop_length=hop_length,
9561+
win_length=win_length,
9562+
window=applied_window,
9563+
center=center,
9564+
normalized=normalized,
9565+
onesided=onesided,
9566+
length=length,
9567+
return_complex=True)
9568+
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
9569+
return x
9570+
9571+
TorchBaseTest.run_compare_torch(
9572+
input_shape,
9573+
ISTFTModel(),
9574+
backend=backend,
9575+
compute_unit=compute_unit
9576+
)
95299577

95309578
if _HAS_TORCH_AUDIO:
95319579

0 commit comments

Comments
 (0)