Skip to content

Commit

Permalink
execute pylint test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tech. Prototyping그룹 정승환 committed Dec 27, 2018
1 parent 601e4e1 commit 4b29569
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 94 deletions.
65 changes: 48 additions & 17 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,60 @@
"""
Description : Set DataSet module for Wavenet
"""
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License
import os
import numpy as np
from scipy.io import wavfile
from utils import *
from mxnet import gluon, autograd, nd

from mxnet import nd
from utils import encode_mu_law
# pylint: disable=invalid-name, too-many-arguments
def load_wav(file_nm):
"""
Description : load wav file
"""
fs, data = wavfile.read(os.getcwd()+'/data/'+file_nm)
return fs, data

def data_generation(data,framerate, seq_size, mu, ctx):
#t = np.linspace(0,5,framerate*5)
#data = np.sin(2*np.pi*220*t) + np.sin(2*np.pi*224*t)
div = max(data.max(),abs(data.min()))
def data_generation(data, framerate, seq_size, mu, ctx, gen_mode=None):
"""
Description : data generation to loading data
"""
if gen_mode == 'sin':
t = np.linspace(0, 5, framerate*5)
data = np.sin(2*np.pi*220*t) + np.sin(2*np.pi*224*t)
div = max(data.max(), abs(data.min()))
data = data/div
while True:
start = np.random.randint(0,data.shape[0]-seq_size)
start = np.random.randint(0, data.shape[0]-seq_size)
ys = data[start:start+seq_size]
ys = encode_mu_law(ys,mu)
yield nd.array(ys[:seq_size],ctx=ctx)

def data_generation_sample(data, framerate, seq_size, mu, ctx):
#t = np.linspace(0,5,framerate*5)
#data = np.sin(2*np.pi*220*t) + np.sin(2*np.pi*224*t)
div = max(data.max(),abs(data.min()))
ys = encode_mu_law(ys, mu)
yield nd.array(ys[:seq_size], ctx=ctx)

def data_generation_sample(data, framerate, seq_size, mu, ctx, gen_mode=None):
"""
Description : sample data generation to loading data
"""
if gen_mode == 'sin':
t = np.linspace(0, 5, framerate*5)
data = np.sin(2*np.pi*220*t) + np.sin(2*np.pi*224*t)
div = max(data.max(), abs(data.min()))
data = data/div
start = 0
ys = data[start:start+seq_size]
ys = encode_mu_law(ys,mu)
return nd.array(ys[:seq_size],ctx=ctx)
ys = encode_mu_law(ys, mu)
return nd.array(ys[:seq_size], ctx=ctx)
33 changes: 27 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
import mxnet as mx
"""
Descrition : main module to run code
"""
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
from trainer import Train

def main():
"""
Description : run code using argument info
"""
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epoches', type=int, default=10`)
parser.add_argument('--epoches', type=int, default=10)
parser.add_argument('--mu', type=int, default=128)
parser.add_argument('--n_residue', type=int, default=24)
parser.add_argument('--n_skip', type=int, default=128)
Expand All @@ -16,12 +37,12 @@ def main():
parser.add_argument('--use_gpu', type=bool, default=True)
parser.add_argument('--generation', type=bool, default=True)
config = parser.parse_args()

trainer = Train(config)

trainer.train()
if (config.generation):
if config.generation:
trainer.generation()

