-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsource_separation.py
102 lines (86 loc) · 4.05 KB
/
source_separation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from torch import optim
from tqdm import tqdm
import os, glob
from scipy.io.wavfile import read
import torch.nn.functional as F
import random
import numpy as np
import inverse_utils
import generator.glow.commons as commons
EPSILON = torch.finfo(torch.float32).eps
def music_sep_batch(mixtures, genList, stft, optSpace,
lr, sigma, alpha1, alpha2, iteration,
mask=False, wiener=False, scheduler_step=800,
scheduler_gamma=0.2):
# freeze generators weights
numGen = len(genList)
for genUnc in genList:
for param in genUnc.parameters():
param.requires_grad = False
# compute spectrogram from mixture and cancel the log
mixSpecs, mixPhases = inverse_utils.get_spec(mixtures, stft) # 513 * T
mixSpecs = F.pad(mixSpecs.unsqueeze(0), (0, 0, 0, 1), "constant", 0) # 514 * T
batch_size = mixSpecs.shape[0]
segLen = int(mixSpecs.shape[-1] / 2) * 2 # glow model requires even dimension
segLenTensor = torch.LongTensor([segLen]*batch_size).cuda()
mixSpecs = mixSpecs[:, :, :segLen].cuda().requires_grad_(False)
mixPhases = mixPhases[:, :, :segLen].requires_grad_(False)
# zCol of shape (batch_size, num_sources, *spec_shape...)
zCol = torch.randn((batch_size, numGen, mixSpecs.shape[-2], segLen),
dtype=torch.float, device='cuda')
zCol = (sigma * zCol).requires_grad_(True)
optimizer = optim.Adam([zCol], lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
# define loss and initialize data variables
for i in tqdm(range(iteration)):
xCol = []
mixSynSpecs = 0
logdets = []
z_masks = []
for j in range(numGen):
gen = genList[j]
xTemp, logdet, z_mask = gen(zCol[:, j, :, :], segLenTensor, gen=True)
logdets.append(-logdet) # logdet in reverse gives log|dx/dz|, we want log|dz/dx|
z_masks.append(z_mask)
if mask:
if i > 3:
maskTemp = torch.div(torch.sum(xTemp, dim=1),
torch.max(torch.sum(xTemp, dim=1),
torch.sum(mixSpecs, dim=1))+1e-8).unsqueeze(1)
else:
maskTemp = torch.ones((batch_size, 1, segLen), dtype=torch.float, device='cuda')
mixSynSpecs += xTemp * torch.pow(maskTemp, 1.0)
xCol.append(xTemp * maskTemp)
else:
mixSynSpecs += xTemp
xCol.append(xTemp)
mixSpecs = torch.abs(mixSpecs) + 1e-8
mixSynSpecs = torch.abs(mixSynSpecs) + 1e-8
loss_kl = (mixSpecs * torch.log(mixSpecs/mixSynSpecs + 1e-8) - mixSpecs + mixSynSpecs)
loss_mask = torch.ones((batch_size, mixSpecs.shape[-2], segLen),
dtype=torch.float, device='cuda')
loss_mask[:, -1, :] = 0.0
loss_rec = (loss_kl*loss_mask).mean()
# regularization
loss_r = 0.0
for j in range(numGen):
if optSpace == 'z':
lss = 0.5 * torch.sum(zCol[:, j, :, :] ** 2) # neg normal likelihood w/o the constant term
l_mle = lss / torch.sum(torch.ones_like(zCol[:, j, :, :]) * z_masks[j]) # averaging across batch, channel and time axes
elif optSpace == 'x':
l_mle = commons.mle_loss(zCol[:, j, :, :], logdets[j], z_masks[j]) # logdets and z_masks are first indexed by source_num
loss_gs = [l_mle]
loss_r += sum(loss_gs)
loss = alpha1 * loss_rec + alpha2 * loss_r #+ 0.1 * loss_coh
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
xCol_wiener = []
if wiener:
for i in range(len(xCol)):
xCol_wiener.append(torch.mul(torch.div(xCol[i], mixSynSpecs), mixSpecs))
return torch.stack(xCol_wiener), mixPhases
else:
return torch.stack(xCol), mixPhases