diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BiLSTMCharacterTrigramClassifier.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BiLSTMCharacterTrigramClassifier.java index 394db45..dd42c68 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BiLSTMCharacterTrigramClassifier.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BiLSTMCharacterTrigramClassifier.java @@ -3,28 +3,12 @@ import java.io.IOException; import java.util.Properties; -import at.medunigraz.imi.bst.n2c2.model.Criterion; +import at.medunigraz.imi.bst.n2c2.nn.architecture.BiLSTMArchitecture; import at.medunigraz.imi.bst.n2c2.nn.input.CharacterTrigram; import at.medunigraz.imi.bst.n2c2.nn.iterator.SentenceIterator; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; -import org.deeplearning4j.nn.conf.layers.GravesLSTM; -import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; -import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.learning.config.AdaGrad; -import org.nd4j.linalg.lossfunctions.LossFunctions; /** * BI-LSTM classifier for n2c2 task 2018 refactored from dl4j examples. @@ -75,58 +59,7 @@ protected void initializeNetworkBinaryMultiLabelDeep() { fullSetIterator = new SentenceIterator(patientExamples, new CharacterTrigram(SentenceIterator.createPatientLines(patientExamples)), BATCH_SIZE); - int nOutFF = 150; - int lstmLayerSize = 128; - double l2Regulization = 0.01; - double adaGradCore = 0.04; - double adaGradDense = 0.01; - double adaGradGraves = 0.008; - - // seed for reproducibility - final int seed = 12345; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) - .updater(AdaGrad.builder().learningRate(adaGradCore).build()).regularization(true).l2(l2Regulization) - .weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SINGLE) - .inferenceWorkspaceMode(WorkspaceMode.SINGLE).list() - - .layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(fullSetIterator.getInputRepresentation().getVectorSize()).nOut(nOutFF) - .weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - - .layer(1, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF) - .weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - - .layer(2, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF) - .weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - - .layer(3, - new GravesBidirectionalLSTM.Builder().nIn(nOutFF).nOut(lstmLayerSize) - .updater(AdaGrad.builder().learningRate(adaGradGraves).build()) - .activation(Activation.SOFTSIGN).build()) - - .layer(4, - new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize) - .updater(AdaGrad.builder().learningRate(adaGradGraves).build()) - .activation(Activation.SOFTSIGN).build()) - - .layer(5, new RnnOutputLayer.Builder().activation(Activation.SIGMOID) - .lossFunction(LossFunctions.LossFunction.XENT).nIn(lstmLayerSize).nOut(Criterion.classifiableValues().length).build()) - - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).pretrain(false).backprop(true).build(); - - // .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength) - - this.net = new MultiLayerNetwork(conf); - this.net.init(); - this.net.setListeners(new ScoreIterationListener(1)); - + this.net = new BiLSTMArchitecture().getNetwork(fullSetIterator.getInputRepresentation().getVectorSize()); } protected String getModelName() { diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMEmbeddingsClassifier.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMEmbeddingsClassifier.java index 938af30..1c080ca 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMEmbeddingsClassifier.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMEmbeddingsClassifier.java @@ -4,23 +4,12 @@ import java.io.IOException; import java.util.Properties; +import at.medunigraz.imi.bst.n2c2.nn.architecture.LSTMArchitecture; import at.medunigraz.imi.bst.n2c2.nn.input.WordEmbedding; import at.medunigraz.imi.bst.n2c2.nn.iterator.TokenIterator; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.conf.layers.GravesLSTM; -import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.lossfunctions.LossFunctions; /** * LSTM classifier for n2c2 task 2018 refactored from dl4j examples. @@ -48,24 +37,7 @@ private void initializeNetworkBinaryMultiLabelDebug() { fullSetIterator = new TokenIterator(patientExamples, new WordEmbedding(PRETRAINED_VECTORS), BATCH_SIZE); // Set up network configuration - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(0) - .updater(Adam.builder().learningRate(2e-2).build()).regularization(true).l2(1e-5).weightInit(WeightInit.XAVIER) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SEPARATE) - .inferenceWorkspaceMode(WorkspaceMode.SEPARATE) // https://deeplearning4j.org/workspaces - .list().layer(0, new GravesLSTM.Builder().nIn(fullSetIterator.getInputRepresentation().getVectorSize()).nOut(256).activation(Activation.TANH).build()) - .layer(1, - new RnnOutputLayer.Builder().activation(Activation.SIGMOID) - .lossFunction(LossFunctions.LossFunction.XENT).nIn(256).nOut(13).build()) - .pretrain(false).backprop(true).build(); - - // for truncated backpropagation over time - // .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength) - // .tBPTTBackwardLength(tbpttLength).pretrain(false).backprop(true).build(); - - this.net = new MultiLayerNetwork(conf); - this.net.init(); - this.net.setListeners(new ScoreIterationListener(1)); + this.net = new LSTMArchitecture().getNetwork(fullSetIterator.getInputRepresentation().getVectorSize()); } @Override diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/Architecture.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/Architecture.java new file mode 100644 index 0000000..c67d939 --- /dev/null +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/Architecture.java @@ -0,0 +1,8 @@ +package at.medunigraz.imi.bst.n2c2.nn.architecture; + +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; + +public interface Architecture { + + MultiLayerNetwork getNetwork(int nIn); +} diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/BiLSTMArchitecture.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/BiLSTMArchitecture.java new file mode 100644 index 0000000..e12aaf4 --- /dev/null +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/BiLSTMArchitecture.java @@ -0,0 +1,77 @@ +package at.medunigraz.imi.bst.n2c2.nn.architecture; + +import at.medunigraz.imi.bst.n2c2.model.Criterion; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.learning.config.AdaGrad; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +public class BiLSTMArchitecture implements Architecture { + @Override + public MultiLayerNetwork getNetwork(int nIn) { + int nOutFF = 150; + int lstmLayerSize = 128; + double l2Regulization = 0.01; + double adaGradCore = 0.04; + double adaGradDense = 0.01; + double adaGradGraves = 0.008; + + // seed for reproducibility + final int seed = 12345; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) + .updater(AdaGrad.builder().learningRate(adaGradCore).build()).regularization(true).l2(l2Regulization) + .weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SINGLE) + .inferenceWorkspaceMode(WorkspaceMode.SINGLE).list() + + .layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(nIn).nOut(nOutFF) + .weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) + .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(10).build()) + + .layer(1, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF) + .weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) + .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(10).build()) + + .layer(2, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF) + .weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) + .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(10).build()) + + .layer(3, + new GravesBidirectionalLSTM.Builder().nIn(nOutFF).nOut(lstmLayerSize) + .updater(AdaGrad.builder().learningRate(adaGradGraves).build()) + .activation(Activation.SOFTSIGN).build()) + + .layer(4, + new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize) + .updater(AdaGrad.builder().learningRate(adaGradGraves).build()) + .activation(Activation.SOFTSIGN).build()) + + .layer(5, new RnnOutputLayer.Builder().activation(Activation.SIGMOID) + .lossFunction(LossFunctions.LossFunction.XENT).nIn(lstmLayerSize).nOut(Criterion.classifiableValues().length).build()) + + .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) + .inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).pretrain(false).backprop(true).build(); + + // .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength) + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + net.setListeners(new ScoreIterationListener(1)); + return net; + } +} diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/LSTMArchitecture.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/LSTMArchitecture.java new file mode 100644 index 0000000..9c25760 --- /dev/null +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/LSTMArchitecture.java @@ -0,0 +1,39 @@ +package at.medunigraz.imi.bst.n2c2.nn.architecture; + +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +public class LSTMArchitecture implements Architecture { + @Override + public MultiLayerNetwork getNetwork(int nIn) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(0) + .updater(Adam.builder().learningRate(2e-2).build()).regularization(true).l2(1e-5).weightInit(WeightInit.XAVIER) + .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SEPARATE) + .inferenceWorkspaceMode(WorkspaceMode.SEPARATE) // https://deeplearning4j.org/workspaces + .list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(256).activation(Activation.TANH).build()) + .layer(1, + new RnnOutputLayer.Builder().activation(Activation.SIGMOID) + .lossFunction(LossFunctions.LossFunction.XENT).nIn(256).nOut(13).build()) + .pretrain(false).backprop(true).build(); + + // for truncated backpropagation over time + // .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength) + // .tBPTTBackwardLength(tbpttLength).pretrain(false).backprop(true).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + net.setListeners(new ScoreIterationListener(1)); + return net; + } +}