-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinverse_utils.py
66 lines (52 loc) · 1.92 KB
/
inverse_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torch import optim
from tqdm import tqdm
import os
import json
import librosa
import torch.nn.functional as F
import random
import numpy as np
import pickle
import scipy
import generator.glow.utils as glowutils
import generator.glow.models as glowmodels
import generator.glow.commons as commons
MAX_WAV_VALUE = 32768.0
def get_spec(audio, stft):
audio = torch.FloatTensor(audio.astype(np.float32))
audio_norm = audio / MAX_WAV_VALUE
mag, pha = stft.stft_fn.transform(audio_norm)
spec = mag.data
spec = torch.squeeze(spec, 0)
return spec, pha
def load_glow(glowFolder='./generator/glow/logs/', modelName="musdb", epoch=None):
modelDir = os.path.join(glowFolder, modelName)
hps = glowutils.get_hparams_from_dir(modelDir)
if epoch == None:
checkpointPath = glowutils.latest_checkpoint_path(modelDir)
else:
checkpointPath = os.path.join(glowFolder, modelName, 'G_'+str(epoch)+'.pth')
generator = glowmodels.FlowGenerator(n_speakers=1, out_channels=hps.data.n_ipt_channels,
**hps.model).cuda()
glowutils.load_checkpoint(checkpointPath, generator)
generator.eval()
# import stft operator
hparams = hps.data
stft = commons.TacotronSTFT(hparams.filter_length,
hparams.hop_length,
hparams.win_length,
hparams.sampling_rate)
return generator, stft
def resynthesize_from_spec(specIPT, mixWav, stft):
# extract phase from original input
mixTensor = torch.FloatTensor(mixWav.astype(np.float32))
_, mixPhase = stft.stft_fn.transform(mixTensor.unsqueeze(0))
# get STFT
xEst = stft.stft_fn.inverse(specIPT.unsqueeze(0), mixPhase).cpu().numpy()[0]
return xEst
def main():
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
load_glow()
if __name__ == "__main__":
main()