forked from bst-mug/n2c2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extract architecture code into own interface
This allows other combinations as required by bst-mug#107 and bst-mug#110.
- Loading branch information
Showing
5 changed files
with
128 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/Architecture.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
package at.medunigraz.imi.bst.n2c2.nn.architecture; | ||
|
||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
|
||
public interface Architecture { | ||
|
||
MultiLayerNetwork getNetwork(int nIn); | ||
} |
77 changes: 77 additions & 0 deletions
77
src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/BiLSTMArchitecture.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package at.medunigraz.imi.bst.n2c2.nn.architecture; | ||
|
||
import at.medunigraz.imi.bst.n2c2.model.Criterion; | ||
import org.deeplearning4j.nn.conf.GradientNormalization; | ||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.WorkspaceMode; | ||
import org.deeplearning4j.nn.conf.layers.DenseLayer; | ||
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; | ||
import org.deeplearning4j.nn.conf.layers.GravesLSTM; | ||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | ||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; | ||
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; | ||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
import org.deeplearning4j.nn.weights.WeightInit; | ||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | ||
import org.nd4j.linalg.activations.Activation; | ||
import org.nd4j.linalg.learning.config.AdaGrad; | ||
import org.nd4j.linalg.lossfunctions.LossFunctions; | ||
|
||
public class BiLSTMArchitecture implements Architecture { | ||
@Override | ||
public MultiLayerNetwork getNetwork(int nIn) { | ||
int nOutFF = 150; | ||
int lstmLayerSize = 128; | ||
double l2Regulization = 0.01; | ||
double adaGradCore = 0.04; | ||
double adaGradDense = 0.01; | ||
double adaGradGraves = 0.008; | ||
|
||
// seed for reproducibility | ||
final int seed = 12345; | ||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) | ||
.updater(AdaGrad.builder().learningRate(adaGradCore).build()).regularization(true).l2(l2Regulization) | ||
.weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | ||
.gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SINGLE) | ||
.inferenceWorkspaceMode(WorkspaceMode.SINGLE).list() | ||
|
||
.layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(nIn).nOut(nOutFF) | ||
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) | ||
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | ||
.gradientNormalizationThreshold(10).build()) | ||
|
||
.layer(1, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF) | ||
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) | ||
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | ||
.gradientNormalizationThreshold(10).build()) | ||
|
||
.layer(2, new DenseLayer.Builder().activation(Activation.RELU).nIn(nOutFF).nOut(nOutFF) | ||
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) | ||
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | ||
.gradientNormalizationThreshold(10).build()) | ||
|
||
.layer(3, | ||
new GravesBidirectionalLSTM.Builder().nIn(nOutFF).nOut(lstmLayerSize) | ||
.updater(AdaGrad.builder().learningRate(adaGradGraves).build()) | ||
.activation(Activation.SOFTSIGN).build()) | ||
|
||
.layer(4, | ||
new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize) | ||
.updater(AdaGrad.builder().learningRate(adaGradGraves).build()) | ||
.activation(Activation.SOFTSIGN).build()) | ||
|
||
.layer(5, new RnnOutputLayer.Builder().activation(Activation.SIGMOID) | ||
.lossFunction(LossFunctions.LossFunction.XENT).nIn(lstmLayerSize).nOut(Criterion.classifiableValues().length).build()) | ||
|
||
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) | ||
.inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).pretrain(false).backprop(true).build(); | ||
|
||
// .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength) | ||
|
||
MultiLayerNetwork net = new MultiLayerNetwork(conf); | ||
net.init(); | ||
net.setListeners(new ScoreIterationListener(1)); | ||
return net; | ||
} | ||
} |
39 changes: 39 additions & 0 deletions
39
src/main/java/at/medunigraz/imi/bst/n2c2/nn/architecture/LSTMArchitecture.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package at.medunigraz.imi.bst.n2c2.nn.architecture; | ||
|
||
import org.deeplearning4j.nn.conf.GradientNormalization; | ||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.WorkspaceMode; | ||
import org.deeplearning4j.nn.conf.layers.GravesLSTM; | ||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | ||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
import org.deeplearning4j.nn.weights.WeightInit; | ||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | ||
import org.nd4j.linalg.activations.Activation; | ||
import org.nd4j.linalg.learning.config.Adam; | ||
import org.nd4j.linalg.lossfunctions.LossFunctions; | ||
|
||
public class LSTMArchitecture implements Architecture { | ||
@Override | ||
public MultiLayerNetwork getNetwork(int nIn) { | ||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(0) | ||
.updater(Adam.builder().learningRate(2e-2).build()).regularization(true).l2(1e-5).weightInit(WeightInit.XAVIER) | ||
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | ||
.gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SEPARATE) | ||
.inferenceWorkspaceMode(WorkspaceMode.SEPARATE) // https://deeplearning4j.org/workspaces | ||
.list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(256).activation(Activation.TANH).build()) | ||
.layer(1, | ||
new RnnOutputLayer.Builder().activation(Activation.SIGMOID) | ||
.lossFunction(LossFunctions.LossFunction.XENT).nIn(256).nOut(13).build()) | ||
.pretrain(false).backprop(true).build(); | ||
|
||
// for truncated backpropagation over time | ||
// .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength) | ||
// .tBPTTBackwardLength(tbpttLength).pretrain(false).backprop(true).build(); | ||
|
||
MultiLayerNetwork net = new MultiLayerNetwork(conf); | ||
net.init(); | ||
net.setListeners(new ScoreIterationListener(1)); | ||
return net; | ||
} | ||
} |