This repository contains the code for training many of the models in nussl. It contains a few extra utilities, such as a simple way to schedule jobs with GPUs and recipes for creating datasets.
conda env create -f conda.yml
conda activate nussl-models
Set the paths appropriately in env.sh. Please ensure all paths in env.sh are absolute paths. Then source env.sh.
If you update the conda.yml file to say, add new requirements, update the conda environment after activating it:
conda env update -f conda.yml --prune
Note that the --prune option will get rid of any packages that were removed in conda.yml.
All of the experiments in this repository use gin-config, which is a flexible way to specify hyperparameters and configuration in a hierarchical and reusable fashion. There are 5 main functions whose arguments are all configured with gin:
train: This function will train a model. This needs datasets (seewham/wham8k.ginfor an example), and model and training parameters (seewham/exp/dpcl.ginfor an example).cache: This function takes the datasets and caches them to the desired location.evaluate: This function evaluates a separation algorithm on the test dataset.instantiate: This function instantiates an experiment by compiling a gin config file on the fly, possibly sweeping across some hyperparameters, and writing it all to an output directory.analyze: This analyzes the experiment after evaluation. Spits out a report card with the metrics.
Finally all will run train, then evaluate, then analyze.
The main.py script takes 3 arguments and 1 positional argument. It takes an
environment config file (--environment_config, -env) which sets all the
machine-specific variables like path to data directories and so on, a
data config file (--data_config, -dat) which describes all of the datasets,
and an experiment config file (--experiment_config, -exp) which describes
the model settings, the optimizer, the algorithm settings, and whatever else
is needed. Generally, you'll want to follow a process like this:
python main.py --dat [path/to/data.gin] --env [path/to/env.gin] --exp [path/to/exp.gin] cache
This will cache all of the datasets so that things train faster. Now, go and set the key
that says cache_populated = False to cache_populated = True in your data.gin file.
Then:
python main.py --dat [path/to/data.gin] --env [path/to/env.gin] --exp [path/to/exp.gin] instantiate
This will instantiate your experiments into a single config file and place that config
into a folder. The location of the folder depends on the location of the path to the
experiment config. If the config file, for example, was at wham/exp/dpcl.gin, then
the output will be at wham/exp/out/dpcl/run0:[descriptive_string]/config.gin. If you
had a sweep in your experiment config file, then you'll see multiple of these. For
example, if you had lines in your experiment config like this:
sweep.parameters = {
'add_clip_gradient_handler.clip_value': [1e-4, 1e-3, 1e-2]
}
Then, the output configs will be at:
wham/exp/out/dpcl/run0:clip_value:0.0001/config.gin
wham/exp/out/dpcl/run1:clip_value:0.001/config.gin
wham/exp/out/dpcl/run2:clip_value:0.01/config.gin
Now, you're ready to run experiments. To run, say the first experiment you would do:
python main.py --exp wham/exp/out/dpcl/run0:clip_value:0.0001/config.gin train
python main.py --exp wham/exp/out/dpcl/run0:clip_value:0.0001/config.gin evaluate
python main.py --exp wham/exp/out/dpcl/run0:clip_value:0.0001/config.gin analyze
Or to do all of them in one fell swoop:
python main.py --exp wham/exp/out/dpcl/run0:clip_value:0.0001/config.gin all
At the end, you'll see some output like this:
❯ python main.py --exp wham/exp/out/dpcl/run0:clip_value:0.0001/config.gin all
/Users/prem/miniconda3/envs/nussl-models/lib/python3.7/site-packages/torch/nn/modules/rnn.py:50: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.3 and num_layers=1
"num_layers={}".format(dropout, num_layers))
2020-04-01:02:36:57,921 [train.py:39] SeparationModel(
(layers): ModuleDict(
(embedding): Embedding(
(linear): Linear(in_features=100, out_features=2580, bias=True)
)
(log_spectrogram): AmplitudeToDB()
(normalization): BatchNorm(
(batch_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(recurrent_stack): RecurrentStack(
(rnn): LSTM(129, 50, batch_first=True, dropout=0.3, bidirectional=True)
)
)
)
Number of parameters: 332982
2020-04-01:02:36:57,922 [train.py:45] Saving to /Users/prem/Dropbox/research/nussl-models/wham/exp/out/dpcl/run0:clip_value:0.0001
2020-04-01:02:36:57,924 [engine.py:837] Engine run starting with max_epochs=1.
2020-04-01:02:36:59,235 [engine.py:939] Epoch[1] Complete. Time taken: 00:00:01
2020-04-01:02:36:59,236 [engine.py:837] Engine run starting with max_epochs=1.
2020-04-01:02:36:59,875 [engine.py:939] Epoch[1] Complete. Time taken: 00:00:00
2020-04-01:02:36:59,876 [engine.py:947] Engine run complete. Time taken 00:00:00
2020-04-01:02:36:59,907 [trainer.py:340]
EPOCH SUMMARY
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Epoch number: 0001 / 0001
- Training loss: 0.553462
- Validation loss: 0.552009
- Epoch took: 00:00:01
- Time since start: 00:00:01
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Saving to /Users/prem/Dropbox/research/nussl-models/wham/exp/out/dpcl/run0:clip_value:0.0001/checkpoints/best.model.pth.
Output @ /Users/prem/Dropbox/research/nussl-models/wham/exp/out/dpcl/run0:clip_value:0.0001
2020-04-01:02:36:59,911 [engine.py:947] Engine run complete. Time taken 00:00:01
/Users/prem/miniconda3/envs/nussl-models/lib/python3.7/site-packages/nussl/separation/base/separation_base.py:71: UserWarning: input_audio_signal has no data!
warnings.warn('input_audio_signal has no data!')
/Users/prem/miniconda3/envs/nussl-models/lib/python3.7/site-packages/nussl/core/audio_signal.py:445: UserWarning: Initializing STFT with data that is non-complex. This might lead to weird results!
warnings.warn('Initializing STFT with data that is non-complex. '
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:16<00:00, 3.37s/it]
┌────────────────────┬───────────────────┬────────────────────┐
│ │ OVERALL (N = 20) │ │
╞════════════════════╪═══════════════════╪════════════════════╡
│ SAR │ SDR │ SIR │
├────────────────────┼───────────────────┼────────────────────┤
│ 13.487467288970947 │ -5.10146589204669 │ -5.028965532779694 │
└────────────────────┴───────────────────┴────────────────────┘