@@ -325,7 +325,7 @@ def _stft(
325325    We can write STFT in terms of convolutions with a DFT kernel. 
326326    At the end: 
327327        * The real part output is: cos_base * input_real + sin_base * input_imag 
328-         * The imaginary part output is: - (sin_base * input_real  - cos_base  * input_imag)  
328+         * The imaginary part output is: cos_base * input_imag  - sin_base  * input_real  
329329    Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py 
330330    """ 
331331    hop_length  =  hop_length  or  mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
@@ -358,12 +358,13 @@ def _stft(
358358    if  input_imaginary :
359359        signal_imaginary  =  mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
360360
361-     # conv with DFT kernel across the input signal 
362-     # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is: 
363-     # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i) 
364-     # If x is complex then x[n]=(a+i*b) 
365-     # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) 
366-     # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) 
361+     # Convolve the DFT kernel with the input signal 
362+     # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n]) 
363+     #   real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k)) 
364+     #   imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k)) 
365+     # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k): 
366+     #   real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k)) 
367+     #   imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k)) 
367368    cos_windows_real  =  mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
368369    sin_windows_real  =  mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
369370    if  input_imaginary :
@@ -372,11 +373,11 @@ def _stft(
372373
373374    # add everything together 
374375    if  input_imaginary :
375-         real_result  =  mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376-         imag_result  =  mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
376+         real_result  =  mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
377+         imag_result  =  mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
377378    else :
378379        real_result  =  cos_windows_real 
379-         imag_result  =  mb . sub ( x = 0. ,  y = sin_windows_real ,  before_op = before_op ) 
380+         imag_result  =  sin_windows_real 
380381
381382    # reduce the rank of the output 
382383    if  should_increase_rank :
@@ -417,10 +418,10 @@ def _istft(
417418    # By default, use the entire frame 
418419    win_length  =  win_length  or  n_fft 
419420
420-     input_shape  =  mb .shape (x = x , before_op = before_op )
421+     input_shape  =  mb .shape (x = input_real , before_op = before_op )
421422    n_frames  =  input_shape .val [- 1 ]
422423    fft_size  =  input_shape .val [- 2 ]
423-     #  expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
424+     expected_output_signal_len  =  n_fft .val  +  hop_length .val  *  (n_frames  -  1 )
424425
425426    is_onesided  =  onesided .val  if  onesided  else  fft_size  !=  n_fft 
426427    cos_base , sin_base  =  _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
@@ -447,14 +448,13 @@ def _istft(
447448        signal_real  =  mb .mul (x = signal_real , y = multiplier , before_op = before_op )
448449        signal_imaginary  =  mb .mul (x = signal_imaginary , y = multiplier , before_op = before_op )
449450
450-     # Conv with  DFT kernel across  the input signal 
451-     # We can describe the IDFT in terms of DFT just by swapping the input and output 
451+     # Convolve the  DFT kernel with  the input signal 
452+     # We can describe the IDFT in terms of DFT just by swapping the input and output.  
452453    # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT 
453-     # So IDFT(x) = (1/N) * swap(DFT(swap(x))) 
454-     # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i) 
455-     # If x is complex then x[n]=(a+i*b) 
456-     # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) 
457-     # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) 
454+     # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N 
455+     # So using the definition in stft function, we get: 
456+     #   real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n)) 
457+     #   imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n)) 
458458    cos_windows_real  =  mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459459    sin_windows_real  =  mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
460460    cos_windows_imag  =  mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
@@ -750,17 +750,21 @@ def _lower_complex_istft(op: Operation):
750750    is_complex  =  types .is_complex (op .input .dtype )
751751
752752    # check parameters for validity 
753+     if  is_complex :
754+         raise  ValueError ("Only complex inputs are allowed" )
753755    if  op .win_length  and  op .win_length .val  >  op .n_fft .val :
754756        raise  ValueError ("Window length must be less than or equal to n_fft" )
755-     if  is_complex  and  op .onesided  and  op .onesided .val :
756-         raise  ValueError ("Onesided  is only valid for real inputs " )
757+     if  op . return_complex  and  op .onesided  and  op .onesided .val :
758+         raise  ValueError ("Complex output  is not compatible with onesided " )
757759
758760    real , imag  =  _istft (
759-         op .input .real  if  is_complex  else  op .input ,
760-         op .input .imag  if  is_complex  else  None ,
761-         op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , before_op = op )
761+         op .input .real , op .input .imag ,
762+         op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , op .length , before_op = op )
762763
763-     return  _wrap_complex_output (op .outputs [0 ], real , imag )
764+     if  op .return_complex :
765+         return  _wrap_complex_output (op .outputs [0 ], real , imag )
766+     else 
767+         return  real 
764768
765769
766770@LowerComplex .register_lower_func (op_type = "complex_shape" ) 
0 commit comments