Skip to content

Commit

Permalink
Extract architecture code into own interface
Browse files Browse the repository at this point in the history
This allows other combinations as required by bst-mug#107 and bst-mug#110.
  • Loading branch information
michelole committed Jun 3, 2019
1 parent 65ba2a4 commit 2bc8a92
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,12 @@
import java.io.IOException;
import java.util.Properties;

import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.nn.architecture.BiLSTMArchitecture;
import at.medunigraz.imi.bst.n2c2.nn.input.CharacterTrigram;
import at.medunigraz.imi.bst.n2c2.nn.iterator.SentenceIterator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
* BI-LSTM classifier for n2c2 task 2018 refactored from dl4j examples.
Expand Down Expand Up @@ -75,58 +59,7 @@ protected void initializeNetworkBinaryMultiLabelDeep() {

fullSetIterator = new SentenceIterator(patientExamples, new CharacterTrigram(SentenceIterator.createPatientLines(patientExamples)), BATCH_SIZE);

int nOutFF = 150;
int lstmLayerSize = 128;
double l2Regulization = 0.01;
double adaGradCore = 0.04;
double adaGradDense = 0.01;
double adaGradGraves = 0.008;

// seed for reproducibility
final int seed = 12345;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)
.updater(AdaGrad.builder().learningRate(adaGradCore).build()).regularization(true).l2(l2Regulization)
.weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SINGLE)
.inferenceWorkspaceMode(WorkspaceMode.SINGLE).list()

.layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(fullSetIterator.getInputRepresentation().getVectorSize()).nOut(nOutFF)
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).build())

.layer(1, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF)
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).build())

.layer(2, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF)
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).build())

.layer(3,
new GravesBidirectionalLSTM.Builder().nIn(nOutFF).nOut(lstmLayerSize)
.updater(AdaGrad.builder().learningRate(adaGradGraves).build())
.activation(Activation.SOFTSIGN).build())

.layer(4,
new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
.updater(AdaGrad.builder().learningRate(adaGradGraves).build())
.activation(Activation.SOFTSIGN).build())

.layer(5, new RnnOutputLayer.Builder().activation(Activation.SIGMOID)
.lossFunction(LossFunctions.LossFunction.XENT).nIn(lstmLayerSize).nOut(Criterion.classifiableValues().length).build())

.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).pretrain(false).backprop(true).build();

// .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)

this.net = new MultiLayerNetwork(conf);
this.net.init();
this.net.setListeners(new ScoreIterationListener(1));

this.net = new BiLSTMArchitecture().getNetwork(fullSetIterator.getInputRepresentation().getVectorSize());
}

protected String getModelName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,12 @@
import java.io.IOException;
import java.util.Properties;

import at.medunigraz.imi.bst.n2c2.nn.architecture.LSTMArchitecture;
import at.medunigraz.imi.bst.n2c2.nn.input.WordEmbedding;
import at.medunigraz.imi.bst.n2c2.nn.iterator.TokenIterator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
* LSTM classifier for n2c2 task 2018 refactored from dl4j examples.
Expand Down Expand Up @@ -48,24 +37,7 @@ private void initializeNetworkBinaryMultiLabelDebug() {
fullSetIterator = new TokenIterator(patientExamples, new WordEmbedding(PRETRAINED_VECTORS), BATCH_SIZE);

// Set up network configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(0)
.updater(Adam.builder().learningRate(2e-2).build()).regularization(true).l2(1e-5).weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SEPARATE)
.inferenceWorkspaceMode(WorkspaceMode.SEPARATE) // https://deeplearning4j.org/workspaces
.list().layer(0, new GravesLSTM.Builder().nIn(fullSetIterator.getInputRepresentation().getVectorSize()).nOut(256).activation(Activation.TANH).build())
.layer(1,
new RnnOutputLayer.Builder().activation(Activation.SIGMOID)
.lossFunction(LossFunctions.LossFunction.XENT).nIn(256).nOut(13).build())
.pretrain(false).backprop(true).build();

// for truncated backpropagation over time
// .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength)
// .tBPTTBackwardLength(tbpttLength).pretrain(false).backprop(true).build();

this.net = new MultiLayerNetwork(conf);
this.net.init();
this.net.setListeners(new ScoreIterationListener(1));
this.net = new LSTMArchitecture().getNetwork(fullSetIterator.getInputRepresentation().getVectorSize());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package at.medunigraz.imi.bst.n2c2.nn.architecture;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

public interface Architecture {

MultiLayerNetwork getNetwork(int nIn);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package at.medunigraz.imi.bst.n2c2.nn.architecture;

import at.medunigraz.imi.bst.n2c2.model.Criterion;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class BiLSTMArchitecture implements Architecture {
@Override
public MultiLayerNetwork getNetwork(int nIn) {
int nOutFF = 150;
int lstmLayerSize = 128;
double l2Regulization = 0.01;
double adaGradCore = 0.04;
double adaGradDense = 0.01;
double adaGradGraves = 0.008;

// seed for reproducibility
final int seed = 12345;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)
.updater(AdaGrad.builder().learningRate(adaGradCore).build()).regularization(true).l2(l2Regulization)
.weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SINGLE)
.inferenceWorkspaceMode(WorkspaceMode.SINGLE).list()

.layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(nIn).nOut(nOutFF)
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).build())

.layer(1, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF)
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).build())

.layer(2, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF)
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).build())

.layer(3,
new GravesBidirectionalLSTM.Builder().nIn(nOutFF).nOut(lstmLayerSize)
.updater(AdaGrad.builder().learningRate(adaGradGraves).build())
.activation(Activation.SOFTSIGN).build())

.layer(4,
new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
.updater(AdaGrad.builder().learningRate(adaGradGraves).build())
.activation(Activation.SOFTSIGN).build())

.layer(5, new RnnOutputLayer.Builder().activation(Activation.SIGMOID)
.lossFunction(LossFunctions.LossFunction.XENT).nIn(lstmLayerSize).nOut(Criterion.classifiableValues().length).build())

.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).pretrain(false).backprop(true).build();

// .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
return net;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package at.medunigraz.imi.bst.n2c2.nn.architecture;

import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class LSTMArchitecture implements Architecture {
@Override
public MultiLayerNetwork getNetwork(int nIn) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(0)
.updater(Adam.builder().learningRate(2e-2).build()).regularization(true).l2(1e-5).weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SEPARATE)
.inferenceWorkspaceMode(WorkspaceMode.SEPARATE) // https://deeplearning4j.org/workspaces
.list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(256).activation(Activation.TANH).build())
.layer(1,
new RnnOutputLayer.Builder().activation(Activation.SIGMOID)
.lossFunction(LossFunctions.LossFunction.XENT).nIn(256).nOut(13).build())
.pretrain(false).backprop(true).build();

// for truncated backpropagation over time
// .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength)
// .tBPTTBackwardLength(tbpttLength).pretrain(false).backprop(true).build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
return net;
}
}

0 comments on commit 2bc8a92

Please sign in to comment.