Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jmlemercier committed Nov 16, 2023
0 parents commit 5fcac4d
Show file tree
Hide file tree
Showing 20 changed files with 815 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
models/*
results/E2EpWPE+DNNPF_HA
results/E2EpWPE_HA
Binary file added 2sderev
Binary file not shown.
127 changes: 127 additions & 0 deletions derev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
import argparse
import os
from os.path import join
import soundfile as sf
import json
import glob

from dnn import backbones
from dsp import wpe, pf

def ensure_dir(d):
if not os.path.exists(d):
os.makedirs(d)

class Dereverberator():

def __init__(self, params):
self.F = params["n_fft"] // 2 + 1
self.channels = params["channels"]

self.wpe = getattr(wpe, params["wpe"]["class"])(params, params["wpe"])
self.dnn_wpe = getattr(backbones, params["dnn_wpe"]["class"])(**params["dnn_wpe"])
self.dnn_wpe.load_state_dict(torch.load(params["dnn_wpe"]["ckpt"]))
self.pf = getattr(pf, params["pf"]["class"])(params, params["pf"])
self.dnn_pf = getattr(backbones, params["dnn_pf"]["class"])(**params["dnn_pf"])
self.dnn_pf.load_state_dict(torch.load(params["dnn_pf"]["ckpt"]))

self.dnn_wpe.load_stats(params["dnn_wpe"]["stats"])
self.dnn_pf.load_stats(params["dnn_pf"]["stats"])

def to(self, device):
self.device = device
self.wpe = self.wpe.to(device)
self.dnn_wpe = self.dnn_wpe.to(device)
self.pf = self.pf.to(device)
self.dnn_pf.to(device)
return self

def reset_state(self):
self.dnn_wpe.reset_state()
self.wpe.reset_state()
self.dnn_pf.reset_state()
self.pf.reset_state()

def process(self, Y, **kwargs):
F, D, T = Y.size()

assert self.F == F, f"Mismatch in frequency bins : {self.F} model vs {F} input"
assert self.channels == D, f"Mismatch in channels : {self.channels} model vs {D} input"

X = torch.zeros_like(Y) #[F, D, T]
with torch.no_grad():
for _ in range(2): #initialize WPE statistics to get the best performance (optional)
for t in range(T):
Y_update = Y[..., t] #F,D

# DNN-WPE
clean_mag, _ = self.dnn_wpe(Y_update.unsqueeze(-1), return_interference=False) #F,1,1
clean_periodogram = torch.square(clean_mag.squeeze()) #F
# WPE
X[..., t] = self.wpe.step_online(Y_update, clean_periodogram.squeeze()) #[F, D]
# DNN-PF
speech_mag, interference_mag = self.dnn_pf(X[..., t].unsqueeze(-1), singlechannel=False, return_interference=True) #F,D,1
speech_periodogram, interference_periodogram = torch.square(speech_mag.squeeze()), torch.square(interference_mag.squeeze()) #F,D
# PF
X[..., t] = self.pf.step_online(X[..., t], speech_periodogram, interference_periodogram) #[F, D]

X[: 2] = .0 + 0j
return X.cpu()



if __name__ == "__main__":

with open("derev_params.json", "r") as j:
params_dict = json.load(j)

parser = argparse.ArgumentParser()
parser.add_argument('--speech', type=str, help="Input multi-channel file")
parser.add_argument('--config', type=str, choices=list(params_dict.keys()),
default='wpe+pf_ha', help="Choice of parameterization from derev_params.json")
args = parser.parse_args()

params = params_dict[args.config]

print("Dereverberating file/dir...")

if params["window_type"] == "hann":
window = torch.hann_window(params["n_fft"])
elif params["window_type"] == "sqrt_hann":
window = torch.sqrt(torch.hann_window(params["n_fft"]))
else:
raise NotImplementedError

istft_kwargs = {
"n_fft": params["n_fft"],
"hop_length": params["hop_length"],
"window": window,
"center": True
}
stft_kwargs = {
**istft_kwargs,
"return_complex": True
}

if os.path.isdir(args.speech):
speech = sorted(glob.glob(join(args.speech, "*.wav")))
else:
speech = [args.speech]

dereverberator = Dereverberator(params)

for speech_path in speech:

dereverberator.reset_state()
dereverberator = dereverberator.to(torch.device("cuda:0"))

y, sr = sf.read(speech_path) #d,t
Y = torch.stft(torch.from_numpy(y).transpose(0, 1), **stft_kwargs) #d,f,t
Y = Y.permute(1, 0, 2)
X = dereverberator.process(Y.to(dereverberator.device))
X = X.permute(1, 0, 2)
x = torch.istft(X, **istft_kwargs).transpose(0, 1).numpy()

ensure_dir(join("results", params["name"]))
sf.write(join("results", params["name"], os.path.basename(speech_path)), x, sr)
122 changes: 122 additions & 0 deletions derev_params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"wpe_ci": {
"name": "E2EpWPE_CI",
"fs": 16000,
"channels": 2,
"n_fft": 512,
"hop_length": 128,
"window_type": "sqrt_hann",
"wpe": {
"class": "RLSWPE",
"taps": 10,
"delay": 2,
"alpha": 0.99,
"eps_wpe": 1e-3
},
"dnn_wpe": {
"class": "LSTMNet",
"ckpt": "models/E2Ep_WPE_L1_CI_pra_crop.WPENet.pt",
"stats": "stats/reverberant/tr_{}_noiseless_reverberant_abs.pt"
},
"pf": {
"class": "Bypass"
},
"dnn_pf": {
"class": "Bypass",
"ckpt": "/export/home/lemercier/code/nwpe/.logs/subnetworks/void.pt",
"stats": "none"
}
},
"wpe+pf_ci": {
"name": "E2EpWPE+DNNPF_CI",
"fs": 16000,
"channels": 2,
"n_fft": 512,
"hop_length": 128,
"window_type": "sqrt_hann",
"wpe": {
"class": "RLSWPE",
"taps": 10,
"delay": 2,
"alpha": 0.99,
"eps_wpe": 1e-3
},
"dnn_wpe": {
"class": "LSTMNet",
"ckpt": "models/E2Ep_WPE_L1_CI_pra_crop.WPENet.pt",
"stats": "stats/reverberant/tr_{}_noiseless_reverberant_abs.pt"
},
"pf": {
"class": "WienerFilter",
"alpha_s": 0.20,
"alpha_n": 0.20,
"gmin": -20,
"bleeding": 0.00
},
"dnn_pf": {
"class": "LSTMNet",
"ckpt": "models/LSTM_L1_CI_pra_crop.PFNet.pt",
"stats": "stats/processed_CI/tr_{}_noiseless_reverberant_abs.pt"
}
},
"wpe_ha": {
"name": "E2EpWPE_HA",
"fs": 16000,
"channels": 2,
"n_fft": 512,
"hop_length": 128,
"window_type": "sqrt_hann",
"wpe": {
"class": "RLSWPE",
"taps": 10,
"delay": 5,
"alpha": 0.99,
"eps_wpe": 1e-3
},
"dnn_wpe": {
"class": "LSTMNet",
"ckpt": "models/E2Ep_WPE_L1_HA_pra_crop.WPENet.pt",
"stats": "stats/reverberant/tr_{}_noiseless_reverberant_abs.pt"
},
"pf": {
"class": "Bypass"
},
"dnn_pf": {
"class": "Bypass",
"ckpt": "/export/home/lemercier/code/nwpe/.logs/subnetworks/void.pt",
"stats": "none"
}
},
"wpe+pf_ha": {
"name": "E2EpWPE+DNNPF_HA",
"fs": 16000,
"channels": 2,
"n_fft": 512,
"hop_length": 128,
"window_type": "sqrt_hann",
"wpe": {
"class": "RLSWPE",
"taps": 10,
"delay": 5,
"alpha": 0.99,
"eps_wpe": 1e-3
},
"dnn_wpe": {
"class": "LSTMNet",
"ckpt": "models/E2Ep_WPE_L1_HA_pra_crop.WPENet.pt",
"stats": "stats/reverberant/tr_{}_noiseless_reverberant_abs.pt"
},
"pf": {
"class": "WienerFilter",
"alpha_s": 0.20,
"alpha_n": 0.20,
"gmin": -20,
"bleeding": 0.00
},
"dnn_pf": {
"class": "LSTMNet",
"ckpt": "models/LSTM_L1_HA_pra_crop.PFNet.pt",
"stats": "stats/processed_HA/tr_{}_noiseless_reverberant_abs.pt"
}
}
}
38 changes: 38 additions & 0 deletions dnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
### Two-stage Dereverbreration Algorithm

<img src="https://raw.githubusercontent.com/sp-uhh/2sderev/main/2sderev.png" width="500" alt="Spectrograms obtained from reverberant and dereverberated files.">

# 1. Installation

Install requirements with `pip instlal -r requirements.txt`
GPU-acceleration is supported for both the DNNs and the DSP operations

# 2. Usage

This code is for inference only, training loops are unfortunately not being made available
To download the models, please use [this link](https://drive.google.com/drive/folders/11r2LWqeE_EUW25MfVIp3vzzwRURQYsBy?usp=drive_link) and put the obtained `.pt` files in `./models`

To perform inference, simply use
```python3 derev.py --speech <speech_file_path> --config <config_key>```

with one of the following `config_key`:
- `wpe_ci`: End-to-end optimized multi-channel linear filter targeted for cochlear implant users (few early reflections)
- `wpe_ha`: End-to-end optimized multi-channel linear filter targeted for hearing-aid users (more early reflections)
- `wpe+pf_ci`: End-to-end optimized multi-channel linear filter + non-linear single-channel post-filter targeted for cochlear implant users (few early reflections)
- `wpe+pf_ha` (default, recommended): End-to-end optimized multi-channel linear filter + non-linear single-channel post-filter targeted for hearing-aid users (more early reflections)

# References

Please consider citing our work if you found this useful:

```
@article{lemercier2022a,
author={Lemercier, Jean-Marie and Thiemann, Joachim and Koning, Raphael and Gerkmann, Timo},
title={A neural network‐supported two‐stage algorithm for lightweight dereverberation on hearing devices},
year={2023},
journal={EURASIP Journal on Audio, Speech, and Music Processing},
volume={18},
pages={1-12},
doi={https://doi.org/10.1186/s13636-023-00285-8},
}
```
Binary file added dnn/__pycache__/backbones.cpython-38.pyc
Binary file not shown.
Loading

0 comments on commit 5fcac4d

Please sign in to comment.