@@ -338,10 +338,7 @@ def _stft(
338338 input_imaginary = mb .expand_dims (x = input_imaginary , axes = (0 ,), before_op = before_op )
339339
340340 is_onesided = onesided and onesided .val
341- cos_base , sin_base = _calculate_dft_matrix (
342- n_fft ,
343- onesided = is_onesided ,
344- before_op = before_op )
341+ cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
345342
346343 # create a window of centered 1s of the requested size
347344 if win_length :
@@ -352,29 +349,46 @@ def _stft(
352349 cos_base = mb .mul (x = window , y = cos_base , before_op = before_op )
353350 sin_base = mb .mul (x = window , y = sin_base , before_op = before_op )
354351
355- # conv with DFT kernel across the input signal
356- sin_base = mb . sub ( x = 0. , y = sin_base , before_op = before_op )
352+
353+ # Expand
357354 cos_base = mb .expand_dims (x = cos_base , axes = (1 ,), before_op = before_op )
358355 sin_base = mb .expand_dims (x = sin_base , axes = (1 ,), before_op = before_op )
359356 hop_size = mb .expand_dims (x = hop_length , axes = (0 ,), before_op = before_op )
360-
361357 signal_real = mb .expand_dims (x = input_real , axes = (1 ,), before_op = before_op )
358+ if input_imaginary :
359+ signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
360+
361+ # conv with DFT kernel across the input signal
362+ # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
363+ # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
364+ # If x is complex then x[n]=(a+i*b)
365+ # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
366+ # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
362367 cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
363368 sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
364-
365369 if input_imaginary :
366- signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
367370 cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
368371 sin_windows_imag = mb .conv (x = signal_imaginary , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
369372
370373 # add everything together
371374 if input_imaginary :
372- # sin base is already negative so subtract
373- real_result = mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
374- imag_result = mb .add (x = sin_windows_real , y = cos_windows_imag , before_op = before_op )
375+ real_result = mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376+ imag_result = mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
375377 else :
376378 real_result = cos_windows_real
377- imag_result = sin_windows_real
379+ imag_result = mb .sub (x = 0. , y = sin_windows_real , before_op = before_op )
380+
381+ # reduce the rank of the output
382+ if should_increase_rank :
383+ real_result = mb .squeeze (x = real_result , axes = (0 ,), before_op = before_op )
384+ imag_result = mb .squeeze (x = imag_result , axes = (0 ,), before_op = before_op )
385+
386+ if normalized and normalized .val :
387+ divisor = mb .sqrt (x = mb .cast (x = n_fft , dtype = "fp32" , before_op = before_op ), before_op = before_op )
388+ real_result = mb .real_div (x = real_result , y = divisor , before_op = before_op )
389+ imag_result = mb .real_div (x = imag_result , y = divisor , before_op = before_op )
390+
391+ return real_result , imag_result
378392
379393def _istft (
380394 input_real : Var ,
0 commit comments