Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,30 @@ with fused residual and skip connections.
python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6
```

## Finetuning the official checkpoint with your own data

The "official" checkpoint above was trained using an older version of the code.
Therefore, you need to use `glow_old.py` to continue training from the official
checkpoint:

1. Download our [published model]
2. Update the checkpoint to comply with recent code modifications:

`python convert_model.py waveglow_old.pt waveglow_old_updated.pt`

3. Perform steps 1 and 2 from the section above

4. Set `"checkpoint_path": "./waveglow_old_updated.pt"` in `config.json`

5. Train your WaveGlow networks with `OLD_GLOW=1` (not yet tested with
`distributed.py`)

```command
mkdir checkpoints
OLD_GLOW=1 python train.py -c config.json
```


[//]: # (TODO)
[//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS)
[pytorch 1.0]: https://github.com/pytorch/pytorch#installation
Expand Down
2 changes: 1 addition & 1 deletion glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, z, reverse=False):
return z
else:
# Forward computation
log_det_W = batch_size * n_of_groups * torch.logdet(W)
log_det_W = batch_size * n_of_groups * torch.slogdet(W)[1]
z = self.conv(z)
return z, log_det_W

Expand Down
43 changes: 21 additions & 22 deletions glow_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,9 @@ def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
self.n_remaining_channels = n_remaining_channels # Useful during inference

def forward(self, forward_input):
return None
"""
forward_input[0] = audio: batch x time
forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time
"""
forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
forward_input[1] = audio: batch x time
"""
spect, audio = forward_input

Expand All @@ -135,39 +133,40 @@ def forward(self, forward_input):

audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
output_audio = []
s_list = []
s_conv_list = []
log_s_list = []
log_det_W_list = []

for k in range(self.n_flows):
if k%4 == 0 and k > 0:
output_audio.append(audio[:,:self.n_multi,:])
audio = audio[:,self.n_multi:,:]
if k % self.n_early_every == 0 and k > 0:
output_audio.append(audio[:,:self.n_early_size,:])
audio = audio[:,self.n_early_size:,:]

# project to new basis
audio, s = self.convinv[k](audio)
s_conv_list.append(s)
audio, log_det_W = self.convinv[k](audio)
log_det_W_list.append(log_det_W)

n_half = int(audio.size(1)/2)

if k%2 == 0:
audio_0 = audio[:,:n_half,:]
audio_1 = audio[:,n_half:,:]
else:
audio_1 = audio[:,:n_half,:]
audio_0 = audio[:,n_half:,:]

output = self.nn[k]((audio_0, spect))
s = output[:, n_half:, :]
output = self.WN[k]((audio_0, spect))
log_s = output[:, n_half:, :]
b = output[:, :n_half, :]
audio_1 = torch.exp(s)*audio_1 + b
s_list.append(s)
audio_1 = torch.exp(log_s)*audio_1 + b
log_s_list.append(log_s)

if k%2 != 0:
audio_0, audio_1 = audio_1, audio_0

audio = torch.cat([audio_0, audio_1],1)

if k%2 == 0:
audio = torch.cat([audio[:,:n_half,:], audio_1],1)
else:
audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
output_audio.append(audio)
return torch.cat(output_audio,1), s_list, s_conv_list
"""
return torch.cat(output_audio,1), log_s_list, log_det_W_list


def infer(self, spect, sigma=1.0):
spect = self.upsample(spect)
Expand Down
19 changes: 16 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,27 @@
#=====END: ADDED FOR DISTRIBUTED======

from torch.utils.data import DataLoader
from glow import WaveGlow, WaveGlowLoss
if 'OLD_GLOW' in os.environ and os.environ['OLD_GLOW'] == '1':
print("Warning! Using old_glow.py instead of glow.py for training")
from glow_old import WaveGlow
else:
from glow import WaveGlow

from glow import WaveGlowLoss
from mel2samp import Mel2Samp

def load_checkpoint(checkpoint_path, model, optimizer):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
optimizer.load_state_dict(checkpoint_dict['optimizer'])

if 'iteration' in checkpoint_dict:
iteration = checkpoint_dict['iteration']
else:
iteration = 0

if 'optimizer' in checkpoint_dict:
optimizer.load_state_dict(checkpoint_dict['optimizer'])

model_for_loading = checkpoint_dict['model']
model.load_state_dict(model_for_loading.state_dict())
print("Loaded checkpoint '{}' (iteration {})" .format(
Expand Down