@@ -9526,6 +9526,54 @@ def forward(self, x):
95269526 compute_unit = compute_unit
95279527 )
95289528
9529+ class TestISTFT (TorchBaseTest ):
9530+ @pytest .mark .slow
9531+ @pytest .mark .parametrize (
9532+ "compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length" ,
9533+ itertools .product (
9534+ compute_units ,
9535+ backends ,
9536+ [(1 , 32 , 9 ), (32 , 9 ), (3 , 32 , 9 )], # input shape
9537+ [False , True ], # complex
9538+ [16 ], # n_fft
9539+ [None , 4 , 5 ], # hop_length
9540+ [None , 16 , 9 ], # win_length
9541+ [None , torch .hann_window ], # window
9542+ [None , False , True ], # center
9543+ ["constant" , "reflect" , "replicate" ], # pad mode
9544+ [False , True ], # normalized
9545+ [None , False , True ], # onesided
9546+ [None , 60 ], # length
9547+ )
9548+ )
9549+ def test_istft (self , compute_unit , backend , input_shape , complex , n_fft , hop_length , win_length , window , center , pad_mode , normalized , onesided ):
9550+ if complex and onesided :
9551+ pytest .skip ("Onesided stft not possible for complex inputs" )
9552+
9553+ class ISTFTModel (torch .nn .Module ):
9554+ def forward (self , x ):
9555+ applied_window = window (win_length ) if window and win_length else None
9556+ x = torch .complex (x , x )
9557+ x = torch .istft (
9558+ x ,
9559+ n_fft = n_fft ,
9560+ hop_length = hop_length ,
9561+ win_length = win_length ,
9562+ window = applied_window ,
9563+ center = center ,
9564+ normalized = normalized ,
9565+ onesided = onesided ,
9566+ length = length ,
9567+ return_complex = True )
9568+ x = torch .stack ([torch .real (x ), torch .imag (x )], dim = 0 )
9569+ return x
9570+
9571+ TorchBaseTest .run_compare_torch (
9572+ input_shape ,
9573+ ISTFTModel (),
9574+ backend = backend ,
9575+ compute_unit = compute_unit
9576+ )
95299577
95309578if _HAS_TORCH_AUDIO :
95319579
0 commit comments