diff --git a/python/DNN_training/README.md b/python/DNN_training/README.md new file mode 100644 index 00000000..3a7d668a --- /dev/null +++ b/python/DNN_training/README.md @@ -0,0 +1,137 @@ +# DNNs For ASR +The work is part of a Google Summer of Code project, the goal of which was to integrate DNNs with CMUSphinx. This particular repository contains some convenient scripts that wrap Keras code and allow for easy training of DNNs. +## Getting Started +Start by cloning the repository. +### Prerequisites +The required python libraries available from pypi are in the requirements.txt file. Install them by running: +``` +pip install -r requirements.txt +``` +Additional libraries not available from pypi: +- tfrbm- for DBN-DNN pretraining. + - available at https://github.com/meownoid/tensorfow-rbm +## Getting Started +Since the project is primarily intended to be used with PocketSphinx the file formats for feature files, state-segmentation output files and the prediction files are in sphinx format. +### Feature File Format +``` +N: number of frames +M: dimensions of the feature vector +N*M (4 bytes) +Frame 1: f_1...f_M (4*M bytes) +. +. +. +Frame N: f_1,...,f_M (4*M bytes) +``` +Look at readMFC in utils.py +### state-segmentation files +format for each frame: +``` + 2 2 2 1 4 bytes +st1 [st2 st3] pos scr +``` +### Prediction output +format for each frame: +``` +N: number of states +N (2 bytes) +scr_1...scr_N (2*N bytes) +``` +### Wrapper Scripts +``` +runDatasetGen.py -train_fileids -val_fileids [-test_fileids] -n_filts -feat_dir -feat_ext -stseg_dir -stseg_ext -mdef [-outfile_prefix] [-keep_utts] +``` +runDatasetGen takes feature files and state-segmentation files stored in sphinx format along with the definition file of the GMM-HMM model to generate a set of numpy arrays that form a python readable dataset. +runDatasetGen writes the following files in the directory it was called in: +- Data Files + - _train.npy + - _dev.npy + - _test.npy +- label files + - _train_label.npy + - _dev_label.npy + - _test_label.npy +- metadata file + - _meta.npz + +The metadata file is a zipped collection of arrays with the follwing keys: +- File names for utterances + - filenames_Train + - filenames_Dev + - filenames_Test +- Number of frames per utterance (useful if -keep_utts is not set) + - framePos_Train + - framePos_Dev + - framePos_Test +- State Frequencies (useful for scaling in some cases) + - state_freq_Train + - state_freq_Dev + - state_freq_Test +``` +runNNTrain.py -train_data -train_labels -val_data -val_labels -nn_config [-context_win] [-cuda_device_id] [-pretrain] [-keras_model] -model_name +``` +runNNTrain takes the training and validation data files (as generated by runDatasetGen) and trains a neural network on them. +The architecture and parameters of the neural network is defined in a text file. Currently this script supports 4 network types: +- MLP (mlp) +- Convolutional Neural Network (conv) +- MLP with short cut connections (resnet) +- Convolutional Network with residual connections in the fully connected layers (conv + resnet) +See sample_nn.cfg for an example. +The format for the configuration file consists of ```param``` and ```value``` pairs +if value has multiple elemets (represented by ... below) they should be separated by spaces. +Params and possible values: +- **type** mlp, conv, resnet, conv+resnet +- **width** any integer value +- **depth** any integer value +- **dropout** float in (0,1) +- **batch_norm** - +- **activation** sigmoid, hard_sigmoid, elu, relu, selu, tanh, softplus, softsign, softmax, linear +- **optimizer** sgd, adam, adagrad +- **lr** float in (0,1) +- **batch_size** any integer value +- **ctc_loss** - +- for type = conv and type = conv+resnet + - **conv** [n_filters, filter_window]... + - **pooling** None, [max/avg, window_size, stride_size] +- for type = resnet and type = conv+resnet + - **block_depth** any integer value + - **n_blocks** any integer value +``` +runNNPredict -keras_model -ctldir -inext -outdir -outext -nfilts [-acoustic_weight] [-context_win] [-cuda_device_id] +``` +runNNPredict takes a keras model and a list of feature files to generate predictions. The predictions are stored as binary files in sphinx readable format (defined above). +Please ensure that the dimensionality of the feature vectors matches nfilts and the context window is the same as that for which the model was trained. +The acoustic_weight is used to scale the output scores. This is required because if the scores are passed through a GMM-GMM decoder like PocketSphinx are too small or too large then the decoding performance suffers. One way of estimating this weight is to generate scores from the GMM-HMM decoder being used, fit a linear regression between the GMM-HMM scores and the NN-scores and use the coefficient as the weight. +``` +readSen.py -gmm_score_dir -gmm_ctllist -nn_score_dir -nn_ctllist [-gmm_ext] [-nn_ext] +``` +readSen takes scores (stored in sphinx readable binary files) obtained from a GMM-HMM decoder and a NN, and fit a regression to them. + +## Example workflow with CMUSphinx +- Feature extraction using sphinx_fe: + ``` + sphinx_fe -argfile ../../en_us.ci_cont/feat.params -c etc/wsj0_train.fileids -di wav/ -do feat_ci_mls -mswav yes -eo mls -ei wav -ofmt sphinx -logspec yes + ``` +- State-segmentation using sphinx3_align + ``` + sphinx3_align -hmm ../../en_us.ci_cont/ -dict etc/cmudict.0.6d.wsj0 -ctl etc/wsj0_train.fileids -cepdir feat_ci_mls/ -cepext .mfc -insent etc/wsj0.transcription -outsent wsj0.out -stsegdir stateseg_ci_dir/ -cmn batch + ``` +- Generate dataset using runDatasetGen.py +- Train NN using runNNtrain.py +### EITHER +- Generate predictions from the NN using runNNPredct.py +- Generate predictions from PocketSphinx +``` +pocketsphinx_batch -hmm ../../en_us.ci_cont/ -lm ../../tcb20onp.Z.DMP -cepdir feat_ci_mfc/ -ctl ../../GSOC/SI_ET_20.NDX -dict etc/cmudict.0.6d.wsj0 -senlogdir sendump_ci/ -compallsen yes -bestpath no -fwdflat no -remove_noise no -remove_silence no -logbase 1.0001 -pl_window 0 +``` +- Compute the acoustic weight using readSen.py +- Decode the scaled NN predictions with PocketSphinx +``` +pocketsphinx_batch -hmm ../../en_us.ci_cont/ -lm ../../tcb20onp.Z.DMP -cepdir senscores/ -cepext .sen -hyp NN2.hyp -ctl ../../GSOC/SI_ET_20.NDX -dict etc/cmudict.0.6d.wsj0 -compallsen yes -logbase 1.0001 -pl_window 0 -senin yes +``` +### OR +- predict and decode with the PocketSphinx DNN decoder by passing your keras model to it and setting the other required parameters. +``` +pocketsphinx_batch -hmm ../../en_us.ci_cont_2/ -lm ../../tcb20onp.Z.DMP -cepdir feat_ci_dev_mls/ -cmn batch -hyp test_ci_2-2.hyp -ctl etc/wsj0_dev.fileids -dict etc/cmudict.0.6d.wsj0 -nnmgau ../../GSOC/bestModels/best_CI.h5 -pl_window 0 -ceplen 25 -ncep 25 -cudaid 2 +``` +NOTE: If you are using the PocketSphinx DNN decoder please ensure that you select the appropriate feature type for your model. You need to be extra careful if you are training models that process the data utterance-wise instead of frame-wise since the default behaviour of Pocketsphinx is to perform frame-wise classification. diff --git a/python/DNN_training/requirements.txt b/python/DNN_training/requirements.txt new file mode 100644 index 00000000..c2d0f232 --- /dev/null +++ b/python/DNN_training/requirements.txt @@ -0,0 +1,67 @@ +appdirs==1.4.3 +audioread==2.1.5 +backports.weakref==1.0rc1 +bleach==1.5.0 +cycler==0.10.0 +Cython==0.25.2 +daemonize==2.4.7 +decorator==4.0.6 +editdistance==0.3.1 +funcsigs==1.0.2 +functools32==3.2.3.post2 +graphviz==0.7.1 +guppy==0.1.10 +h5py==2.7.0 +htk-io==0.5 +html5lib==0.9999999 +ipython==2.4.1 +joblib==0.11 +Keras==2.0.6 +Lasagne==0.2.dev1 +librosa==0.5.1 +Mako==1.0.6 +Markdown==2.2.0 +MarkupSafe==1.0 +matplotlib==2.0.2 +memory-profiler==0.47 +mock==2.0.0 +nose==1.3.7 +numpy==1.13.1 +packaging==16.8 +pbr==3.1.1 +pexpect==4.0.1 +posix-ipc==1.0.0 +protobuf==3.3.0 +ptyprocess==0.5 +py==1.4.33 +pycurl==7.43.0 +pydot==1.2.3 +pydot-ng==1.0.0 +pyfst==0.2.3 +pygpu==0.6.5 +pyliblzma==0.5.3 +pyparsing==2.2.0 +pysqlite==1.0.1 +pytest==3.0.7 +python-apt==1.1.0b1 +python-dateutil==2.6.0 +python-speech-features==0.5 +pytools==2016.2.6 +pytz==2017.2 +PyYAML==3.12 +resampy==0.1.5 +rpm-python==4.12.0.1 +scikit-learn==0.18.1 +scipy==0.19.1 +simplegeneric==0.8.1 +six==1.10.0 +subprocess32==3.2.7 +tensorflow-gpu==1.2.0 +tfrbm==0.0.2 +Theano==0.9.0 +tkinter==0.2.0 +tqdm==4.14.0 +urlgrabber==3.9.1 +virtualenv==15.0.1 +Werkzeug==0.12.2 +yum-metadata-parser==1.1.4 diff --git a/python/DNN_training/test.csv b/python/DNN_training/test.csv new file mode 100644 index 00000000..798e666a --- /dev/null +++ b/python/DNN_training/test.csv @@ -0,0 +1,9 @@ +epoch,acc,loss,lr,val_acc,val_loss +0,0.11622738801684533,5.8389106012230449,0.001,0.10160732136648122,5.7048709947973366 +1,0.16153809341500766,4.8325580953638916,0.001,0.10914953494436717,5.4223426403024435 +2,0.18681057379402757,4.4621437494093934,0.001,0.12780201451668985,5.2290357653056798 +3,0.21127847434915772,4.1832528039470747,0.001,0.14062635822322669,5.0586822548562527 +4,0.2326776057618683,3.9420465787738608,0.001,0.15207753824756606,4.8969079582681241 +0,0.11584453962480858,5.8249653400724917,0.001,0.10603784553198888,5.6919887321217502 +1,0.16048675583843797,4.8297243269807897,0.001,0.10991421462100139,5.4031377342712235 +2,0.18789032589969373,4.4382342056329911,0.001,0.12922543245827539,5.2118432286385863 diff --git a/python/DNN_training/test.png b/python/DNN_training/test.png new file mode 100644 index 00000000..080e0e73 Binary files /dev/null and b/python/DNN_training/test.png differ diff --git a/python/Makefile.am b/python/Makefile.am index a6721a1a..9df05301 100644 --- a/python/Makefile.am +++ b/python/Makefile.am @@ -72,6 +72,13 @@ nobase_scripts_SCRIPTS = \ cmusphinx/prune_mixw.py \ cmusphinx/quantize_mixw.py \ cmusphinx/lat2dot.py \ - cmusphinx/qmwx.pyx + cmusphinx/qmwx.pyx \ + cmusphinx/readSen.py \ + cmusphinx/runDatasetGen.py \ + cmusphinx/runNNPredict.py \ + cmusphinx/runNNTrain.py \ + cmusphinx/genLabels.py \ + cmusphinx/utils.py \ + cmusphinx/NNTrain.py EXTRA_DIST = $(nobase_scripts_SCRIPTS) diff --git a/python/cmusphinx/NNTrain.py b/python/cmusphinx/NNTrain.py new file mode 100644 index 00000000..b9e5d57e --- /dev/null +++ b/python/cmusphinx/NNTrain.py @@ -0,0 +1,386 @@ +from keras.models import Sequential, Model +from keras.optimizers import SGD,Adagrad, Adam +from keras.layers.normalization import BatchNormalization +from keras.layers import ( + Input, + Dense, + Activation, + Dropout, + Conv1D, + Conv2D, + LocallyConnected2D, + MaxPooling2D, + AveragePooling2D, + Reshape, + Flatten, + Masking) +from keras.layers.core import Lambda +from keras.layers.merge import add, concatenate +from keras.utils import to_categorical, plot_model +from keras.models import load_model, Model +from keras.callbacks import History,ModelCheckpoint,CSVLogger,ReduceLROnPlateau +from keras import regularizers +from keras.preprocessing.sequence import pad_sequences +import keras.backend as K +import numpy as np +import matplotlib.pyplot as plt +import tensorflow as tf +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import array_ops +from tfrbm import GBRBM,BBRBM +from sys import stdout +from sklearn.preprocessing import StandardScaler +import struct +from utils import * + +""" + This module provides functions for training different neural network architectures + Below is a bried summary of the different functions and their purpose, more detail + can be found below: + - mlp1: creates a MLP consisting of a number dense layers + - mlp_wCTC: creates a MLP that calculates the ctc loss during training + - mlp4: create convolutional neural networks and residual networks + - DBN_DNN: performs DBN pretraining on simple MLPs + - preTrain: performs layer-wise pretraining on simple MLPs + - trainAndTest: runs training on a provided model using the given dataset +""" + +def ctc_lambda_func(args): + import tensorflow as tf + y_pred, labels, input_length, label_length = args + label_length = K.cast(tf.squeeze(label_length), 'int32') + input_length = K.cast(tf.squeeze(input_length), 'int32') + # return K.ctc_batch_cost(labels, y_pred, input_length, label_length) + labels = K.ctc_label_dense_to_sparse(labels,label_length) + return tf.nn.ctc_loss(labels,y_pred,input_length, + preprocess_collapse_repeated=True, + ctc_merge_repeated=False, + time_major=False, + ignore_longer_outputs_than_inputs=True) + +def ler(y_true, y_pred, **kwargs): + """ + Label Error Rate. For more information see 'tf.edit_distance' + """ + return tf.reduce_mean(tf.edit_distance(y_pred, y_true, **kwargs)) + +def decode_output_shape(inputs_shape): + y_pred_shape, seq_len_shape = inputs_shape + return (y_pred_shape[:1], None) + +def decode(args): + import tensorflow as tf + y_pred, label_len = args + label_len = K.cast(tf.squeeze(label_len), 'int32') + # ctc_labels = tf.nn.ctc_greedy_decoder(y_pred, label_len)[0][0] + # return ctc_labels + ctc_labels = K.ctc_decode(y_pred,label_len,greedy=False)[0][0] + return K.ctc_label_dense_to_sparse(ctc_labels, label_len) + +def mlp_wCTC(input_dim,output_dim,depth,width,BN=False): + print locals() + x = Input(name='x', shape=(1000,input_dim)) + h = x + h = Masking()(h) + for i in range(depth): + h = Dense(width)(h) + if BN: + h = BatchNormalization()(h) + h = Activation('sigmoid')(h) + out = Dense(output_dim,name='out')(h) + softmax = Activation('softmax', name='softmax')(out) + # a = 1.0507 * 1.67326 + # b = -1 + # # out = Lambda(lambda x : a * K.pow(x,3) + b)(h) + # out = Lambda(lambda x: a * K.exp(x) + b, name='out')(h) + y = Input(name='y',shape=[None,],dtype='int32') + x_len = Input(name='x_len', shape=[1],dtype='int32') + y_len = Input(name='y_len', shape=[1],dtype='int32') + + dec = Lambda(decode, output_shape=decode_output_shape, name='decoder')([out,x_len]) + # edit_distance = Lambda(ler, output_shape=(1,), name='edit_distance')([out,y,x_len,y_len]) + + loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([out, y, x_len, y_len]) + model = Model(inputs=[x, y, x_len, y_len], outputs=[loss_out,dec,softmax]) + + sgd = SGD(lr=0.001, decay=1e-6, momentum=0.99, nesterov=True, clipnorm=5) + opt = Adam(lr=0.0001, clipnorm=5) + model.compile(loss={'ctc': dummy_loss, + 'decoder': decoder_dummy_loss, + 'softmax': 'sparse_categorical_crossentropy'}, + optimizer=opt, + metrics={'decoder': ler, + 'softmax': 'accuracy'}, + loss_weights=[1,0,0]) + return model +def ctc_model(model): + x = model.get_layer(name='x').input + + out = model.get_layer(name='out').output + softmax = Activation('softmax', name='softmax')(out) + y = Input(name='y',shape=[None,],dtype='int32') + x_len = Input(name='x_len', shape=[1],dtype='int32') + y_len = Input(name='y_len', shape=[1],dtype='int32') + + dec = Lambda(decode, output_shape=decode_output_shape, name='decoder')([out,x_len]) + # edit_distance = Lambda(ler, output_shape=(1,), name='edit_distance')([out,y,x_len,y_len]) + + loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([out, y, x_len, y_len]) + model = Model(inputs=[x, y, x_len, y_len], outputs=[loss_out,dec,softmax]) + + sgd = SGD(lr=0.001, decay=1e-6, momentum=0.99, nesterov=True, clipnorm=5) + opt = Adam(lr=0.0001, clipnorm=5) + model.compile(loss={'ctc': dummy_loss, + 'decoder': decoder_dummy_loss, + 'softmax': 'sparse_categorical_crossentropy'}, + optimizer=opt, + metrics={'decoder': ler, + 'softmax': 'accuracy'}, + loss_weights=[1,0,0]) + return model +def _bn_act(input,activation='relu'): + """Helper to build a BN -> relu block + """ + norm = BatchNormalization()(input) + return Activation(activation)(norm) + +def make_dense_res_block(inp, size, width, drop=False,BN=False,regularize=False): + x = inp + for i in range(size): + x = Dense(width, + kernel_regularizer=regularizers.l2(0.05) if regularize else None)(x) + if i < size - 1: + if drop: + x = Dropout(0.15)(x) + if BN: + x = _bn_relu(x) + return x + +def mlp4(input_dim,output_dim,nBlocks,width, n_frames, block_depth=1, + n_filts=[84], filt_dims=[(11,8)], pooling=[['max',(6,6),(2,2)]], + block_width=None, dropout=False, BN=False, activation='relu', + parallelize=False, conv=False, regularize=False, + exp_boost=False, quad_boost=False, shortcut=False, + opt='adam', lr=0.001): + + print locals() + if block_width == None: + block_width = width + inp = Input(shape=input_dim, name='x') + x = inp + if conv: + x = Reshape((n_frames,input_dim/n_frames,1))(x) + for i in range(len(n_filts)): + print i + + x = LocallyConnected2D(n_filts[i],filt_dims[i], + padding='valid')(x) + x = _bn_act(x,activation=activation) + if pooling[i] != None: + pooling_type, win_size, stride = pooling[i] + if pooling_type == 'max': + x = MaxPooling2D(win_size,strides=stride,padding='same')(x) + if pooling_type == 'avg': + x = AveragePooling2D(win_size,strides=stride,padding='same')(x) + x = Flatten()(x) + if block_width != width: + x = Dense(block_width)(x) + for i in range(nBlocks): + y = make_dense_res_block(x,block_depth,block_width,BN=BN,drop=dropout,regularize=regularize) + if shortcut: + x = add([x,y]) + else: + x = y + if dropout: + x = Dropout(dropout)(x) + if BN: + x = _bn_act(x,activation=activation) + else: + x = Activation(activation)(x) + + if exp_boost: + x = Dense(output_dim)(x) + z = Lambda(lambda x : K.exp(x))(x) + if quad_boost: + x = Dense(output_dim)(x) + a = 0.001 + b = 0.4 + z = Lambda(lambda x : a * K.pow(x,3) + b)(x) + else: + z = Dense(output_dim, name='out')(x) + z = Activation('softmax', name='softmax')(z) + model = Model(inputs=inp, outputs=z) + if parallelize: + model = make_parallel(model, len(CUDA_VISIBLE_DEVICES.split(','))) + # opt = Adam(lr=25/(np.sqrt(width * output_dim))) + if opt == 'sgd': + opt = SGD + if opt == 'adam': + opt = Adam + if opt == 'adagrad': + opt = Adagrad + opt = opt(lr=lr) + # opt = SGD(lr=1/(np.sqrt(input_dim * width)), decay=1e-6, momentum=0.9, nesterov=True) + model.compile(optimizer=opt, + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + return model + +def resnet_wrapper(input_dim,output_dim,depth,width,reshape_layer): + builder = resnet.ResnetBuilder() + model = builder.build_resnet_18(input_dim, output_dim,reshape_layer) + x = model.get_layer(name='flatten_1').get_output_at(-1) + for i in range(depth): + x = Dense(width,activation='relu')(x) + softmax = Dense(output_dim,activation='softmax')(x) + model = Model(inputs=model.inputs, outputs=softmax) + opt = Adam(lr=10/np.sqrt(input_dim * output_dim)) + # opt = SGD(lr=1/(np.sqrt(input_dim * width)), decay=1e-6, momentum=0.9, nesterov=True) + model.compile(optimizer=opt, + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + return model +def DBN_DNN(inp,nClasses,depth,width,batch_size=2048): + RBMs = [] + weights = [] + bias = [] + # batch_size = inp.shape + nEpoches = 5 + if len(inp.shape) == 3: + inp = inp.reshape((inp.shape[0] * inp.shape[1],inp.shape[2])) + sigma = np.std(inp) + # sigma = 1 + rbm = GBRBM(n_visible=inp.shape[-1],n_hidden=width,learning_rate=0.002, momentum=0.90, use_tqdm=True,sample_visible=True,sigma=sigma) + rbm.fit(inp,n_epoches=15,batch_size=batch_size,shuffle=True) + RBMs.append(rbm) + for i in range(depth - 1): + print 'training DBN layer', i + rbm = BBRBM(n_visible=width,n_hidden=width,learning_rate=0.02, momentum=0.90, use_tqdm=True) + for e in range(nEpoches): + batch_size *= 1 + (e*0.5) + n_batches = (inp.shape[-2] / batch_size) + (1 if inp.shape[-2]%batch_size != 0 else 0) + for j in range(n_batches): + stdout.write("\r%d batch no %d/%d epoch no %d/%d" % (int(time.time()),j+1,n_batches,e,nEpoches)) + stdout.flush() + b = np.array(inp[j*batch_size:min((j+1)*batch_size, inp.shape[0])]) + for r in RBMs: + b = r.transform(b) + rbm.partial_fit(b) + RBMs.append(rbm) + for r in RBMs: + (W,_,Bh) = r.get_weights() + weights.append(W) + bias.append(Bh) + model = mlp1(x_train.shape[1],nClasses,depth-1,width) + print len(weights), len(model.layers) + assert len(weights) == len(model.layers) - 1 + for i in range(len(weights)): + W = [weights[i],bias[i]] + model.layers[i].set_weights(W) + return model +# def gen_data(active): + + +def preTrain(model,modelName,x_train,y_train,meta,skip_layers=[],outEqIn=False,fit_generator=None): + print model.summary() + layers = model.layers + output = layers[-1] + outdim = output.output_shape[-1] + for i in range(len(layers) - 1): + if i in skip_layers: + print 'skipping layer ',i + continue + if len(model.layers[i].get_weights()) == 0: + print 'skipping layer ',i + continue + last = model.layers[i].get_output_at(-1) + if outEqIn: + preds = Dense(outdim)(last) + else: + preds = Dense(outdim,activation='softmax')(last) + model_new = Model(model.input,preds) + for j in range(len(model_new.layers) - 2): + print "untrainable layer ",j + model_new.layers[j].trainable=False + model_new.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + print model_new.summary() + batch_size = 2048 + if fit_generator == None: + model_new.fit(x_train,y_train,epochs=1,batch_size=2048) + else: + history = model.fit_generator(fit_generator(x_train,y_train,batch_size), + steps_per_epoch=(meta['nFrames_Train'])/batch_size, epochs=3) + # model.fit_generator(gen_bracketed_data(x_train,y_train,meta['framePos_Train'],4), + # steps_per_epoch=len(meta['framePos_Train']), epochs=3, + # callbacks=[ModelCheckpoint('%s_CP.h5' % modelName,monitor='loss',mode='min')]) + # model.fit_generator(gen_data(x_train,y_train,batch_size), + # steps_per_epoch = x_train.shape[0] / batch_size, + # epochs = 1) + model.layers[i].set_weights(model_new.layers[-2].get_weights()) + for l in model.layers: + l.trainable = True + return model + +def trainNtest(model,x_train,y_train,x_test,y_test,meta, + modelName,testOnly=False,pretrain=False, batch_size=512, + init_epoch=0, fit_generator=None, ctc_train=False, n_epochs=50): + print 'TRAINING MODEL:',modelName + if not testOnly: + if pretrain: + print 'pretraining model...' + model = preTrain(model,modelName,x_train,y_train,meta,fit_generator=fit_generator) + if ctc_train: + model = ctc_model(model) + print model.summary() + print 'starting fit...' + callback_arr = [ModelCheckpoint('%s_CP.h5' % modelName,save_best_only=True,verbose=1), + ReduceLROnPlateau(patience=5,factor=0.5,min_lr=10**(-6), verbose=1), + CSVLogger(modelName+'.csv',append=True)] + + if fit_generator == None: + history = model.fit(x_train,y_train,epochs=n_epochs,batch_size=batch_size, + initial_epoch=init_epoch, + validation_data=(x_test,y_test), + callbacks=callback_arr) + else: + history = model.fit_generator(fit_generator(x_train,y_train,batch_size), + steps_per_epoch= (meta['nFrames_Train'])/batch_size, epochs=n_epochs, + validation_data=fit_generator(x_test,y_test,batch_size), + validation_steps = (meta['nFrames_Dev']) / batch_size, + callbacks=callback_arr) + model = Model(inputs=[model.get_layer(name='x').input], + outputs=[model.get_layer(name='softmax').output]) + print model.summary() + model.compile(loss='sparse_categorical_crossentropy', + optimizer='adam', + metrics=['accuracy']) + print 'saving model...' + model.save(modelName+'.h5') + # model.save_weights(modelName+'_W.h5') + print(history.history.keys()) + print history.history['lr'] + print 'plotting graphs...' + # summarize history for accuracy + fig, ax1 = plt.subplots() + ax1.plot(history.history['acc']) + ax1.plot(history.history['val_acc']) + ax2 = ax1.twinx() + ax2.plot(history.history['loss'],color='r') + ax2.plot(history.history['val_loss'],color='g') + plt.title('model loss & accuracy') + ax1.set_ylabel('accuracy') + ax2.set_ylabel('loss') + ax1.set_xlabel('epoch') + ax1.legend(['training acc', 'testing acc']) + ax2.legend(['training loss', 'testing loss']) + fig.tight_layout() + plt.savefig(modelName+'.png') + plt.clf() + else: + model = load_model(modelName) + print 'scoring...' + score = model.evaluate_generator(gen_bracketed_data(x_test,y_test,meta['framePos_Dev'],4), + len(meta['framePos_Dev'])) + print score diff --git a/python/cmusphinx/genLabels.py b/python/cmusphinx/genLabels.py new file mode 100644 index 00000000..ec75b2c0 --- /dev/null +++ b/python/cmusphinx/genLabels.py @@ -0,0 +1,279 @@ +import numpy as np +import os +import sys +from sklearn.preprocessing import StandardScaler +import utils +from keras.preprocessing.sequence import pad_sequences +import librosa +""" + TODO: Test with different model and file naming scheme +""" +def ping(): + curr_time = int(time.time()) + while done == 0: + if (int(time.time()) != curr_time): + curr_time = int(time.time()) + sys.stdout.write('.') + sys.stdout.flush() + + +def read_sen_labels_from_mdef(fname): + labels = np.loadtxt(fname,dtype=str,skiprows=10,usecols=(0,1,2,3,6,7,8)) + labels = map(lambda x: + [reduce(lambda a,b: a+' '+b, + filter(lambda y: y != '-', x[:4]))] + list(x[4:]), labels) + + phone2state = {} + for r in labels: + phone2state[r[0]] = map(int, r[1:]) + return phone2state + +def frame2state(fname, phone2state): + with open(fname,'r') as f: + lines = f.readlines()[2:] + lines = map(lambda x: x.split()[2:], lines) + + lines = map(lambda x: [x[0], reduce(lambda a,b: a+' '+b,x[1:])],lines) + for l in lines: + if l[1] not in phone2state: + l[1] = l[1].split()[0] + states = map(lambda x: phone2state[x[1]][int(x[0])], lines) + return (list(states)) + +def loadDict(filename): + def mySplit(line): + line = line.split() + for i in range(1,len(line)): + line[i] = "{0}1 {0}2 {0}3".format(line[i]) + line = [line[0], reduce(lambda x,y: x+ ' ' +y, line[1:])] + return line + with open(filename) as f: + d = f.readlines() + d = map(lambda x: x.split(),d) + myDict = {} + for r in d: + myDict[r[0]] = r[1:] + return myDict + +def loadTrans(trans_file,pDict): + trans = {} + with open(trans_file) as f: + lines = f.readlines() + for line in lines: + line = line.split() + fname = line[-1][1:-1] + labels = map(lambda x: pDict.setdefault(x,-1), line[:-1]) + labels = filter(lambda x: x!=-1, labels) + if labels == []: + continue + labels = reduce(lambda x,y: x + y, labels) + trans[fname] = labels + return trans + +def trans2labels(trans,phone2state): + d = {} + for u in trans: + labels = trans[u] + labels = map(lambda x: phone2state[x], labels) + labels = reduce(lambda x,y: x + y, labels) + d[u] = labels + return d + +def genDataset(train_flist, dev_flist, test_flist, n_feats, + feat_path, feat_ext, stseg_path, stseg_ext, mdef_fname, + outfile_prefix, context_len=None, + keep_utts=False, ctc_labels=False, pDict_file=None, + trans_file=None, make_graph=False,cqt=False, + max_len=None,n_deltas=0,pad=False): + assert(os.path.exists(stseg_path)) + train_files = np.loadtxt(train_flist,dtype=str) + dev_files = np.loadtxt(dev_flist,dtype=str) + if test_flist != None: + test_files = np.loadtxt(test_flist,dtype=str) + else: + test_files = [] + phone2state = read_sen_labels_from_mdef(mdef_fname) + + print len(train_files) + stseg_files_train = map(lambda x: x.split('/')[-1]+stseg_ext,train_files) + stseg_files_test = map(lambda x: x.split('/')[-1]+stseg_ext,test_files) + stseg_files_dev = map(lambda x: x.split('/')[-1]+stseg_ext,dev_files) + stseg_files_train = filter(lambda x: os.path.exists(stseg_path + '/' + x), stseg_files_train) + stseg_files_test = filter(lambda x: os.path.exists(stseg_path + '/' + x), stseg_files_test) + stseg_files_dev = filter(lambda x: os.path.exists(stseg_path + '/' + x), stseg_files_dev) + + stseg_files = stseg_files_train + stseg_files_dev + stseg_files_test + print "Training Files: %d Dev Files: %d Testing Files: %d" % (len(stseg_files_train), len(stseg_files_dev), len(stseg_files_test)) + + train_files = map(lambda x: feat_path+'/'+x+feat_ext,train_files) + dev_files = map(lambda x: feat_path+'/'+x+feat_ext,dev_files) + test_files = map(lambda x: feat_path+'/'+x+feat_ext,test_files) + + X_Train = [] + Y_Train = [] + X_Test = [] + Y_Test = [] + X_Dev = [] + Y_Dev = [] + framePos_Train = [] + framePos_Test = [] + framePos_Dev = [] + filenames_Train = [] + filenames_Test = [] + filenames_Dev = [] + active_states_Train = [] + active_states_Test = [] + active_states_Dev = [] + # allData = [] + # allLabels = [] + pos = 0 + scaler = StandardScaler(copy=False,with_std=False) + n_states = np.max(phone2state.values())+1 + print n_states + state_freq_Train = [0]*n_states + state_freq_Dev = [0]*n_states + state_freq_Test = [0]*n_states + + + + for i in range(len(stseg_files)): + if i < len(stseg_files_train): + # print '\n train' + frames = framePos_Train + allData = X_Train + allLabels = Y_Train + filenames = filenames_Train + state_freq = state_freq_Train + files = train_files + active_state = active_states_Train + elif i < len(stseg_files_train) + len(stseg_files_dev): + # print '\n dev' + frames = framePos_Dev + allData = X_Dev + allLabels = Y_Dev + filenames = filenames_Dev + state_freq = state_freq_Dev + files = dev_files + active_state = active_states_Dev + else: + # print '\n test' + frames = framePos_Test + allData = X_Test + allLabels = Y_Test + filenames = filenames_Test + state_freq = state_freq_Test + files = test_files + active_state = active_states_Test + + sys.stdout.write("\r%d/%d " % (i,len(stseg_files))) + sys.stdout.flush() + f = stseg_files[i] + + [data_file] = filter(lambda x: f[:-9] in x, files) + if cqt: + y,fs=librosa.load(data_file,sr=None) + data = np.absolute(librosa.cqt(y,sr=fs,window=np.hamming, + hop_length=160, n_bins=64, bins_per_octave=32).T) + # print data.shape + else: + data = utils.readMFC(data_file,n_feats).astype('float32') + data = scaler.fit_transform(data) + labels = frame2state(stseg_path + '/' + f, phone2state) + nFrames = min(len(labels), data.shape[0]) + sys.stdout.write('(%d,%d) (%d,)' % (data.shape + np.array(labels).shape)) + data = data[:nFrames] + labels = labels[:nFrames] + if context_len != None: + pad_top = np.zeros((context_len,data.shape[1])) + pad_bot = np.zeros((context_len,data.shape[1])) + padded_data = np.concatenate((pad_top,data),axis=0) + padded_data = np.concatenate((padded_data,pad_bot),axis=0) + + data = [] + for j in range(context_len,len(padded_data) - context_len): + new_row = padded_data[j - context_len: j + context_len + 1] + new_row = new_row.flatten() + data.append(new_row) + data = np.array(data) + if n_deltas > 0: + pad_top = np.zeros((n_deltas,data.shape[1])) + pad_bot = np.zeros((n_deltas,data.shape[1])) + padded_data = np.concatenate((pad_top,data),axis=0) + padded_data = np.concatenate((padded_data,pad_bot),axis=0) + data = [] + for j in range(n_deltas,len(padded_data) - n_deltas): + delta_top = padded_data - padded_data[j - n_deltas:j] + delta_bot = padded_data - padded_data[j:j + n_deltas] + new_row = delta_top + padded_data + delta_bot + data.append(new_row) + for l in labels: + state_freq[l] += 1 + filenames.append(data_file) + frames.append(nFrames) + if keep_utts: + allData.append(data) + allLabels.append(np.array(labels)) + else: + allData += list(data) + allLabels += list(labels) + pos += nFrames + if not ctc_labels: + assert(len(allLabels) == len(allData)) + # print allData + print len(allData), len(allLabels) + if max_len == None: + max_len = 100 * ((max(map(len,X_Train)) + 99)/ 100) + print 'max_len', max_len + if keep_utts and pad: + X_Train = pad_sequences(X_Train,maxlen=max_len,dtype='float32',padding='post') + Y_Train = pad_sequences(Y_Train,maxlen=max_len,dtype='float32',padding='post',value=n_states) + Y_Train = Y_Train.reshape(Y_Train.shape[0],Y_Train.shape[1],1) + X_Dev = pad_sequences(X_Dev,maxlen=max_len,dtype='float32',padding='post') + Y_Dev = pad_sequences(Y_Dev,maxlen=max_len,dtype='float32',padding='post',value=n_states) + Y_Dev = Y_Dev.reshape(Y_Dev.shape[0],Y_Dev.shape[1],1) + X_Test = pad_sequences(X_Test,maxlen=max_len,dtype='float32',padding='post') + Y_Test = pad_sequences(Y_Test,maxlen=max_len,dtype='float32',padding='post',value=n_states) + Y_Test = Y_Test.reshape(Y_Test.shape[0],Y_Test.shape[1],1) + # np.savez('wsj0_phonelabels_NFrames',NFrames_Train=NFrames_Train,NFrames_Test=NFrames_Test) + # t = threading.Thread(target=ping) + # t.start() + if context_len != None: + np.save(outfile_prefix + 'bracketed_train.npy',X_Train) + np.save(outfile_prefix + 'bracketed_dev.npy',X_Dev) + np.save(outfile_prefix + 'bracketed_train_labels.npy',Y_Train) + np.save(outfile_prefix + 'bracketed_dev_labels.npy',Y_Dev) + if len(X_Test) != 0: + np.save(outfile_prefix + 'bracketed_test.npy',X_Test) + np.save(outfile_prefix + 'bracketed_test_labels.npy',Y_Test) + np.savez(outfile_prefix + 'bracketed_meta.npz',framePos_Train=framePos_Train, + framePos_Test=framePos_Test, + framePos_Dev=framePos_Dev, + filenames_Train=filenames_Train, + filenames_Dev=filenames_Dev, + filenames_Test=filenames_Test, + state_freq_Train=state_freq_Train, + state_freq_Dev=state_freq_Dev, + state_freq_Test=state_freq_Test) + else: + np.save(outfile_prefix + 'train.npy',X_Train) + np.save(outfile_prefix + 'dev.npy',X_Dev) + np.save(outfile_prefix + 'train_labels.npy',Y_Train) + np.save(outfile_prefix + 'dev_labels.npy',Y_Dev) + if len(X_Test) != 0: + np.save(outfile_prefix + 'test.npy',X_Test) + np.save(outfile_prefix + 'test_labels.npy',Y_Test) + np.savez(outfile_prefix + 'meta.npz',framePos_Train=framePos_Train, + framePos_Test=framePos_Test, + framePos_Dev=framePos_Dev, + filenames_Train=filenames_Train, + filenames_Dev=filenames_Dev, + filenames_Test=filenames_Test, + state_freq_Train=state_freq_Train, + state_freq_Dev=state_freq_Dev, + state_freq_Test=state_freq_Test) + +if __name__ == '__main__': + genDataset('../wsj/wsj0/etc/wsj0_train.fileids','../wsj/wsj0/etc/wsj0_dev.fileids','../wsj/wsj0/etc/wsj0_test.fileids',40,'../wsj/wsj0/feat_cd_mls/','../wsj/wsj0/stateseg_ci_dir/','../en_us.ci_cont/mdef', + keep_utts=True, context_len=None, cqt=True, + trans_file='../wsj/wsj0/etc/wsj0.transcription', + pDict_file='../wsj/wsj0/etc/cmudict.0.6d.wsj0') diff --git a/python/cmusphinx/makeFileListFromTrans.py b/python/cmusphinx/makeFileListFromTrans.py new file mode 100644 index 00000000..5de9ac1c --- /dev/null +++ b/python/cmusphinx/makeFileListFromTrans.py @@ -0,0 +1,19 @@ +import numpy as np +def getFileListFromTran(trans,outFile): + with open(trans,'r') as f: + lines = f.readlines() + lines = map(lambda x: x.strip(), lines) + files = map(lambda x: x.split('(')[-1][:-1], lines) + np.savetxt(outFile,files,fmt='%s') + +def fixTran(trans): + with open(trans,'r') as f: + lines = f.readlines() + lines = map(lambda x: x.strip(), lines) + lines = map(lambda x: x.split('('), lines) + for l in lines: + l[-1] = l[-1].split('/')[-1] + lines = map(lambda x: reduce(lambda a,b: a+'('+b,x),lines) + np.savetxt(trans+'2',lines,fmt='%s') +# getFileListFromTran("../wsj/wsj0/transcripts/wsj0/wsj0.trans", "../wsj/wsj0/wsj0.filelist") +fixTran("../wsj/wsj0/transcripts/wsj0/wsj0.trans") \ No newline at end of file diff --git a/python/cmusphinx/readSen.py b/python/cmusphinx/readSen.py new file mode 100644 index 00000000..b47b38a8 --- /dev/null +++ b/python/cmusphinx/readSen.py @@ -0,0 +1,96 @@ +import os +import struct +import numpy as np +from sklearn.linear_model import LinearRegression +from argparse import ArgumentParser + +def readSen(fname, print_most_prob_sen=False): + print fname + f = open(fname,'rb') + s = '' + while 'endhdr\n' not in s: + v = f.read(1) + s += struct.unpack('s',v)[0] + magic_num = struct.unpack('I',f.read(4))[0] + assert magic_num == 0x11223344 + count = 0 + data = [] + while v: + v = f.read(2) + if not v: + continue + n_active = struct.unpack('h',v)[0] + # print n_active + assert n_active == 138 + + v = f.read(2*n_active) + scores = list(struct.unpack('%sh' % n_active, v)) + # print np.argmax(scores) + count += 1 + data += scores + print count + return np.array(data) + +if __name__=='__main__': + parser = ArgumentParser(description="""Given two sets of sen files fits a regression to them and returns the coefficient and the intercept. + Useful in determining the appropriate acoustic weight to scale the outputs of the NN by. Improperly + scaled output perform much worse than appropriatly scaled outputs""", + usage='%(prog)s [options] \nUse --help for option list') + parser.add_argument('-gmm_score_dir',type=str, required=True, + help="The directory where the sen files generated by GMM-HMM decoder are stored. Preppended to file paths in gmm_ctllist") + parser.add_argument('-gmm_ctllist', type=str, required=True, + help='List of all the sen files generated by the GMM-HMM decoder') + parser.add_argument('-nn_score_dir',type=str, required=True, + help="The directory where the sen files generated by a ANN are stored. Preppended to file paths in gmm_ctllist") + parser.add_argument('-nn_ctllist', type=str, required=True, + help='List of all the sen files generated by the ANN') + parser.add_argument('-gmm_ext', type=str, required=False, default='', + help='the file extension applied to all the files in gmm_ctllist') + parser.add_argument('-nn_ext', type=str, required=False, default='', + help='the file extension applied to all the files in nn_ctllist') + args = vars(parser.parse_args()) + # readSen('../wsj/wsj0/senscores/11_14_1/wsj0/si_et_20/440/440c0401.wv1.flac.sen') + ndx_list = map(lambda x: args['nn_score_dir']+x+args['nn_ext'], np.loadtxt(args['nn_ctllist'],dtype=str)) + file_list = map(lambda x: args['gmm_score_dir']+x+args['gmm_ext'], np.loadtxt(args['gmm_ctllist'],dtype=str)) + # file_list = map(lambda x: '../wsj/wsj0/sendump_dev_ci/' + x, os.listdir('../wsj/wsj0/sendump_dev_ci/')) + # file_list.sort() + # file_list = file_list[:-1] + # ndx_list = ['../wsj/wsj0/single_dev_NN/11_14_1/wsj0/si_et_20/445/445c0403.wv1.flac.sen'] + # file_list = ['../wsj/wsj0/single_dev/11_14_1/wsj0/si_et_20/445/445c0403.wv1.flac.sen'] + x = [] + y = [] + for i in range(len(file_list)): + if i >= 0: + if os.path.exists(ndx_list[i]): + print i,ndx_list[i], file_list[i] + _y = list(readSen(ndx_list[i])) + _x = list(readSen(file_list[i])) + if len(_x) != len(_y): + continue + y += _y + x += _x + frame_len = min(len(x),len(y)) + # x = x[:frame_len] + # y = y[:frame_len] + print len(x),len(y), len(x)/138, len(y)/138 + assert len(x) == len(y) + else: + continue + else: + print i,ndx_list[i+1], file_list[i] + y += list(readSen(ndx_list[i+1])) + x += list(readSen(file_list[i])) + x = np.array(x).reshape(-1,1) + y = np.array(y).reshape(-1,1) + # print x.shape, y.shape + data = np.concatenate((x,y),axis=1) + # np.save('data4regression.npy',data) + + # data = np.load('data4regression.npy') + lreg = LinearRegression(normalize=True,n_jobs=-1) + lreg.fit(data[:,[1]],data[:,[0]]) + print "coefficient: %f\t\tintercept: %f" % (lreg.coef_, lreg.intercept_) + + # vs = np.std(data[:,[1]]) + # va = np.std(data[:,[0]]) + # print va/vs diff --git a/python/cmusphinx/runDatasetGen.py b/python/cmusphinx/runDatasetGen.py new file mode 100644 index 00000000..d44d47b6 --- /dev/null +++ b/python/cmusphinx/runDatasetGen.py @@ -0,0 +1,41 @@ +from argparse import ArgumentParser + +parser = ArgumentParser(description="Generate numpy array from features and alignments produced by Sphinx", + usage='%(prog)s [options] \nUse --help for option list') +parser.add_argument('-train_fileids', type=str, required=True, + help="list of training files") +parser.add_argument('-val_fileids',type=str,required=True, + help='list of validation files') +parser.add_argument('-test_fileids',type=str,required=False, + help='list of test files') +parser.add_argument('-nfilts',type=int,required=True, + help='number of filters used for extracting features. (The dimensionality of the feature vector)') +parser.add_argument('-feat_dir', type=str, required=True, + help='the directory where feature files are stored (prepended to filepaths in the train, val and test filelists when looking for features)') +parser.add_argument('-feat_ext',type=str,required=True, + help='extension to be appended to each file path when looking for feature files') +parser.add_argument('-stseg_dir',type=str,required=True, + help='directory where the state-segmentation for each feature file is stored (prepended to filepaths in the train, val and test filelists when looking for labels)') +parser.add_argument('-stseg_ext',type=str,required=True, + help='extension to be appended to each file path when looking for state-segmentation files') +parser.add_argument('-mdef',type=str,required=True, + help='path to the mdef file for the Sphinx model. Needed to map phones/triphones in segmentation to state labels') +parser.add_argument('-outfile_prefix',type=str,default="", required=False, + help='prepended to the names of the output files') +parser.add_argument('-keep_utts',nargs='?',required=False, default=False, const=True, + help='store features and labels in a 3D array in which each index points to the list of features/labels for one utterance') +args = vars(parser.parse_args()) + +from genLabels import genDataset +genDataset(args['train_fileids'], + args['val_fileids'], + args['test_fileids'], + args['nfilts'], + args['feat_dir'], + args['feat_ext'], + args['stseg_dir'], + args['stseg_ext'], + args['mdef'], + args['outfile_prefix'], + keep_utts=args['keep_utts']) + diff --git a/python/cmusphinx/runNNPredict.py b/python/cmusphinx/runNNPredict.py new file mode 100644 index 00000000..eb9227ad --- /dev/null +++ b/python/cmusphinx/runNNPredict.py @@ -0,0 +1,44 @@ +from argparse import ArgumentParser + +parser = ArgumentParser(description="Generate Predictions from a Keras Model and save them in PockeSphinx readable .sen files.", + usage='%(prog)s [options] \nUse --help for option list') +parser.add_argument('-keras_model', type=str, required=True, + help="Keras model to be used for prediction in hd5 format (must be compatible with keras.load_model)") +parser.add_argument('-ctldir',type=str, required=True, + help="the directory for the control files, prepended to each file path in ctllist") +parser.add_argument('-ctllist', type=str,required=True, + help='list of input files, each representing an utterance') +parser.add_argument('-inext', type=str, required=True, + help='the extension of the control files, appended to each file path in ctllist') +parser.add_argument('-outdir', type=str,required=True, + help='Directory where the predictions are stored. The structure of this directory will be identical to ctldir') +parser.add_argument('-outext', type=str, required=True, + help='the extension of the output files') +parser.add_argument('-nfilts', type=int, required=True, + help='dimensionality of the feature vectors') +parser.add_argument('-acoustic_weight',type=float,required=False, + help='The weight to scale the predictions by. Sometimes needed to get meaningful predictions from PocketSphinx') +parser.add_argument('-context_win',type=int, required=False, default=0, + help='number of contextual frames to include from before and after the target frame (defaults to 5)') +parser.add_argument('-cuda_device_id', type=str, required=False, default="", + help="The CUDA-capable GPU device to use for execution. If none is specified, the code will execute on the CPU. If specifying multiple devices, separate the id's with commas") +args = vars(parser.parse_args()) + +import os +CUDA_VISIBLE_DEVICES = args["cuda_device_id"] +os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES +from keras.models import load_model +from utils import * + +model = load_model(keras_model,custom_objects={'dummy_loss':dummy_loss, + 'decoder_dummy_loss':decoder_dummy_loss, + 'ler':ler}) +print model.summary() +getPredsFromFilelist(model,args['ctllist'], + args['ctldir'], + args['inext'], + args['outdir'], + args['outext'], + context_len=args['context_win'], + weight=args['acoustic_weight'] if args['acoustic_weight'] != None else 0.1, + n_feat=args['nfilts']) \ No newline at end of file diff --git a/python/cmusphinx/runNNTrain.py b/python/cmusphinx/runNNTrain.py new file mode 100644 index 00000000..3a5b210f --- /dev/null +++ b/python/cmusphinx/runNNTrain.py @@ -0,0 +1,153 @@ +from argparse import ArgumentParser +import numpy as np + +parser = ArgumentParser(description="Train a Keras neural network model.", + usage='%(prog)s [options] \nUse --help for option list') +parser.add_argument('-train_data',type=str, required=True, + help="the training data for the neural network as a saved numpy array of 2D numpy arrays") +parser.add_argument('-train_labels',type=str, required=True, + help="the training labels for the neural network as a saved numpy array of 1D numpy arrays") +parser.add_argument('-val_data',type=str, required=True, + help="the validation data for the neural network as a saved numpy array of 2D numpy arrays") +parser.add_argument('-val_labels',type=str, required=True, + help="the validation labels for the neural network as a saved numpy array of 1D numpy arrays") +parser.add_argument('-nn_config',type=str, required=True, + help='file containing the neural network configuration information (look at sample_mlp.cfg)') +parser.add_argument('-n_epochs',type=int, required=True, + help='Number of epochs for which to train the model.') +parser.add_argument('-context_win',type=int, required=False, default=0, + help='number of contextual frames to include from before and after the target frame (defaults to 5)') +parser.add_argument('-cuda_device_id', type=str, required=False, default="", + help="The CUDA-capable GPU device to use for execution. If none is specified, the code will execute on the CPU. If specifying multiple devices, separate the id's with commas") +parser.add_argument('-pretrain', nargs='?', required=False, default=False, const=True, + help='Perform layer-wise pretraining of the MLP before starting training. (Use only with dense MLPs)') +parser.add_argument('-keras_model', type=str, required=False, + help="Keras model to be trained in hd5 format (must be compatible with keras.load_model)") +parser.add_argument('-model_name', type=str, required=True, + help='Name to be assigned to the output files') +args = vars(parser.parse_args()) + +import os +CUDA_VISIBLE_DEVICES = args["cuda_device_id"] +os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES +from NNTrain import * +from utils import * +from keras.models import load_model + +def read_config(filename): + with open(filename) as f: + lines = f.readlines() + lines = filter(lambda x: x[0] != '#' and len(x) > 2, lines) + args = {} + for l in lines: + split = l.split() + if split[0] in args: + args[split[0]].append(split[1:]) + else: + args[split[0]] = split[1:] if len(split) > 1 else [] + if len(args[split[0]]) == 1: + args[split[0]] = args[split[0]][0] + return args + +def init_model(args,input_dim,output_dim, nframes): + print args + nn_type = args['type'] + + if nn_type == 'mlp': + model = mlp4(input_dim, + output_dim, + int(args['depth']), + int(args['width']), + nframes, + BN = 'batch_norm' in args, + dropout = float(args.setdefault('dropout',False)), + activation = args.setdefault('activation','sigmoid')) + if nn_type == 'resnet': + model = mlp4(input_dim, + output_dim, + int(args['n_blocks']), + int(args['width']), + nframes, + block_depth=int(args['block_depth']), + dropout=float(args.setdefault('dropout',False)), + BN='batch_norm' in args, + shortcut=True) + if nn_type == 'conv' or nn_type == 'conv+resnet': + print args + assert(len(args['conv']) == len(args['pooling']) or + (type(args['conv']) == str and + type(args['pooling'] == str))) + filts = [] + filt_dims = [] + pooling = [] + max='max' + avg='avg' + if type(args['conv']) == str: + conv = [args['conv']] + pool = [args['pooling']] + else: + conv = args['conv'] + pool = args['pooling'] + for i in range(len(conv)): + filt, dims = eval(conv[i]) + filts.append(int(filt)) + filt_dims.append(dims) + + pooling.append(eval(pool[i])) + if nn_type == 'conv': + model = mlp4(input_dim, + output_dim, + int(args['depth']), + int(args['width']), + nframes, + n_filts=filts, + filt_dims=filt_dims, + pooling=pooling, + conv=True) + else: + model = mlp4(input_dim, + output_dim, + int(args['depth']), + int(args['width']), + nframes, + block_depth=int(args['block_depth']), + n_filts=filts, + filt_dims=filt_dims, + pooling=pooling, + conv=True, + shortcut=True) + + if 'ctc_loss' in args: + model = ctc_model(model) + return model + +x_train = np.load(args['train_data']) +y_train = np.load(args['train_labels']) + +x_test = np.load(args['val_data']) +y_test = np.load(args['val_labels']) + +nClasses = max(map(np.max,y_train)) + 1 + +meta = {} +meta['nFrames_Train'] = sum(map(lambda x: x.shape[0], x_train)) +meta['nFrames_Dev'] = sum(map(lambda x: x.shape[0], x_test)) +meta['state_freq_train'] = np.zeros(nClasses) +for u in y_train: + for r in u: + meta['state_freq_train'][r] += 1 + +conf = read_config(args['nn_config']) +context_len = args['context_win'] +if args['keras_model'] != None: + model = load_model(args['keras_model']) +else: + if 'ctc_loss' in conf: + model = init_model(conf, (None,x_train[0].shape[-1],), + nClasses + 1, 2 * context_len + 1) + fg = gen_bracketed_data(for_CTC=True, n_states=nClasses) + else: + model = init_model(conf, (x_train[0].shape[-1] * (2 * context_len + 1),), + nClasses, 2 * context_len + 1) + fg = gen_bracketed_data(context_len=context_len) +trainNtest(model,x_train,y_train,x_test,y_test,meta,args['model_name'],batch_size=int(conf['batch_size']),ctc_train=False,fit_generator=fg, pretrain=args['pretrain'], n_epochs=args['n_epochs']) diff --git a/python/cmusphinx/utils.py b/python/cmusphinx/utils.py new file mode 100644 index 00000000..c9398b89 --- /dev/null +++ b/python/cmusphinx/utils.py @@ -0,0 +1,333 @@ +import numpy as np +import struct +import matplotlib.pyplot as plt +import pylab as pl +from sys import stdout +import os +from keras.preprocessing.sequence import pad_sequences +import keras.backend as K +from scipy.sparse import coo_matrix +from sklearn.preprocessing import StandardScaler + +def dummy_loss(y_true,y_pred): + return y_pred +def decoder_dummy_loss(y_true,y_pred): + return K.zeros((1,)) +def ler(y_true, y_pred, **kwargs): + """ + Label Error Rate. For more information see 'tf.edit_distance' + """ + return tf.reduce_mean(tf.edit_distance(y_pred, y_true, **kwargs)) + +def dense2sparse(a): + # rows,cols = a.nonzero() + # data = map(lambda i: a[rows[i],cols[i]], range(rows.shape[0])) + # return coo_matrix((data,(rows,cols)), shape=a.shape,dtype='int32') + return coo_matrix(a,shape=a.shape,dtype='int32') + +def readMFC(fname,nFeats): + data = [] + with open(fname,'rb') as f: + v = f.read(4) + head = struct.unpack('I',v)[0] + v = f.read(nFeats * 4) + while v: + frame = list(struct.unpack('%sf' % nFeats, v)) + data .append(frame) + v = f.read(nFeats * 4) + data = np.array(data) + # print data.shape, head + assert(data.shape[0] * data.shape[1] == head) + return data + +def ctc_labels(labels, blank_labels = []): + new_labels = [] + for i in range(len(labels)): + l_curr = labels[i] + if l_curr not in blank_labels: + if i == 0: + new_labels.append(l_curr) + else: + if l_curr != labels[i-1]: + new_labels.append(l_curr) + return np.array(new_labels) +def _gen_bracketed_data_2D(x,y,nFrames, + context_len,fix_length, + for_CTC): + max_len = ((np.max(nFrames) + 50)/100) * 100 #rounding off to the nearest 100 + batch_size = 2 + while 1: + pos = 0 + nClasses = np.max(y) + 1 + if for_CTC: + alldata = [] + alllabels = [] + for i in xrange(len(nFrames)): + data = x[pos:pos + nFrames[i]] + labels = y[pos:pos + nFrames[i]] + # if for_CTC: + # labels = ctc_labels(labels,blank_labels=range(18) + [108,109,110]) + # if len(labels.shape) == 1: + # labels = to_categorical(labels,num_classes=nClasses) + if context_len != None: + pad_top = np.zeros((context_len,data.shape[1])) + pad_bot = np.zeros((context_len,data.shape[1])) + padded_data = np.concatenate((pad_top,data),axis=0) + padded_data = np.concatenate((padded_data,pad_bot),axis=0) + + data = [] + for j in range(context_len,len(padded_data) - context_len): + new_row = padded_data[j - context_len: j + context_len + 1] + new_row = new_row.flatten() + data.append(new_row) + data = np.array(data) + if for_CTC: + if batch_size != None: + alldata.append(data) + alllabels.append(labels) + + if len(alldata) == batch_size: + alldata = np.array(alldata) + alllabels = np.array(alllabels) + if fix_length: + alldata = pad_sequences(alldata,maxlen=1000,dtype='float32',truncating='post') + alllabels = pad_sequences(alllabels,maxlen=1000,dtype='float32',value=138,truncating='post') + inputs = {'x': alldata, + 'y': alllabels, + 'x_len': np.array(map(lambda x: len(x), alldata)), + 'y_len': np.array(map(lambda x: len(x), alllabels))} + outputs = {'ctc': np.ones([batch_size])} + yield (inputs,outputs) + alldata = [] + alllabels = [] + else: + data = np.array([data]) + labels = np.array([labels]) + inputs = {'x': data, + 'y': labels, + 'x_len': [data.shape[0]], + 'y_len': [labels.shape[0]]} + outputs = {'ctc': labels} + yield (inputs,outputs) + else: + yield (data,labels) + pos += nFrames[i] + +def _gen_bracketed_data_3D(x,y,batch_size,context_len): + epoch_no = 1 + while 1: + print epoch_no + batch_data = [] + batch_labels = [] + for i in range(len(x)): + data = x[i] + labels = y[i] + if context_len != 0: + pad_top = np.zeros((context_len,data.shape[1])) + pad_bot = np.zeros((context_len,data.shape[1])) + padded_data = np.concatenate((pad_top,data),axis=0) + padded_data = np.concatenate((padded_data,pad_bot),axis=0) + + data = [] + for j in range(context_len,len(padded_data) - context_len): + new_row = padded_data[j - context_len: j + context_len + 1] + new_row = new_row.flatten() + data.append(new_row) + data = np.array(data) + seq_len = 0 + while seq_len < len(data) and data[seq_len].any(): + seq_len += 1 + idxs = range(seq_len) + np.random.shuffle(idxs) + for j in idxs: + if len(batch_data) < batch_size: + batch_data.append(data[j]) + batch_labels.append(labels[j]) + else: + batch_data = np.array(batch_data) + batch_labels = np.array(batch_labels) + yield(batch_data,batch_labels) + batch_data = [] + batch_labels = [] + epoch_no += 1 +def gen_ctc_data(alldata,alllabels,batch_size, n_states): + while 1: + for i in range(batch_size,alldata.shape[0]+1,batch_size): + x = alldata[i-batch_size:i] + y = alllabels[i-batch_size:i] + # .reshape(batch_size,alllabels.shape[1]) + max_len = max(map(len,x)) + x = pad_sequences(x,maxlen=max_len,dtype='float32',padding='post') + y = pad_sequences(y,maxlen=max_len,dtype='float32',padding='post',value=n_states) + x = np.array(x) + y = np.array(y) + # print x.shape, y.shape + y_len = [] + # print y + for b in y: + # # print b[-1], int(b[-1]) != 138 + pad_len = 0 + while pad_len < len(b) and int(b[pad_len]) != 138: + pad_len += 1 + y_len.append(pad_len) + y_len = np.array(y_len) + x_len = [] + for b in x: + # print b[-1], int(b[-1]) != 138 + pad_len = 0 + while pad_len < len(b) and b[pad_len].any(): + pad_len += 1 + x_len.append(pad_len) + x_len = np.array(x_len) + # x_len = np.array(map(lambda x: len(x), x)) + # y_len = np.array(map(lambda x: len(x), y)) + # print x.shape,y.shape,x_len,y_len + # print y.shape + # y = dense2sparse(y) + inputs = {'x': x, + 'y': y, + 'x_len': x_len, + 'y_len': y_len} + outputs = {'ctc': np.ones([batch_size]), + 'decoder': dense2sparse(y), + 'softmax': y.reshape(y.shape[0],y.shape[1],1)} + yield(inputs,outputs) + +def gen_bracketed_data(context_len=None,fix_length=False, + for_CTC=False, n_states=None): + if for_CTC: + assert(n_states != None) + return lambda x,y,batch_size: gen_ctc_data(x,y,batch_size,n_states) + else: + return lambda x,y,batch_size: _gen_bracketed_data_3D(x,y,batch_size,context_len) + # return lambda x,y,nf: _gen_bracketed_data(x,y,nf,context_len,fix_length, + # for_CTC) + +def plotFromCSV(modelName, loss_cols=[2,5], acc_cols=[1,4]): + data = np.loadtxt(modelName+'.csv',skiprows=1,delimiter=',') + epoch = data[:,[0]] + acc = data[:,[acc_cols[0]]] + loss = data[:,[loss_cols[0]]] + val_acc = data[:,[acc_cols[1]]] + val_loss = data[:,[loss_cols[1]]] + + fig, ax1 = plt.subplots() + ax1.plot(acc) + ax1.plot(val_acc) + ax2 = ax1.twinx() + ax2.plot(loss,color='r') + ax2.plot(val_loss,color='g') + plt.title('model loss & accuracy') + ax1.set_ylabel('accuracy') + ax2.set_ylabel('loss') + ax1.set_xlabel('epoch') + ax1.legend(['training acc', 'testing acc']) + ax2.legend(['training loss', 'testing loss']) + fig.tight_layout() + plt.savefig(modelName+'.png') + plt.clf() + +def writeSenScores(filename,scores,weight,offset): + n_active = scores.shape[1] + s = '' + s = """s3 +version 0.1 +mdef_file ../../en_us.cd_cont_4000/mdef +n_sen 138 +logbase 1.000100 +endhdr +""" + s += struct.pack('I',0x11223344) + + scores = np.log(scores)/np.log(1.0001) + scores *= -1 + scores -= np.min(scores,axis=1).reshape(-1,1) + # scores = scores.astype(int) + scores *= weight + scores += offset + truncateToShort = lambda x: 32676 if x > 32767 else (-32768 if x < -32768 else x) + vf = np.vectorize(truncateToShort) + scores = vf(scores) + # scores /= np.sum(scores,axis=0) + for r in scores: + # print np.argmin(r) + s += struct.pack('h',n_active) + r_str = struct.pack('%sh' % len(r), *r) + # r_str = reduce(lambda x,y: x+y,r_str) + s += r_str + with open(filename,'w') as f: + f.write(s) + +def getPredsFromArray(model,data,nFrames,filenames,res_dir,res_ext,freqs,preds_in=False,weight=0.1,offset=0): + if preds_in: + preds = data + else: + preds = model.predict(data,verbose=1,batch_size=2048) + pos = 0 + for i in range(len(nFrames)): + fname = filenames[i][:-4] + fname = reduce(lambda x,y: x+'/'+y,fname.split('/')[4:]) + stdout.write("\r%d/%d " % (i,len(filenames))) + stdout.flush() + res_file_path = res_dir+fname+res_ext + dirname = os.path.dirname(res_file_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + # preds = model.predict(data[pos:pos+nFrames[i]],batch_size=nFrames[i]) + writeSenScores(res_file_path,preds[pos:pos+nFrames[i]],freqs,weight,offset) + pos += nFrames[i] + +def getPredsFromFilelist(model,filelist,file_dir,file_ext, + res_dir,res_ext,n_feat=40,context_len=None, + weight=1,offset=0, data_preproc_fn=None, + data_postproc_fn=None): + with open(filelist) as f: + files = f.readlines() + files = map(lambda x: x.strip(),files) + filepaths = map(lambda x: file_dir+x+file_ext,files) + scaler = StandardScaler(copy=False,with_std=False) + for i in range(len(filepaths)): + stdout.write("\r%d/%d " % (i,len(filepaths))) + stdout.flush() + + f = filepaths[i] + if not os.path.exists(f): + print ("\n",f) + continue + data = readMFC(f,n_feat) + data = scaler.fit_transform(data) + + if context_len != None: + pad_top = np.zeros((context_len,data.shape[1])) + data[0] + pad_bot = np.zeros((context_len,data.shape[1])) + data[-1] + padded_data = np.concatenate((pad_top,data),axis=0) + padded_data = np.concatenate((padded_data,pad_bot),axis=0) + + data = [] + for j in range(context_len,len(padded_data) - context_len): + new_row = padded_data[j - context_len: j + context_len + 1] + new_row = new_row.flatten() + data.append(new_row) + data = np.array(data) + if data_preproc_fn != None: + _data = data_preproc_fn(data) + preds = model.predict(_data) + preds = np.squeeze(preds) + else: + preds = model.predict(data) + + if data_postproc_fn != None: + preds = data_postproc_fn(preds) + if preds.shape[0] != data.shape[0]: + preds = preds[:data.shape[0]] + # print np.sum(preds) + # print preds.shape + res_file_path = res_dir+files[i]+res_ext + dirname = os.path.dirname(res_file_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + writeSenScores(res_file_path,preds,weight,offset) + +# a = dense2sparse(np.array([[1,2,3],[4,0,6]])) +# print a.shape +# print np.asarray(a,dtype=long) \ No newline at end of file diff --git a/scripts/19.nn_train/nn_train.pl b/scripts/19.nn_train/nn_train.pl new file mode 100644 index 00000000..e5dceee3 --- /dev/null +++ b/scripts/19.nn_train/nn_train.pl @@ -0,0 +1,167 @@ +#!/usr/bin/perl +## ==================================================================== +## +## Copyright (c) 2006 Carnegie Mellon University. All rights +## reserved. +## +## Redistribution and use in source and binary forms, with or without +## modification, are permitted provided that the following conditions +## are met: +## +## 1. Redistributions of source code must retain the above copyright +## notice, this list of conditions and the following disclaimer. +## +## 2. Redistributions in binary form must reproduce the above copyright +## notice, this list of conditions and the following disclaimer in +## the documentation and/or other materials provided with the +## distribution. +## +## This work was supported in part by funding from the Defense Advanced +## Research Projects Agency and the National Science Foundation of the +## United States of America, and the CMU Sphinx Speech Consortium. +## +## THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND +## ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +## THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +## PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY +## NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +## SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +## LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +## DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +## THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +## (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +## OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +## +## ==================================================================== +## +## Author: David Huggins-Daines +## + +use strict; +use File::Copy; +use File::Basename; +use File::Spec::Functions; +use File::Path; +use File::Temp; + +use lib catdir(dirname($0), updir(), 'lib'); +use SphinxTrain::Config; +use SphinxTrain::Util; + +# die "Usage: $0 \n" unless @ARGV == 2; +# my ($part, $npart) = @ARGV; +#my $part=0; +#my $npart=0; +if ($ST::CFG_TRAIN_DNN ne "yes") { + Log("Skipped: \$ST::CFG_TRAIN_DNN set to \'$ST::CFG_TRAIN_DNN\' in sphinx_train.cfg\n"); + exit(0); +} +my $hmm_dir = defined($ST::CFG_FORCE_ALIGN_MODELDIR) + ? $ST::CFG_FORCE_ALIGN_MODELDIR + : "$ST::CFG_MODEL_DIR/$ST::CFG_EXPTNAME.ci_$ST::CFG_DIRLABEL"; +my $logdir = "$ST::CFG_LOG_DIR/19.nn_train"; +#mkdir $logdir, 0755; +my $outdir = "$ST::CFG_BASE_DIR/falignout"; +my $outfile = "$outdir/$ST::CFG_EXPTNAME.alignedtranscripts.d"; + +my $statepdeffn = $ST::CFG_HMM_TYPE; # indicates the type of HMMs +my $mwfloor = 1e-8; +my $minvar = 1e-4; +my $listoffiles = $ST::CFG_LISTOFFILES; +my $transcriptfile = "$outdir/$ST::CFG_EXPTNAME.aligninput"; +my $dict = defined($ST::CFG_FORCE_ALIGN_DICTIONARY) + ? $ST::CFG_FORCE_ALIGN_DICTIONARY + : "$outdir/$ST::CFG_EXPTNAME.falign.dict"; +my $fdict = defined($ST::CFG_FORCE_ALIGN_FILLERDICT) + ? $ST::CFG_FORCE_ALIGN_FILLERDICT + : "$outdir/$ST::CFG_EXPTNAME.falign.fdict"; +my $beam = defined($ST::CFG_FORCE_ALIGN_BEAM) ? $ST::CFG_FORCE_ALIGN_BEAM : 1e-100; +my $logfile = "$logdir/${ST::CFG_EXPTNAME}.falign.log"; + +# Get the number of utterances +open INPUT,"${ST::CFG_LISTOFFILES}" or die "Failed to open $ST::CFG_LISTOFFILES: $!"; +# Check control file format (determines if we add ,CTL below) +my $line = ; +my $ctlext; +if (split(" ", $line) == 1 or $line =~ m,/,) { + # Use full file path1 + $ctlext = ",CTL"; +} +else { + # Use utterance ID + $ctlext = ""; +} +my $ctl_counter = 1; +while () { + $ctl_counter++; +} +close INPUT; +$ctl_counter = 1 unless ($ctl_counter); + +#Log("Force alignment starting: ($part of $npart) ", 'result'); + +my @phsegdir; + +if (defined($ST::CFG_STSEG_DIR)) { + push @phsegdir, (-stsegdir => "$ST::CFG_STSEG_DIR"); +} +else{ + LogError("Please specity CFG_STSEG_DIR"); +} +#Log('Compiling stseg-read'); +#my $logfile = catfile($logdir, "${ST::CFG_EXPTNAME}.stseg_compile.log"); +#my $return_value = system("gcc -o $ST::CFG_SPHINXTRAIN_DIR/scripts/19.nn_train/19.nn_train/stseg-read $ST::CFG_SPHINXTRAIN_DIR/scripts/19.nn_train/19.nn_train/stseg-read.c"); + +my $return_value = system("sphinx3_align -hmm $hmm_dir -senmgau $statepdeffn -mixwfloor $mwfloor -varfloor $minvar -dict $dict -fdict $fdict -ctl $ST::CFG_LISTOFFILES -cepdir $ST::CFG_FEATFILES_DIR -cepext .$ST::CFG_FEATFILE_EXTENSION -insent $transcriptfile -outsent $outfile @phsegdir -beam $beam -agc $ST::CFG_AGC -cmn $ST::CFG_CMN -varnorm $ST::CFG_VARNORM -feat $ST::CFG_FEATURE -ceplen $ST::CFG_VECTOR_LENGTH > $logdir-align.txt" ); +#my $return_value = RunTool + # ('sphinx3_align', $logfile, $ctl_counter, + # -hmm => $hmm_dir, + # -senmgau => $statepdeffn, + # -mixwfloor => $mwfloor, + # -varfloor => $minvar, + # -dict => $dict, + # -fdict => $fdict, + # -ctl => $ST::CFG_LISTOFFILES, + # -ctlcount => $ctl_counter, + # -cepdir => $ST::CFG_FEATFILES_DIR, + # -cepext => ".$ST::CFG_FEATFILE_EXTENSION", + # -insent => $transcriptfile, + # -outsent => $outfile, + # @phsegdir, + # -beam => $beam, + # -agc => $ST::CFG_AGC, + # -cmn => $ST::CFG_CMN, + # -varnorm => $ST::CFG_VARNORM, + # -feat => $ST::CFG_FEATURE, + # -ceplen => $ST::CFG_VECTOR_LENGTH, + # ); + + +if ($return_value) { + LogError("Failed to run sphinx3_align"); +} + +Log('converting stseg files to ASCII'); +my $logfile = "$logdir/stseg2ascii.log"; + +my $return_value = system("$ST::CFG_SPHINXTRAIN_DIR/scripts/19.nn_train/readStSegs.sh $ST::CFG_STSEG_DIR $ST::CFG_SPHINXTRAIN_DIR/scripts/19.nn_train > $logdir-convert.txt"); +Log('generating dataset'); +my $logfile = catfile($logdir, "${ST::CFG_EXPTNAME}.gendataset.log"); +$ENV{PYTHONPATH} .= ':' . File::Spec->catdir($ST::CFG_SPHINXTRAIN_DIR, 'python'); + +Log("python $ST::CFG_SPHINXTRAIN_DIR/python/cmusphinx/runDatasetGen.py -train_fileids $ST::CFG_LISTOFFILES -val_fileids $ST::CFG_LISTOFFILES -nfilts $ST::CFG_VECTOR_LENGTH -feat_dir $ST::CFG_FEATFILES_DIR -feat_ext .$ST::CFG_FEATFILE_EXTENSION -stseg_dir $ST::CFG_STSEG_DIR -stseg_ext .stseg.txt -mdef $hmm_dir/mdef -outfile_prefix $ST::CFG_BASE_DIR -keep_utts"); + +my $return_value = system("python $ST::CFG_SPHINXTRAIN_DIR/python/cmusphinx/runDatasetGen.py -train_fileids $ST::CFG_LISTOFFILES -val_fileids $ST::CFG_LISTOFFILES -nfilts $ST::CFG_VECTOR_LENGTH -feat_dir $ST::CFG_FEATFILES_DIR -feat_ext .$ST::CFG_FEATFILE_EXTENSION -stseg_dir $ST::CFG_STSEG_DIR -stseg_ext .stseg.txt -mdef $hmm_dir/mdef -outfile_prefix $ST::CFG_BASE_DIR/ -keep_utts"); + +if ($return_value) { + LogError("Failed to run runDatasetGen.py"); +} +Log('training nn'); +my $logfile = catfile($logdir, "${ST::CFG_EXPTNAME}.training.log"); +my $return_value = system("python $ST::CFG_SPHINXTRAIN_DIR/python/cmusphinx/runNNTrain.py -train_data $ST::CFG_BASE_DIR/train.npy -train_label $ST::CFG_BASE_DIR/train_labels.npy -val_data $ST::CFG_BASE_DIR/dev.npy -val_labels $ST::CFG_BASE_DIR/dev_labels.npy -nn_config $ST::CFG_SPHINXTRAIN_DIR/scripts/19.nn_train/sample_nn.cfg -context_win 4 -model_name $hmm_dir/keras_mode.h5 -n_epochs 3"); +if ($return_value) { + LogError("Failed to run runNNTrain.py"); +} + + +exit ($return_value); diff --git a/scripts/19.nn_train/readStSegs.sh b/scripts/19.nn_train/readStSegs.sh new file mode 100755 index 00000000..ece62a12 --- /dev/null +++ b/scripts/19.nn_train/readStSegs.sh @@ -0,0 +1,9 @@ +stseg_fldr="$1" +read_fldr="$2" +echo $(pwd) +files=$(find "$stseg_fldr/" -name "*.stseg") +for f in $files +do + echo "CONVERTING: "$f + cat $f | $read_fldr/stseg-read > $f.txt +done diff --git a/scripts/19.nn_train/sample_nn.cfg b/scripts/19.nn_train/sample_nn.cfg new file mode 100644 index 00000000..851e9aba --- /dev/null +++ b/scripts/19.nn_train/sample_nn.cfg @@ -0,0 +1,22 @@ +#type resnet +#width 2048 +#block_depth 3 +#n_blocks 5 + +type mlp +width 2048 +depth 3 +#dropout 0.25 +batch_norm +activation sigmoid +#ctc_loss +batch_size 512 +optimizer SGD +lr 0.001 + + +#type conv +#conv [84,(6,8)] [168,(3,4)] +#pooling None [max,(3,3),(1,1)] +#width 2048 +#depth 3 \ No newline at end of file diff --git a/scripts/19.nn_train/stseg-read b/scripts/19.nn_train/stseg-read new file mode 100755 index 00000000..92086dc0 Binary files /dev/null and b/scripts/19.nn_train/stseg-read differ diff --git a/scripts/19.nn_train/stseg-read.c b/scripts/19.nn_train/stseg-read.c new file mode 100644 index 00000000..9c0688a7 --- /dev/null +++ b/scripts/19.nn_train/stseg-read.c @@ -0,0 +1,155 @@ +/* ==================================================================== + * Copyright (c) 1995-2002 Carnegie Mellon University. All rights + * reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * + * This work was supported in part by funding from the Defense Advanced + * Research Projects Agency and the National Science Foundation of the + * United States of America, and the CMU Sphinx Speech Consortium. + * + * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND + * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY + * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * ==================================================================== + * + */ +/* + * stseg.c -- Read and display .stseg file created by s3align. + * + * ********************************************** + * CMU ARPA Speech Project + * + * Copyright (c) 1996 Carnegie Mellon University. + * ALL RIGHTS RESERVED. + * ********************************************** + * + * HISTORY + * + * 19-Jul-96 M K Ravishankar (rkm@cs.cmu.edu) at Carnegie Mellon University + * Created. + */ + + +#include +#include +#include +#include +#include + +/* "\nCI.8 LC.8 RC.8 POS.3(HI)-ST.5(LO) SCR(32)" */ + +static char *phone[100]; +static int n_phone; +static char *posname = "besiu"; + + +static skip_line (FILE *fp) +{ + int c; + + while (((c = fgetc (fp)) >= 0) && (c != '\n')); +} + + +main () +{ + int i, k, nf, scr; + int16_t c; + FILE *fp; + int16_t str[3]; + char str1[1024]; + fp = stdin; + n_phone = 0; + + /* Skip version# string */ + skip_line (fp); + + /* Read CI phone names */ + for (;;) { + for (i = 0;; i++) { + if (((c = fgetc(fp)) == ' ') || (c == '\n')) + break; + str1[i] = c; + } + str1[i] = '\0'; + + if (c == ' ') { + phone[n_phone] = (char *) malloc (i+1); + strcpy (phone[n_phone], str1); + n_phone++; + } else + break; + } + printf ("%d phones\n", n_phone); + + /* Skip format line */ + skip_line (fp); + + /* Skip end-comment line */ + skip_line (fp); + + /* Read byteorder magic no. */ + fread (&i, sizeof(int), 1, fp); + assert (i == 0x11223344); + + /* Read no. frames */ + fread (&nf, sizeof(int), 1, fp); + printf ("#frames = %d\n", nf); + + char pos[1]; + /* Read state info per frame */ + for (i = 0; i < nf; i++) { + k = fread (str, sizeof(uint16_t), 3, fp); + str[3] = 0; + assert (k == 3); + k = fread (pos, sizeof(char), 1, fp); + assert (k == 1); + k = fread (&scr, sizeof(int), 1, fp); + assert (k == 1); + + c = str[0]; + //printf("c=%d\n",c); + assert ((c >= 0) && (c < n_phone)); + printf ("%5d %11d %2d %s", i, scr, pos[0] & 0x001f, phone[c]); + + c = str[1]; + //printf("c2=%d\n",c); + if (c != -1) { + assert ((c >= 0) && (c < n_phone)); + printf (" %s", phone[c]); + } + + c = str[2]; + //printf("c3=%d\n",c); + if (c != -1) { + assert ((c >= 0) && (c < n_phone)); + printf (" %s", phone[c]); + } + + c = (pos[0] >> 5) & 0x07; + if ((c >= 0) && (c < 4)) + printf (" %c", posname[c]); + + printf ("\n"); + } +} diff --git a/scripts/Makefile.am b/scripts/Makefile.am index 991cc7d8..5de7f52b 100644 --- a/scripts/Makefile.am +++ b/scripts/Makefile.am @@ -32,6 +32,11 @@ nobase_scripts_SCRIPTS = \ 11.force_align/slave_align.pl \ 12.vtln_align/slave_align.pl \ 12.vtln_align/vtln_align.pl \ + 19.nn_train/nn_train.pl \ + 19.nn_train/readStSegs.sh \ + 19.nn_train/stseg-read.c \ + 19.nn_train/stseg-read \ + 19.nn_train/sample_nn.cfg \ 20.ci_hmm/baum_welch.pl \ 20.ci_hmm/norm_and_launchbw.pl \ 20.ci_hmm/norm.pl \ diff --git a/scripts/sphinxtrain b/scripts/sphinxtrain old mode 100644 new mode 100755 index 61529114..893e5528 --- a/scripts/sphinxtrain +++ b/scripts/sphinxtrain @@ -74,6 +74,7 @@ steps = [ "10.falign_ci_hmm/slave_convg.pl", "11.force_align/slave_align.pl", "12.vtln_align/slave_align.pl", +"19.nn_train/nn_train.pl", "20.ci_hmm/slave_convg.pl", "30.cd_hmm_untied/slave_convg.pl", "40.buildtrees/slave.treebuilder.pl",