if __name__ =="__main__":
if __name__ == "__main__":
main()
109 changes: 70 additions & 39 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,56 @@
import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn,utils
import mxnet.ndarray as F

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Module: WaveNet network modulep
"""
from mxnet import nd
from mxnet.gluon import nn
import mxnet.ndarray as F
# pylint: disable=invalid-name, too-many-arguments, arguments-differ, attribute-defined-outside-init, too-many-instance-attributes, invalid-sequence-index, no-self-use
class One_Hot(nn.Block):
"""
Description : generate one hot result
"""
def __init__(self, depth):
super(One_Hot,self).__init__()
super(One_Hot, self).__init__()
self.depth = depth

def forward(self, X_in):
with X_in.context:
X_in = X_in
self.ones = nd.one_hot(nd.arange(self.depth),self.depth)
return self.ones[X_in,:]
self.ones = nd.one_hot(nd.arange(self.depth), self.depth)
return self.ones[X_in, :]

def __repr__(self):
return self.__class__.__name__ + "({})".format(self.depth)

class WaveNet(nn.Block):
def __init__(self, mu=256,n_residue=32, n_skip= 512, dilation_depth=10, n_repeat=5):
# mu: audio quantization size
# n_residue: residue channels
# n_skip: skip channels
# dilation_depth & n_repeat: dilation layer setup
"""
mu: audio quantization size
n_residue: residue channels
n_skip: skip channels
dilation_depth & n_repeat: dilation layer setup
"""
def __init__(self, mu=256, n_residue=32, n_skip=512, dilation_depth=10, n_repeat=5):
super(WaveNet, self).__init__()
self.dilation_depth = dilation_depth
self.dilations = [2**i for i in range(dilation_depth)] * n_repeat
self.dilations = [2**i for i in range(dilation_depth)] * n_repeat
with self.name_scope():
self.one_hot = One_Hot(mu)
self.from_input = nn.Conv1D(in_channels=mu, channels=n_residue, kernel_size=1)
Expand All @@ -34,54 +59,60 @@ def __init__(self, mu=256,n_residue=32, n_skip= 512, dilation_depth=10, n_repeat
self.skip_scale = nn.Sequential()
self.residue_scale = nn.Sequential()
for d in self.dilations:
self.conv_sigmoid.add(nn.Conv1D(in_channels=n_residue, channels=n_residue, kernel_size=2, dilation=d))
self.conv_tanh.add(nn.Conv1D(in_channels=n_residue, channels=n_residue, kernel_size=2, dilation=d))
self.skip_scale.add(nn.Conv1D(in_channels=n_residue, channels=n_skip, kernel_size=1, dilation=d))
self.residue_scale.add(nn.Conv1D(in_channels=n_residue, channels=n_residue, kernel_size=1, dilation=d))
self.conv_sigmoid.add(nn.Conv1D(in_channels=n_residue,\
channels=n_residue, kernel_size=2, dilation=d))
self.conv_tanh.add(nn.Conv1D(in_channels=n_residue,\
channels=n_residue, kernel_size=2, dilation=d))
self.skip_scale.add(nn.Conv1D(in_channels=n_residue,\
channels=n_skip, kernel_size=1, dilation=d))
self.residue_scale.add(nn.Conv1D(in_channels=n_residue,\
channels=n_residue, kernel_size=1, dilation=d))
self.conv_post_1 = nn.Conv1D(in_channels=n_skip, channels=n_skip, kernel_size=1)
self.conv_post_2 = nn.Conv1D(in_channels=n_skip, channels=mu, kernel_size=1)
def forward(self,x):

def forward(self, x):
with x.context:
output = self.preprocess(x)
skip_connections = [] # save for generation purposes
for s, t, skip_scale, residue_scale in zip(self.conv_sigmoid, self.conv_tanh, self.skip_scale, self.residue_scale):
output, skip = self.residue_forward(output, s, t, skip_scale, residue_scale)
for s, t, skip_scale, residue_scale in zip(self.conv_sigmoid, self.conv_tanh,\
self.skip_scale, self.residue_scale):
output, skip = self.residue_forward(output, s, t,\
skip_scale, residue_scale)
skip_connections.append(skip)
# sum up skip connections
output = sum([s[:,:,-output.shape[2]:] for s in skip_connections])
output = sum([s[:, :, -output.shape[2]:] for s in skip_connections])
output = self.postprocess(output)
return output

def preprocess(self, x):
output = F.transpose(self.one_hot(x).expand_dims(0),axes=(0,2,1))
"""
Description : module for preprocess
"""
output = F.transpose(self.one_hot(x).expand_dims(0), axes=(0, 2, 1))
output = self.from_input(output)
return output

def postprocess(self, x):
"""
Description : module for postprocess
"""
output = F.relu(x)
output = self.conv_post_1(output)
output = F.relu(output)
output = self.conv_post_2(output)
output = nd.reshape(output,(output.shape[1],output.shape[2]))
output = F.transpose(output,axes=(1,0))
output = nd.reshape(output, (output.shape[1], output.shape[2]))
output = F.transpose(output, axes=(1, 0))
return output

def residue_forward(self, x, conv_sigmoid, conv_tanh, skip_scale, residue_scale):
"""
Description : module for residue forward
"""
output = x
output_sigmoid, output_tanh = conv_sigmoid(output), conv_tanh(output)
output = F.sigmoid(output_sigmoid) * F.tanh(output_tanh)
skip = skip_scale(output)
output = residue_scale(output)
output = output + x[:,:,-output.shape[2]:]
output = output + x[:, :, -output.shape[2]:]
return output, skip

def generate_slow(self, x, n=100):
with x.context:
res = list(x.asnumpy())
for _ in range(n):
x_ = nd.array(res[-sum(self.dilations)-1:])
y = self.forward(x_)
#_, i = y.max(dim=1)
res.append(y.argmax(1).asnumpy()[-1])
return res

Loading

0 comments on commit 4b29569

Please sign in to comment.