Skip to content

Commit 99062ff

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Simplify stft
1 parent 861d028 commit 99062ff

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,7 @@ def _stft(
338338
input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op)
339339

340340
is_onesided = onesided and onesided.val
341-
cos_base, sin_base = _calculate_dft_matrix(
342-
n_fft,
343-
onesided=is_onesided,
344-
before_op=before_op)
341+
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)
345342

346343
# create a window of centered 1s of the requested size
347344
if win_length:
@@ -352,29 +349,46 @@ def _stft(
352349
cos_base = mb.mul(x=window, y=cos_base, before_op=before_op)
353350
sin_base = mb.mul(x=window, y=sin_base, before_op=before_op)
354351

355-
# conv with DFT kernel across the input signal
356-
sin_base = mb.sub(x=0., y=sin_base, before_op=before_op)
352+
353+
# Expand
357354
cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op)
358355
sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op)
359356
hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op)
360-
361357
signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op)
358+
if input_imaginary:
359+
signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op)
360+
361+
# conv with DFT kernel across the input signal
362+
# The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
363+
# DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
364+
# If x is complex then x[n]=(a+i*b)
365+
# So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
366+
# So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
362367
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
363368
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
364-
365369
if input_imaginary:
366-
signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op)
367370
cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
368371
sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
369372

370373
# add everything together
371374
if input_imaginary:
372-
# sin base is already negative so subtract
373-
real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
374-
imag_result = mb.add(x=sin_windows_real, y=cos_windows_imag, before_op=before_op)
375+
real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
376+
imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
375377
else:
376378
real_result = cos_windows_real
377-
imag_result = sin_windows_real
379+
imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op)
380+
381+
# reduce the rank of the output
382+
if should_increase_rank:
383+
real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op)
384+
imag_result = mb.squeeze(x=imag_result, axes=(0,), before_op=before_op)
385+
386+
if normalized and normalized.val:
387+
divisor = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op)
388+
real_result = mb.real_div(x=real_result, y=divisor, before_op=before_op)
389+
imag_result = mb.real_div(x=imag_result, y=divisor, before_op=before_op)
390+
391+
return real_result, imag_result
378392

379393
def _istft(
380394
input_real: Var,

0 commit comments

Comments
 (0)