Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor NN code #111

Merged
merged 16 commits into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import at.medunigraz.imi.bst.n2c2.classifier.Classifier;
import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.nn.LSTMClassifier;
import at.medunigraz.imi.bst.n2c2.nn.LSTMEmbeddingsClassifier;

public class NNClassifierFactory implements ClassifierFactory {

private static final Classifier classifier = new LSTMClassifier();
private static final Classifier classifier = new LSTMEmbeddingsClassifier();

@Override
public Classifier getClassifier(Criterion criterion) {
Expand Down
225 changes: 0 additions & 225 deletions src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifier.java

This file was deleted.

40 changes: 18 additions & 22 deletions src/main/java/at/medunigraz/imi/bst/n2c2/nn/BaseNNClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.model.Eligibility;
import at.medunigraz.imi.bst.n2c2.model.Patient;
import at.medunigraz.imi.bst.n2c2.nn.iterator.BaseNNIterator;
import at.medunigraz.imi.bst.n2c2.util.DatasetUtil;
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
Expand All @@ -19,7 +20,6 @@
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.io.*;
Expand All @@ -35,18 +35,13 @@ public abstract class BaseNNClassifier extends PatientBasedClassifier {
// size of mini-batch for training
protected static final int BATCH_SIZE = 10;

// specifies time series length
protected int truncateLength = 64;

public int vectorSize;

// training data
protected List<Patient> patientExamples;

// multi layer network
protected MultiLayerNetwork net;

public DataSetIterator fullSetIterator;
public BaseNNIterator fullSetIterator;

/**
* Training for binary multi label classifcation.
Expand Down Expand Up @@ -135,7 +130,7 @@ protected void saveModel(int epoch) {
props.setProperty(getModelName() + ".bestModelEpoch", new Integer(epoch).toString());

// TODO truncateLength does not change each epoch, this could be persisted in saveParams()
props.setProperty(getModelName() + ".truncateLength", new Integer(truncateLength).toString());
props.setProperty(getModelName() + ".truncateLength", new Integer(fullSetIterator.getTruncateLength()).toString());

File f = new File(root, getModelName() + ".properties");
OutputStream out = new FileOutputStream(f);
Expand All @@ -149,6 +144,8 @@ protected void saveModel(int epoch) {
} catch (IOException e) {
e.printStackTrace();
}

fullSetIterator.save(root);
}

protected abstract String getModelName();
Expand All @@ -162,7 +159,7 @@ protected void saveModel(int epoch) {
public Map<Criterion, Double> predict(Patient p) {
String patientNarrative = p.getText();

INDArray features = loadFeaturesForNarrative(patientNarrative, this.truncateLength);
INDArray features = fullSetIterator.loadFeaturesForNarrative(patientNarrative, fullSetIterator.getTruncateLength());
INDArray networkOutput = net.output(features);

int timeSeriesLength = networkOutput.size(2);
Expand Down Expand Up @@ -203,7 +200,11 @@ public Eligibility predict(Patient p, Criterion c) {
@Override
public void train(List<Patient> examples) {
if (isTrained(examples)) {
initializeNetworkFromFile(getModelPath(examples));
try {
initializeNetworkFromFile(getModelPath(examples));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
else {
this.patientExamples = examples;
Expand All @@ -212,7 +213,7 @@ public void train(List<Patient> examples) {
// initializeMonitoring();

LOG.info("Minibatchsize :\t" + BATCH_SIZE);
LOG.info("Truncate length:\t" + truncateLength);
LOG.info("Truncate length:\t" + fullSetIterator.getTruncateLength());

trainFullSetBML();
}
Expand All @@ -236,19 +237,14 @@ public boolean isTrained(List<Patient> patients) {
return new File(getModelPath(patients), getModelName() + ".properties").exists();
}

protected abstract INDArray loadFeaturesForNarrative(String reviewContents, int maxLength);
public void initializeNetworkFromFile(String pathToModel) throws IOException {
Properties prop = loadProperties(pathToModel);
final int bestEpoch = Integer.parseInt(prop.getProperty(getModelName() + ".bestModelEpoch"));

public void initializeNetworkFromFile(String pathToModel) {
try {
Properties prop = loadProperties(pathToModel);
final int bestEpoch = Integer.parseInt(prop.getProperty(getModelName() + ".bestModelEpoch"));
this.truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength"));
File networkFile = new File(pathToModel, getModelName() + "_" + bestEpoch + ".zip");
this.net = ModelSerializer.restoreMultiLayerNetwork(networkFile);

File networkFile = new File(pathToModel, getModelName() + "_" + bestEpoch + ".zip");
this.net = ModelSerializer.restoreMultiLayerNetwork(networkFile);
} catch (IOException e) {
e.printStackTrace();
}
fullSetIterator.load(new File(pathToModel));
}

/**
Expand Down
Loading