Skip to content

Commit

Permalink
Refactor code around new interface InputRepresentation
Browse files Browse the repository at this point in the history
Introduce the new interface `InputRepresentation` to separate logic of input representation (e.g. word embeddings, character trigrams) from iterators and classifiers. This allows new combinations required as part of bst-mug#107 and bst-mug#110.

Move data-dependent methods such as `initializeTruncateLength` and `loadFeaturesForNarrative` to the iterators.

Remove public and duplicate attributes to reduce complexity.
  • Loading branch information
michelole committed Jun 3, 2019
1 parent 5cc6602 commit d524666
Show file tree
Hide file tree
Showing 10 changed files with 460 additions and 357 deletions.
113 changes: 11 additions & 102 deletions src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifier.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
package at.medunigraz.imi.bst.n2c2.nn;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.nn.input.CharacterTrigram;
import at.medunigraz.imi.bst.n2c2.nn.iterator.NGramIterator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -30,10 +22,7 @@
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;

Expand All @@ -47,40 +36,22 @@ public class BILSTMC3GClassifier extends BaseNNClassifier {

private static final Logger LOG = LogManager.getLogger();

@Override
public void initializeNetworkFromFile(String pathToModel) {

// settings for memory management:
// https://deeplearning4j.org/workspaces

Nd4j.getMemoryManager().setAutoGcWindow(10000);
// Nd4j.getMemoryManager().togglePeriodicGc(false);

// instantiating generator
fullSetIterator = new NGramIterator();

// TODO move to iterator.
Properties prop = null;
try {
// read char 3-grams and index
FileInputStream fis = new FileInputStream(new File(pathToModel, "characterNGram_3"));
ObjectInputStream ois = new ObjectInputStream(fis);
ArrayList<String> characterNGram_3 = (ArrayList<String>) ois.readObject();

((NGramIterator)fullSetIterator).characterNGram_3 = characterNGram_3;
((NGramIterator)fullSetIterator).vectorSize = characterNGram_3.size();
this.vectorSize = ((NGramIterator)fullSetIterator).vectorSize;

// read char 3-grams index
fis = new FileInputStream(new File(pathToModel, "char3GramToIdxMap"));
ois = new ObjectInputStream(fis);
Map<String, Integer> char3GramToIdxMap_0 = (HashMap<String, Integer>) ois.readObject();
((NGramIterator)fullSetIterator).char3GramToIdxMap = char3GramToIdxMap_0;

Nd4j.getRandom().setSeed(12345);

prop = loadProperties(pathToModel);
final int truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength"));
fullSetIterator = new NGramIterator(new CharacterTrigram(), truncateLength, BATCH_SIZE);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
throw new RuntimeException(e);
}

super.initializeNetworkFromFile(pathToModel);
Expand All @@ -102,14 +73,7 @@ protected void initializeNetworkBinaryMultiLabelDeep() {
Nd4j.getMemoryManager().setAutoGcWindow(10000);
// Nd4j.getMemoryManager().togglePeriodicGc(false);

try {
fullSetIterator = new NGramIterator(patientExamples, BATCH_SIZE);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
vectorSize = ((NGramIterator)fullSetIterator).vectorSize;
truncateLength = ((NGramIterator)fullSetIterator).maxSentences;
fullSetIterator = new NGramIterator(patientExamples, new CharacterTrigram(NGramIterator.createPatientLines(patientExamples)), BATCH_SIZE);

int nOutFF = 150;
int lstmLayerSize = 128;
Expand All @@ -118,8 +82,6 @@ protected void initializeNetworkBinaryMultiLabelDeep() {
double adaGradDense = 0.01;
double adaGradGraves = 0.008;

saveParams();

// seed for reproducibility
final int seed = 12345;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)
Expand All @@ -128,7 +90,7 @@ protected void initializeNetworkBinaryMultiLabelDeep() {
.gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SINGLE)
.inferenceWorkspaceMode(WorkspaceMode.SINGLE).list()

.layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(vectorSize).nOut(nOutFF)
.layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(fullSetIterator.getInputRepresentation().getVectorSize()).nOut(nOutFF)
.weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).build())
Expand Down Expand Up @@ -170,57 +132,4 @@ protected void initializeNetworkBinaryMultiLabelDeep() {
protected String getModelName() {
return "BILSTMC3G_MBL";
}

protected void saveParams() {
File root = getModelDirectory(patientExamples);

try {
// writing our character n-grams
FileOutputStream fos = new FileOutputStream(new File(root, "characterNGram_3"));
ObjectOutputStream oos = new ObjectOutputStream(fos);
oos.writeObject(((NGramIterator)fullSetIterator).characterNGram_3);
oos.flush();
oos.close();
fos.close();

// writing our character n-grams
fos = new FileOutputStream(new File(root, "char3GramToIdxMap"));
oos = new ObjectOutputStream(fos);
oos.writeObject(((NGramIterator)fullSetIterator).char3GramToIdxMap);
oos.flush();
oos.close();
fos.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}

/**
* Load features from narrative.
*
* @param reviewContents
* Narrative content.
* @param maxLength
* Maximum length of token series length.
* @return Time series feature presentation of narrative.
*/
protected INDArray loadFeaturesForNarrative(String reviewContents, int maxLength) {

List<String> sentences = DataUtilities.getSentences(reviewContents);

int outputLength = Math.min(maxLength, sentences.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);

for (int j = 0; j < sentences.size() && j < outputLength; j++) {
String sentence = sentences.get(j);
INDArray vector = ((NGramIterator)fullSetIterator).getChar3GramVectorToSentence(sentence);
features.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j) },
vector);
}
return features;
}


}
24 changes: 11 additions & 13 deletions src/main/java/at/medunigraz/imi/bst/n2c2/nn/BaseNNClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.model.Eligibility;
import at.medunigraz.imi.bst.n2c2.model.Patient;
import at.medunigraz.imi.bst.n2c2.nn.iterator.BaseNNIterator;
import at.medunigraz.imi.bst.n2c2.util.DatasetUtil;
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
Expand All @@ -19,7 +20,6 @@
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.io.*;
Expand All @@ -35,18 +35,13 @@ public abstract class BaseNNClassifier extends PatientBasedClassifier {
// size of mini-batch for training
protected static final int BATCH_SIZE = 10;

// specifies time series length
protected int truncateLength = 64;

public int vectorSize;

// training data
protected List<Patient> patientExamples;

// multi layer network
protected MultiLayerNetwork net;

public DataSetIterator fullSetIterator;
public BaseNNIterator fullSetIterator;

/**
* Training for binary multi label classifcation.
Expand Down Expand Up @@ -135,7 +130,7 @@ protected void saveModel(int epoch) {
props.setProperty(getModelName() + ".bestModelEpoch", new Integer(epoch).toString());

// TODO truncateLength does not change each epoch, this could be persisted in saveParams()
props.setProperty(getModelName() + ".truncateLength", new Integer(truncateLength).toString());
props.setProperty(getModelName() + ".truncateLength", new Integer(fullSetIterator.getTruncateLength()).toString());

File f = new File(root, getModelName() + ".properties");
OutputStream out = new FileOutputStream(f);
Expand All @@ -149,6 +144,8 @@ protected void saveModel(int epoch) {
} catch (IOException e) {
e.printStackTrace();
}

fullSetIterator.save(root);
}

protected abstract String getModelName();
Expand All @@ -162,7 +159,7 @@ protected void saveModel(int epoch) {
public Map<Criterion, Double> predict(Patient p) {
String patientNarrative = p.getText();

INDArray features = loadFeaturesForNarrative(patientNarrative, this.truncateLength);
INDArray features = fullSetIterator.loadFeaturesForNarrative(patientNarrative, fullSetIterator.getTruncateLength());
INDArray networkOutput = net.output(features);

int timeSeriesLength = networkOutput.size(2);
Expand Down Expand Up @@ -212,7 +209,7 @@ public void train(List<Patient> examples) {
// initializeMonitoring();

LOG.info("Minibatchsize :\t" + BATCH_SIZE);
LOG.info("Truncate length:\t" + truncateLength);
LOG.info("Truncate length:\t" + fullSetIterator.getTruncateLength());

trainFullSetBML();
}
Expand All @@ -236,19 +233,20 @@ public boolean isTrained(List<Patient> patients) {
return new File(getModelPath(patients), getModelName() + ".properties").exists();
}

protected abstract INDArray loadFeaturesForNarrative(String reviewContents, int maxLength);

public void initializeNetworkFromFile(String pathToModel) {
try {
Properties prop = loadProperties(pathToModel);
final int bestEpoch = Integer.parseInt(prop.getProperty(getModelName() + ".bestModelEpoch"));
this.truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength"));

File networkFile = new File(pathToModel, getModelName() + "_" + bestEpoch + ".zip");
this.net = ModelSerializer.restoreMultiLayerNetwork(networkFile);


} catch (IOException e) {
e.printStackTrace();
}

fullSetIterator.load(new File(pathToModel));
}

/**
Expand Down
Loading

0 comments on commit d524666

Please sign in to comment.