@@ -399,6 +399,7 @@ def _istft(
399399 window : Optional [Var ],
400400 normalized : Optional [Var ],
401401 onesided : Optional [Var ],
402+ length : Optional [Var ],
402403 before_op : Operation ,
403404) -> Tuple [Var , Var ]:
404405 """
@@ -419,7 +420,7 @@ def _istft(
419420 input_shape = mb .shape (x = x , before_op = before_op )
420421 n_frames = input_shape .val [- 1 ]
421422 fft_size = input_shape .val [- 2 ]
422- expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
423+ # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
423424
424425 is_onesided = onesided .val if onesided else fft_size != n_fft
425426 cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
@@ -478,10 +479,14 @@ def _istft(
478479 real_result = mb .real_div (x = real_result , y = window_envelope , before_op = before_op )
479480 imag_result = mb .real_div (x = imag_result , y = window_envelope , before_op = before_op )
480481
481- # reduce the rank of the output
482- if should_increase_rank :
483- real_result = mb .squeeze (x = real_result , axes = (0 ,), before_op = before_op )
484- imag_result = mb .squeeze (x = imag_result , axes = (0 ,), before_op = before_op )
482+ # We need to adapt last dimension
483+ if length is not None :
484+ if length > expected_output_signal_len :
485+ real_result = mb .pad (x = real_result , pad = , mode = "constant" , constant_val = 0 , before_op = before_op )
486+ imag_result = mb .pad (x = imag_result , pad = , mode = "constant" , constant_val = 0 , before_op = before_op )
487+ elif length < expected_output_signal_len :
488+ real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length ], before_op = before_op )
489+ imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length ], before_op = before_op )
485490
486491 return real_result , imag_result
487492
0 commit comments