This work would not be possible without cloud resources provided by Google's TPU Research Cloud (TRC) program. I also thank the TRC support team for quickly resolving whatever issues I had: you're awesome!
This is a PyTorch implementation of the LEAF audio frontend [1], made using the official tensorflow implementation as a direct reference.
This implementation supports training on TPUs using torch-xla
.
- Will be evaluated on AudioSet, SpeechCommands and Voxceleb1 datasets, and pretrained weights will be made available.
- Currently,
torch-xla
has some issues with certaincomplex64
operations:torch.view_as_real(comp)
,comp.real
,comp.imag
as highlighted in #Issue 3070. These are used primarily for generating gabor impulse responses. To bypass this shortcoming, an alternate implementation using manual complex number operations is provided. - Matched performance on SpeechCommands, experiments on other datasets ongoing
- More details for commands to replicate experiments will be added shortly
torch >= 1.9.0
torchaudio >= 0.9.0
torch-audiomentations==0.9.0
SoundFile==0.10.3.post1
[Optional] torch_xla == 1.9
Additional dependencies include
[WavAugment](https://github.com/facebookresearch/WavAugment)
All experiments on VoxCeleb1 and SpeechCommands were repeated at least 5 times, and 95%
ci are reported.
Model | Dataset | Metric | features | Official | This repo | weights |
---|---|---|---|---|---|---|
EfficientNet-b0 | SpeechCommands v2 | Accuracy | LEAF | 93.4±0.3 | 94.5±0.3 | ckpt |
ResNet-18 | SpeechCommands v2 | Accuracy | LEAF | N/A | 94.05±0.3 | ckpt |
EfficientNet-b0 | VoxCeleb1 | Accuracy | LEAF | 33.1±0.7 | 40.9±1.8 | ckpt |
ResNet-18 | VoxCeleb1 | Accuracy | LEAF | N/A | 44.7±2.9 | ckpt |
- ResNet-18 likely works better for VoxCeleb1 simply because it's a more difficult task than SpeechCommands and ResNet-18 has more parameters.
To evaluate how non-Mel
initialization schemes for complex_conv
work, experiments were repeated on xavier_normal
, kaiming_normal
and randn
init schemes on the SpeechCommands dataset.
Model | Features | Init | Test Accuracy |
---|---|---|---|
EfficientNet-b0 | LEAF | Default (Mel) |
94.5±0.3 |
EfficientNet-b0 | LEAF | randn |
84.7±1.6 |
EfficientNet-b0 | LEAF | kaiming_normal |
84.7±2.3 |
EfficientNet-b0 | LEAF | xavier_normal |
79.1±0.7 |
- download and extract desired ckpt from Results.
import os
import torch
import pickle
from models.classifier import Classifier
results_dir = "<path to results folder>"
hparams_path = os.path.join(results_dir, "hparams.pickle")
ckpt_path = os.path.join(results_dir, "ckpts", "<checkpoint.pth>")
checkpoint = torch.load(ckpt_path)
with open(hparams_path, "rb") as fp:
hparams = pickle.load(fp)
model = Classifier(hparams.cfg)
print(model.load_state_dict(checkpoint['model_state_dict']))
# to access just the pretrained LEAF frontend
frontend = model.features
[1] If you use this repository, kindly cite the LEAF paper:
@article{zeghidour2021leaf,
title={LEAF: A Learnable Frontend for Audio Classification},
author={Zeghidour, Neil and Teboul, Olivier and de Chaumont Quitry, F{\'e}lix and Tagliasacchi, Marco},
journal={ICLR},
year={2021}
}
Please also consider citing this implementation using the citation widget in the sidebar.