Skip to content

Commit

Permalink
Rename iterators and classifiers for better readability
Browse files Browse the repository at this point in the history
  • Loading branch information
michelole committed Jun 3, 2019
1 parent d524666 commit 65ba2a4
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import at.medunigraz.imi.bst.n2c2.classifier.Classifier;
import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.nn.LSTMClassifier;
import at.medunigraz.imi.bst.n2c2.nn.LSTMEmbeddingsClassifier;

public class NNClassifierFactory implements ClassifierFactory {

private static final Classifier classifier = new LSTMClassifier();
private static final Classifier classifier = new LSTMEmbeddingsClassifier();

@Override
public Classifier getClassifier(Criterion criterion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import at.medunigraz.imi.bst.n2c2.model.Criterion;
import at.medunigraz.imi.bst.n2c2.nn.input.CharacterTrigram;
import at.medunigraz.imi.bst.n2c2.nn.iterator.NGramIterator;
import at.medunigraz.imi.bst.n2c2.nn.iterator.SentenceIterator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.nn.conf.GradientNormalization;
Expand All @@ -32,7 +32,7 @@
* @author Markus
*
*/
public class BILSTMC3GClassifier extends BaseNNClassifier {
public class BiLSTMCharacterTrigramClassifier extends BaseNNClassifier {

private static final Logger LOG = LogManager.getLogger();

Expand All @@ -49,7 +49,7 @@ public void initializeNetworkFromFile(String pathToModel) {
try {
prop = loadProperties(pathToModel);
final int truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength"));
fullSetIterator = new NGramIterator(new CharacterTrigram(), truncateLength, BATCH_SIZE);
fullSetIterator = new SentenceIterator(new CharacterTrigram(), truncateLength, BATCH_SIZE);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -73,7 +73,7 @@ protected void initializeNetworkBinaryMultiLabelDeep() {
Nd4j.getMemoryManager().setAutoGcWindow(10000);
// Nd4j.getMemoryManager().togglePeriodicGc(false);

fullSetIterator = new NGramIterator(patientExamples, new CharacterTrigram(NGramIterator.createPatientLines(patientExamples)), BATCH_SIZE);
fullSetIterator = new SentenceIterator(patientExamples, new CharacterTrigram(SentenceIterator.createPatientLines(patientExamples)), BATCH_SIZE);

int nOutFF = 150;
int lstmLayerSize = 128;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import java.util.Properties;

import at.medunigraz.imi.bst.n2c2.nn.input.WordEmbedding;
import at.medunigraz.imi.bst.n2c2.nn.iterator.N2c2PatientIteratorBML;
import at.medunigraz.imi.bst.n2c2.nn.iterator.TokenIterator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.nn.conf.GradientNormalization;
Expand All @@ -22,18 +22,16 @@
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import at.medunigraz.imi.bst.n2c2.model.Patient;

/**
* LSTM classifier for n2c2 task 2018 refactored from dl4j examples.
*
* @author Markus
*
*/
public class LSTMClassifier extends BaseNNClassifier {
public class LSTMEmbeddingsClassifier extends BaseNNClassifier {

// location of precalculated vectors
private static final File PRETRAINED_VECTORS = new File(LSTMClassifier.class.getClassLoader().getResource("vectors.vec").getFile());
private static final File PRETRAINED_VECTORS = new File(LSTMEmbeddingsClassifier.class.getClassLoader().getResource("vectors.vec").getFile());

// logging
private static final Logger LOG = LogManager.getLogger();
Expand All @@ -47,7 +45,7 @@ private void initializeNetworkBinaryMultiLabelDebug() {

Nd4j.getMemoryManager().setAutoGcWindow(10000); // https://deeplearning4j.org/workspaces

fullSetIterator = new N2c2PatientIteratorBML(patientExamples, new WordEmbedding(PRETRAINED_VECTORS), BATCH_SIZE);
fullSetIterator = new TokenIterator(patientExamples, new WordEmbedding(PRETRAINED_VECTORS), BATCH_SIZE);

// Set up network configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(0)
Expand Down Expand Up @@ -81,7 +79,7 @@ public void initializeNetworkFromFile(String pathToModel) {
try {
prop = loadProperties(pathToModel);
final int truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength"));
fullSetIterator = new N2c2PatientIteratorBML(new WordEmbedding(PRETRAINED_VECTORS), truncateLength, BATCH_SIZE);
fullSetIterator = new TokenIterator(new WordEmbedding(PRETRAINED_VECTORS), truncateLength, BATCH_SIZE);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import at.medunigraz.imi.bst.n2c2.model.Patient;

/**
* A character 3-gram DataSetIterator.
* A sentence iterator.
*
* @author Markus
*/
public class NGramIterator extends BaseNNIterator {
public class SentenceIterator extends BaseNNIterator {

private static final long serialVersionUID = 1L;

Expand All @@ -34,7 +34,7 @@ public class NGramIterator extends BaseNNIterator {
* @param batchSize
* Minibatch size.
*/
public NGramIterator(List<Patient> patients, InputRepresentation inputRepresentation, int batchSize) {
public SentenceIterator(List<Patient> patients, InputRepresentation inputRepresentation, int batchSize) {
super(inputRepresentation);

this.patients = patients;
Expand All @@ -50,7 +50,7 @@ public NGramIterator(List<Patient> patients, InputRepresentation inputRepresenta
* @param truncateLength
* @param batchSize
*/
public NGramIterator(InputRepresentation inputRepresentation, int truncateLength, int batchSize) {
public SentenceIterator(InputRepresentation inputRepresentation, int truncateLength, int batchSize) {
super(inputRepresentation);

this.truncateLength = truncateLength;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* @author Markus
*
*/
public class N2c2PatientIteratorBML extends BaseNNIterator {
public class TokenIterator extends BaseNNIterator {
private static final Logger LOG = LogManager.getLogger();

private static final long serialVersionUID = 1L;
Expand All @@ -37,7 +37,7 @@ public class N2c2PatientIteratorBML extends BaseNNIterator {
* @param batchSize
* Mini batch size use for processing.
*/
public N2c2PatientIteratorBML(List<Patient> patients, InputRepresentation inputRepresentation, int batchSize) {
public TokenIterator(List<Patient> patients, InputRepresentation inputRepresentation, int batchSize) {
super(inputRepresentation);

this.patients = patients;
Expand All @@ -55,7 +55,7 @@ public N2c2PatientIteratorBML(List<Patient> patients, InputRepresentation inputR
* @param truncateLength
* @param batchSize
*/
public N2c2PatientIteratorBML(InputRepresentation inputRepresentation, int truncateLength, int batchSize) {
public TokenIterator(InputRepresentation inputRepresentation, int truncateLength, int batchSize) {
super(inputRepresentation);

this.truncateLength = truncateLength;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

public abstract class BaseNNClassifierTest {

protected static final File SAMPLE = new File(BILSTMC3GClassifierTest.class.getResource("/gold-standard/sample.xml").getPath());
protected static final File SAMPLE = new File(BaseNNClassifierTest.class.getResource("/gold-standard/sample.xml").getPath());

protected BaseNNClassifier classifier;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import at.medunigraz.imi.bst.n2c2.dao.PatientDAO;
import at.medunigraz.imi.bst.n2c2.model.Patient;

public class BILSTMC3GClassifierTest extends BaseNNClassifierTest {
public class BiLSTMCharacterTrigramClassifierTest extends BaseNNClassifierTest {

public BILSTMC3GClassifierTest() {
this.classifier = new BILSTMC3GClassifier();
public BiLSTMCharacterTrigramClassifierTest() {
this.classifier = new BiLSTMCharacterTrigramClassifier();
}

@Test
Expand All @@ -27,13 +27,13 @@ public void saveAndLoad() throws IOException, SAXException {
train.add(p);

// We first train on some examples...
BILSTMC3GClassifier trainClassifier = new BILSTMC3GClassifier();
BiLSTMCharacterTrigramClassifier trainClassifier = new BiLSTMCharacterTrigramClassifier();
trainClassifier.deleteModelDir(train); // Delete any previously trained models, to ensure training is tested
trainClassifier.train(train); // This should persist models
assertTrue(trainClassifier.isTrained(train));

// ... and then try to load the model on a new instance.
BILSTMC3GClassifier testClassifier = new BILSTMC3GClassifier();
BiLSTMCharacterTrigramClassifier testClassifier = new BiLSTMCharacterTrigramClassifier();
assertTrue(testClassifier.isTrained(train));
// TODO use Mockito to call train() and ensure trainFullSetBMC is NOT called.
testClassifier.initializeNetworkFromFile(BaseNNClassifier.getModelPath(train));
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package at.medunigraz.imi.bst.n2c2.nn;

public class LSTMEmbeddingsClassifierTest extends BaseNNClassifierTest {

public LSTMEmbeddingsClassifierTest() {
this.classifier = new LSTMEmbeddingsClassifier();
}
}

0 comments on commit 65ba2a4

Please sign in to comment.