diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifier.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifier.java index ef8ac27..a847b49 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifier.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifier.java @@ -1,18 +1,10 @@ package at.medunigraz.imi.bst.n2c2.nn; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.FileOutputStream; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.Properties; import at.medunigraz.imi.bst.n2c2.model.Criterion; +import at.medunigraz.imi.bst.n2c2.nn.input.CharacterTrigram; import at.medunigraz.imi.bst.n2c2.nn.iterator.NGramIterator; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -30,10 +22,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -47,40 +36,22 @@ public class BILSTMC3GClassifier extends BaseNNClassifier { private static final Logger LOG = LogManager.getLogger(); + @Override public void initializeNetworkFromFile(String pathToModel) { - // settings for memory management: // https://deeplearning4j.org/workspaces Nd4j.getMemoryManager().setAutoGcWindow(10000); // Nd4j.getMemoryManager().togglePeriodicGc(false); - // instantiating generator - fullSetIterator = new NGramIterator(); - + // TODO move to iterator. + Properties prop = null; try { - // read char 3-grams and index - FileInputStream fis = new FileInputStream(new File(pathToModel, "characterNGram_3")); - ObjectInputStream ois = new ObjectInputStream(fis); - ArrayList characterNGram_3 = (ArrayList) ois.readObject(); - - ((NGramIterator)fullSetIterator).characterNGram_3 = characterNGram_3; - ((NGramIterator)fullSetIterator).vectorSize = characterNGram_3.size(); - this.vectorSize = ((NGramIterator)fullSetIterator).vectorSize; - - // read char 3-grams index - fis = new FileInputStream(new File(pathToModel, "char3GramToIdxMap")); - ois = new ObjectInputStream(fis); - Map char3GramToIdxMap_0 = (HashMap) ois.readObject(); - ((NGramIterator)fullSetIterator).char3GramToIdxMap = char3GramToIdxMap_0; - - Nd4j.getRandom().setSeed(12345); - + prop = loadProperties(pathToModel); + final int truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength")); + fullSetIterator = new NGramIterator(new CharacterTrigram(), truncateLength, BATCH_SIZE); } catch (IOException e) { - e.printStackTrace(); - } catch (ClassNotFoundException e) { - // TODO Auto-generated catch block - e.printStackTrace(); + throw new RuntimeException(e); } super.initializeNetworkFromFile(pathToModel); @@ -102,14 +73,7 @@ protected void initializeNetworkBinaryMultiLabelDeep() { Nd4j.getMemoryManager().setAutoGcWindow(10000); // Nd4j.getMemoryManager().togglePeriodicGc(false); - try { - fullSetIterator = new NGramIterator(patientExamples, BATCH_SIZE); - } catch (IOException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - vectorSize = ((NGramIterator)fullSetIterator).vectorSize; - truncateLength = ((NGramIterator)fullSetIterator).maxSentences; + fullSetIterator = new NGramIterator(patientExamples, new CharacterTrigram(NGramIterator.createPatientLines(patientExamples)), BATCH_SIZE); int nOutFF = 150; int lstmLayerSize = 128; @@ -118,8 +82,6 @@ protected void initializeNetworkBinaryMultiLabelDeep() { double adaGradDense = 0.01; double adaGradGraves = 0.008; - saveParams(); - // seed for reproducibility final int seed = 12345; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) @@ -128,7 +90,7 @@ protected void initializeNetworkBinaryMultiLabelDeep() { .gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SINGLE) .inferenceWorkspaceMode(WorkspaceMode.SINGLE).list() - .layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(vectorSize).nOut(nOutFF) + .layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(fullSetIterator.getInputRepresentation().getVectorSize()).nOut(nOutFF) .weightInit(WeightInit.RELU).updater(AdaGrad.builder().learningRate(adaGradDense).build()) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) .gradientNormalizationThreshold(10).build()) @@ -170,57 +132,4 @@ protected void initializeNetworkBinaryMultiLabelDeep() { protected String getModelName() { return "BILSTMC3G_MBL"; } - - protected void saveParams() { - File root = getModelDirectory(patientExamples); - - try { - // writing our character n-grams - FileOutputStream fos = new FileOutputStream(new File(root, "characterNGram_3")); - ObjectOutputStream oos = new ObjectOutputStream(fos); - oos.writeObject(((NGramIterator)fullSetIterator).characterNGram_3); - oos.flush(); - oos.close(); - fos.close(); - - // writing our character n-grams - fos = new FileOutputStream(new File(root, "char3GramToIdxMap")); - oos = new ObjectOutputStream(fos); - oos.writeObject(((NGramIterator)fullSetIterator).char3GramToIdxMap); - oos.flush(); - oos.close(); - fos.close(); - } catch (FileNotFoundException e) { - e.printStackTrace(); - } catch (IOException e) { - e.printStackTrace(); - } - } - - /** - * Load features from narrative. - * - * @param reviewContents - * Narrative content. - * @param maxLength - * Maximum length of token series length. - * @return Time series feature presentation of narrative. - */ - protected INDArray loadFeaturesForNarrative(String reviewContents, int maxLength) { - - List sentences = DataUtilities.getSentences(reviewContents); - - int outputLength = Math.min(maxLength, sentences.size()); - INDArray features = Nd4j.create(1, vectorSize, outputLength); - - for (int j = 0; j < sentences.size() && j < outputLength; j++) { - String sentence = sentences.get(j); - INDArray vector = ((NGramIterator)fullSetIterator).getChar3GramVectorToSentence(sentence); - features.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j) }, - vector); - } - return features; - } - - } diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BaseNNClassifier.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BaseNNClassifier.java index 72bf125..860f511 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BaseNNClassifier.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/BaseNNClassifier.java @@ -5,6 +5,7 @@ import at.medunigraz.imi.bst.n2c2.model.Criterion; import at.medunigraz.imi.bst.n2c2.model.Eligibility; import at.medunigraz.imi.bst.n2c2.model.Patient; +import at.medunigraz.imi.bst.n2c2.nn.iterator.BaseNNIterator; import at.medunigraz.imi.bst.n2c2.util.DatasetUtil; import org.apache.commons.io.FileUtils; import org.apache.logging.log4j.LogManager; @@ -19,7 +20,6 @@ import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.indexing.NDArrayIndex; import java.io.*; @@ -35,18 +35,13 @@ public abstract class BaseNNClassifier extends PatientBasedClassifier { // size of mini-batch for training protected static final int BATCH_SIZE = 10; - // specifies time series length - protected int truncateLength = 64; - - public int vectorSize; - // training data protected List patientExamples; // multi layer network protected MultiLayerNetwork net; - public DataSetIterator fullSetIterator; + public BaseNNIterator fullSetIterator; /** * Training for binary multi label classifcation. @@ -135,7 +130,7 @@ protected void saveModel(int epoch) { props.setProperty(getModelName() + ".bestModelEpoch", new Integer(epoch).toString()); // TODO truncateLength does not change each epoch, this could be persisted in saveParams() - props.setProperty(getModelName() + ".truncateLength", new Integer(truncateLength).toString()); + props.setProperty(getModelName() + ".truncateLength", new Integer(fullSetIterator.getTruncateLength()).toString()); File f = new File(root, getModelName() + ".properties"); OutputStream out = new FileOutputStream(f); @@ -149,6 +144,8 @@ protected void saveModel(int epoch) { } catch (IOException e) { e.printStackTrace(); } + + fullSetIterator.save(root); } protected abstract String getModelName(); @@ -162,7 +159,7 @@ protected void saveModel(int epoch) { public Map predict(Patient p) { String patientNarrative = p.getText(); - INDArray features = loadFeaturesForNarrative(patientNarrative, this.truncateLength); + INDArray features = fullSetIterator.loadFeaturesForNarrative(patientNarrative, fullSetIterator.getTruncateLength()); INDArray networkOutput = net.output(features); int timeSeriesLength = networkOutput.size(2); @@ -212,7 +209,7 @@ public void train(List examples) { // initializeMonitoring(); LOG.info("Minibatchsize :\t" + BATCH_SIZE); - LOG.info("Truncate length:\t" + truncateLength); + LOG.info("Truncate length:\t" + fullSetIterator.getTruncateLength()); trainFullSetBML(); } @@ -236,19 +233,20 @@ public boolean isTrained(List patients) { return new File(getModelPath(patients), getModelName() + ".properties").exists(); } - protected abstract INDArray loadFeaturesForNarrative(String reviewContents, int maxLength); - public void initializeNetworkFromFile(String pathToModel) { try { Properties prop = loadProperties(pathToModel); final int bestEpoch = Integer.parseInt(prop.getProperty(getModelName() + ".bestModelEpoch")); - this.truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength")); File networkFile = new File(pathToModel, getModelName() + "_" + bestEpoch + ".zip"); this.net = ModelSerializer.restoreMultiLayerNetwork(networkFile); + + } catch (IOException e) { e.printStackTrace(); } + + fullSetIterator.load(new File(pathToModel)); } /** diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMClassifier.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMClassifier.java index 090b521..f1a9339 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMClassifier.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/LSTMClassifier.java @@ -1,16 +1,13 @@ package at.medunigraz.imi.bst.n2c2.nn; import java.io.File; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.io.IOException; +import java.util.Properties; +import at.medunigraz.imi.bst.n2c2.nn.input.WordEmbedding; import at.medunigraz.imi.bst.n2c2.nn.iterator.N2c2PatientIteratorBML; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; -import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -20,14 +17,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; -import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -41,34 +32,14 @@ */ public class LSTMClassifier extends BaseNNClassifier { - // accessing word vectors - private WordVectors wordVectors; - - // tokenizer logic - private TokenizerFactory tokenizerFactory; - // location of precalculated vectors private static final File PRETRAINED_VECTORS = new File(LSTMClassifier.class.getClassLoader().getResource("vectors.vec").getFile()); - // word vector size - private static final int PRETRAINED_VECTORS_DIMENSION = 200; - // logging private static final Logger LOG = LogManager.getLogger(); - public LSTMClassifier() { - this.wordVectors = WordVectorSerializer.loadStaticModel(PRETRAINED_VECTORS); - initializeTokenizer(); - } - - private void initializeTokenizer() { - tokenizerFactory = new DefaultTokenizerFactory(); - tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); - } - @Override protected void initializeNetwork() { - initializeTruncateLength(); initializeNetworkBinaryMultiLabelDebug(); } @@ -76,7 +47,7 @@ private void initializeNetworkBinaryMultiLabelDebug() { Nd4j.getMemoryManager().setAutoGcWindow(10000); // https://deeplearning4j.org/workspaces - fullSetIterator = new N2c2PatientIteratorBML(patientExamples, wordVectors, BATCH_SIZE, truncateLength); + fullSetIterator = new N2c2PatientIteratorBML(patientExamples, new WordEmbedding(PRETRAINED_VECTORS), BATCH_SIZE); // Set up network configuration MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(0) @@ -84,7 +55,7 @@ private void initializeNetworkBinaryMultiLabelDebug() { .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) .gradientNormalizationThreshold(1.0).trainingWorkspaceMode(WorkspaceMode.SEPARATE) .inferenceWorkspaceMode(WorkspaceMode.SEPARATE) // https://deeplearning4j.org/workspaces - .list().layer(0, new GravesLSTM.Builder().nIn(PRETRAINED_VECTORS_DIMENSION).nOut(256).activation(Activation.TANH).build()) + .list().layer(0, new GravesLSTM.Builder().nIn(fullSetIterator.getInputRepresentation().getVectorSize()).nOut(256).activation(Activation.TANH).build()) .layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID) .lossFunction(LossFunctions.LossFunction.XENT).nIn(256).nOut(13).build()) @@ -99,84 +70,22 @@ private void initializeNetworkBinaryMultiLabelDebug() { this.net.setListeners(new ScoreIterationListener(1)); } - /** - * Get longest token sequence of all patients with respect to existing word - * vector out of Google corpus. - * - */ - private void initializeTruncateLength() { - - // type coverage - Set corpusTypes = new HashSet(); - Set matchedTypes = new HashSet(); - - // token coverage - int filteredSum = 0; - int tokenSum = 0; - - List> allTokens = new ArrayList<>(patientExamples.size()); - int maxLength = 0; - - for (Patient patient : patientExamples) { - String narrative = patient.getText(); - String cleaned = narrative.replaceAll("[\r\n]+", " ").replaceAll("\\s+", " "); - List tokens = tokenizerFactory.create(cleaned).getTokens(); - tokenSum += tokens.size(); - - List tokensFiltered = new ArrayList<>(); - for (String token : tokens) { - corpusTypes.add(token); - if (wordVectors.hasWord(token)) { - tokensFiltered.add(token); - matchedTypes.add(token); - } else { - LOG.info("Word2vec representation missing:\t" + token); - } - } - allTokens.add(tokensFiltered); - filteredSum += tokensFiltered.size(); - - maxLength = Math.max(maxLength, tokensFiltered.size()); - } - - LOG.info("Matched " + matchedTypes.size() + " types out of " + corpusTypes.size()); - LOG.info("Matched " + filteredSum + " tokens out of " + tokenSum); - - this.truncateLength = maxLength; - } - @Override protected String getModelName() { return "LSTMW2V_MBL"; } - /** - * Load features from narrative. - * - * @param reviewContents - * Narrative content. - * @param maxLength - * Maximum length of token series length. - * @return Time series feature presentation of narrative. - */ - protected INDArray loadFeaturesForNarrative(String reviewContents, int maxLength) { - - List tokens = tokenizerFactory.create(reviewContents).getTokens(); - List tokensFiltered = new ArrayList<>(); - for (String t : tokens) { - if (wordVectors.hasWord(t)) - tokensFiltered.add(t); + @Override + public void initializeNetworkFromFile(String pathToModel) { + Properties prop = null; + try { + prop = loadProperties(pathToModel); + final int truncateLength = Integer.parseInt(prop.getProperty(getModelName() + ".truncateLength")); + fullSetIterator = new N2c2PatientIteratorBML(new WordEmbedding(PRETRAINED_VECTORS), truncateLength, BATCH_SIZE); + } catch (IOException e) { + throw new RuntimeException(e); } - int outputLength = Math.min(maxLength, tokensFiltered.size()); - - INDArray features = Nd4j.create(1, PRETRAINED_VECTORS_DIMENSION, outputLength); - for (int j = 0; j < tokensFiltered.size() && j < maxLength; j++) { - String token = tokensFiltered.get(j); - INDArray vector = wordVectors.getWordVectorMatrix(token); - features.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j) }, - vector); - } - return features; + super.initializeNetworkFromFile(pathToModel); } } diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/CharacterTrigram.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/CharacterTrigram.java new file mode 100644 index 0000000..c79397c --- /dev/null +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/CharacterTrigram.java @@ -0,0 +1,166 @@ +package at.medunigraz.imi.bst.n2c2.nn.input; + +import at.medunigraz.imi.bst.n2c2.nn.DataUtilities; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.*; +import java.util.*; + +public class CharacterTrigram implements InputRepresentation { + + private ArrayList characterNGram_3 = new ArrayList(); + + private Map char3GramToIdxMap = new HashMap(); + + public CharacterTrigram() { + // TODO same as load + } + + public CharacterTrigram(Map> integerListMap) { + // generate char 3 grams + try { + fillCharNGramsMaps(integerListMap); + } catch (IOException e) { + throw new RuntimeException(e); + } + + // generate index + this.createIndizes(); + } + + public void save(File model) { + File root = model; + + try { + // writing our character n-grams + FileOutputStream fos = new FileOutputStream(new File(root, "characterNGram_3")); + ObjectOutputStream oos = new ObjectOutputStream(fos); + oos.writeObject(characterNGram_3); + oos.flush(); + oos.close(); + fos.close(); + + // writing our character n-grams + fos = new FileOutputStream(new File(root, "char3GramToIdxMap")); + oos = new ObjectOutputStream(fos); + oos.writeObject(char3GramToIdxMap); + oos.flush(); + oos.close(); + fos.close(); + } catch (FileNotFoundException e) { + e.printStackTrace(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + public void load(File model) { + try { + // read char 3-grams and index + FileInputStream fis = new FileInputStream(new File(model, "characterNGram_3")); + ObjectInputStream ois = new ObjectInputStream(fis); + ArrayList characterNGram_3 = (ArrayList) ois.readObject(); + + this.characterNGram_3 = characterNGram_3; + + // read char 3-grams index + fis = new FileInputStream(new File(model, "char3GramToIdxMap")); + ois = new ObjectInputStream(fis); + Map char3GramToIdxMap_0 = (HashMap) ois.readObject(); + this.char3GramToIdxMap = char3GramToIdxMap_0; + + Nd4j.getRandom().setSeed(12345); + + } catch (IOException e) { + e.printStackTrace(); + } catch (ClassNotFoundException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + } + + @Override + public INDArray getVector(String unit) { + return getChar3GramVectorToSentence(unit); + } + + @Override + public boolean hasRepresentation(String unit) { + return char3GramToIdxMap.containsKey(unit); + } + + @Override + public int getVectorSize() { + return characterNGram_3.size(); + } + + /** + * Creates index for character 3-grams. + */ + private void createIndizes() { + // store indexes + for (int i = 0; i < characterNGram_3.size(); i++) + char3GramToIdxMap.put(characterNGram_3.get(i), i); + } + + /** + * Fills character 3-gram dictionary. + * + * @throws IOException + */ + private void fillCharNGramsMaps(Map> integerListMap) throws IOException { + // TODO operate on a single List with all sentences + for (Map.Entry> entry : integerListMap.entrySet()) { + for (String line : entry.getValue()) { + String normalized = DataUtilities.processTextReduced(line); + String char3Grams = DataUtilities.getChar3GramRepresentation(normalized); + + // process character n-grams + String[] char3Splits = char3Grams.split("\\s+"); + + for (String split : char3Splits) { + if (!characterNGram_3.contains(split)) { + characterNGram_3.add(split); + } + } + } + + // adding out of dictionary entries + characterNGram_3.add("OOD"); + } + } + + /** + * Sentence will be transformed to a character 3-gram vector. + * + * @param sentence + * Sentence which gets vector representation. + * @return + */ + public INDArray getChar3GramVectorToSentence(String sentence) { + + INDArray featureVector = Nd4j.zeros(getVectorSize()); + try { + String normalized = DataUtilities.processTextReduced(sentence); + String char3Grams = DataUtilities.getChar3GramRepresentation(normalized); + + // process character n-grams + String[] char3Splits = char3Grams.split("\\s+"); + + for (String split : char3Splits) { + if (char3GramToIdxMap.get(split) == null) { + int nGramIndex = char3GramToIdxMap.get("OOD"); + featureVector.putScalar(new int[] { nGramIndex }, 1.0); + } else { + int nGramIndex = char3GramToIdxMap.get(split); + featureVector.putScalar(new int[] { nGramIndex }, 1.0); + } + } + } catch (IOException e) { + e.printStackTrace(); + } + return featureVector; + } +} diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/InputRepresentation.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/InputRepresentation.java new file mode 100644 index 0000000..dfcf81f --- /dev/null +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/InputRepresentation.java @@ -0,0 +1,18 @@ +package at.medunigraz.imi.bst.n2c2.nn.input; + +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.*; + +public interface InputRepresentation { + + INDArray getVector(String unit); + + boolean hasRepresentation(String unit); + + int getVectorSize(); + + void save(File model); + + void load(File model); +} diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/WordEmbedding.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/WordEmbedding.java new file mode 100644 index 0000000..cf1603f --- /dev/null +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/input/WordEmbedding.java @@ -0,0 +1,44 @@ +package at.medunigraz.imi.bst.n2c2.nn.input; + +import java.io.File; + +import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; +import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Facade for dl4j's WordVectors + */ +public class WordEmbedding implements InputRepresentation { + + private final WordVectors wordVectors; + + public WordEmbedding(File embeddings) { + wordVectors = WordVectorSerializer.loadStaticModel(embeddings); + } + + @Override + public INDArray getVector(String unit) { + return wordVectors.getWordVectorMatrix(unit); + } + + @Override + public boolean hasRepresentation(String unit) { + return wordVectors.hasWord(unit); + } + + @Override + public int getVectorSize() { + return wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; + } + + @Override + public void save(File model) { + + } + + @Override + public void load(File model) { + // TODO same as constructor + } +} diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/BaseNNIterator.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/BaseNNIterator.java index 5e93050..e502275 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/BaseNNIterator.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/BaseNNIterator.java @@ -3,22 +3,32 @@ import at.medunigraz.imi.bst.n2c2.model.Criterion; import at.medunigraz.imi.bst.n2c2.model.Eligibility; import at.medunigraz.imi.bst.n2c2.model.Patient; +import at.medunigraz.imi.bst.n2c2.nn.input.InputRepresentation; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import java.util.Arrays; -import java.util.List; -import java.util.NoSuchElementException; +import java.io.File; +import java.util.*; public abstract class BaseNNIterator implements DataSetIterator { + private static final Logger LOG = LogManager.getLogger(); + protected final InputRepresentation inputRepresentation; - protected List patients; + protected int truncateLength; + + protected List patients; // TODO separate train and test data, allow iterator on test protected int cursor = 0; protected int batchSize; - public int vectorSize; + + public BaseNNIterator(InputRepresentation inputRepresentation) { + this.inputRepresentation = inputRepresentation; + } /** * Fill multi-hot vector for mulit label classification. @@ -31,6 +41,10 @@ protected void fillBinaryMultiHotVector(List binaryMultiHotVector) { } } + public InputRepresentation getInputRepresentation() { + return inputRepresentation; + } + /* * (non-Javadoc) * @@ -48,7 +62,7 @@ public int totalExamples() { */ @Override public int inputColumns() { - return vectorSize; + return inputRepresentation.getVectorSize(); } /* @@ -198,5 +212,19 @@ public DataSet next(int num) { return getNext(num); } + public int getTruncateLength() { + return truncateLength; + } + public abstract DataSet getNext(int num); + + public abstract INDArray loadFeaturesForNarrative(String reviewContents, int maxLength); + + public void save(File model) { + inputRepresentation.save(model); + } + + public void load(File model) { + inputRepresentation.load(model); + } } diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/N2c2PatientIteratorBML.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/N2c2PatientIteratorBML.java index 1078c1c..3fe4a11 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/N2c2PatientIteratorBML.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/N2c2PatientIteratorBML.java @@ -1,10 +1,10 @@ package at.medunigraz.imi.bst.n2c2.nn.iterator; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; +import java.util.*; -import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; +import at.medunigraz.imi.bst.n2c2.nn.input.InputRepresentation; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; @@ -23,36 +23,43 @@ * */ public class N2c2PatientIteratorBML extends BaseNNIterator { + private static final Logger LOG = LogManager.getLogger(); private static final long serialVersionUID = 1L; - private final WordVectors wordVectors; - - private final int truncateLength; - private final TokenizerFactory tokenizerFactory; - /** * Patient data iterator for the n2c2 task. * * @param patients * Patient data. - * @param wordVectors - * Word vectors object. * @param batchSize * Mini batch size use for processing. - * @param truncateLength - * Maximum length of token sequence. */ - public N2c2PatientIteratorBML(List patients, WordVectors wordVectors, int batchSize, int truncateLength) { + public N2c2PatientIteratorBML(List patients, InputRepresentation inputRepresentation, int batchSize) { + super(inputRepresentation); this.patients = patients; this.batchSize = batchSize; - this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; - this.wordVectors = wordVectors; + tokenizerFactory = new DefaultTokenizerFactory(); + tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); + + initializeTruncateLength(); + } + + /** + * + * @param inputRepresentation + * @param truncateLength + * @param batchSize + */ + public N2c2PatientIteratorBML(InputRepresentation inputRepresentation, int truncateLength, int batchSize) { + super(inputRepresentation); + this.truncateLength = truncateLength; + this.batchSize = batchSize; tokenizerFactory = new DefaultTokenizerFactory(); tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); @@ -90,7 +97,7 @@ public DataSet getNext(int num) { List tokens = tokenizerFactory.create(narrative).getTokens(); List tokensFiltered = new ArrayList<>(); for (String token : tokens) { - if (wordVectors.hasWord(token)) + if (inputRepresentation.hasRepresentation(token)) tokensFiltered.add(token); } allTokens.add(tokensFiltered); @@ -98,11 +105,11 @@ public DataSet getNext(int num) { } // truncate if sequence is longer than truncateLength - if (maxLength > truncateLength) - maxLength = truncateLength; + if (maxLength > getTruncateLength()) + maxLength = getTruncateLength(); - INDArray features = Nd4j.create(narratives.size(), vectorSize, maxLength); - INDArray labels = Nd4j.create(narratives.size(), 13, maxLength); + INDArray features = Nd4j.create(narratives.size(), inputRepresentation.getVectorSize(), maxLength); + INDArray labels = Nd4j.create(narratives.size(), totalOutcomes(), maxLength); INDArray featuresMask = Nd4j.zeros(narratives.size(), maxLength); INDArray labelsMask = Nd4j.zeros(narratives.size(), maxLength); @@ -115,7 +122,7 @@ public DataSet getNext(int num) { // get word vectors for each token in narrative for (int j = 0; j < tokens.size() && j < maxLength; j++) { String token = tokens.get(j); - INDArray vector = wordVectors.getWordVectorMatrix(token); + INDArray vector = inputRepresentation.getVector(token); features.put(new INDArrayIndex[] { NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j) }, vector); @@ -137,4 +144,81 @@ public DataSet getNext(int num) { } return new DataSet(features, labels, featuresMask, labelsMask); } + + /** + * Get longest token sequence of all patients with respect to existing word + * vector out of Google corpus. + * + */ + private void initializeTruncateLength() { + + // type coverage + Set corpusTypes = new HashSet(); + Set matchedTypes = new HashSet(); + + // token coverage + int filteredSum = 0; + int tokenSum = 0; + + List> allTokens = new ArrayList<>(patients.size()); + int maxLength = 0; + + for (Patient patient : patients) { + String narrative = patient.getText(); + String cleaned = narrative.replaceAll("[\r\n]+", " ").replaceAll("\\s+", " "); + List tokens = tokenizerFactory.create(cleaned).getTokens(); + tokenSum += tokens.size(); + + List tokensFiltered = new ArrayList<>(); + for (String token : tokens) { + corpusTypes.add(token); + if (inputRepresentation.hasRepresentation(token)) { + tokensFiltered.add(token); + matchedTypes.add(token); + } else { + LOG.info("Word2vec representation missing:\t" + token); + } + } + allTokens.add(tokensFiltered); + filteredSum += tokensFiltered.size(); + + maxLength = Math.max(maxLength, tokensFiltered.size()); + } + + LOG.info("Matched " + matchedTypes.size() + " types out of " + corpusTypes.size()); + LOG.info("Matched " + filteredSum + " tokens out of " + tokenSum); + + this.truncateLength = maxLength; + } + + /** + * Load features from narrative. + * + * @param reviewContents + * Narrative content. + * @param maxLength + * Maximum length of token series length. + * @return Time series feature presentation of narrative. + */ + @Override + public INDArray loadFeaturesForNarrative(String reviewContents, int maxLength) { + + List tokens = tokenizerFactory.create(reviewContents).getTokens(); + List tokensFiltered = new ArrayList<>(); + for (String t : tokens) { + if (inputRepresentation.hasRepresentation(t)) + tokensFiltered.add(t); + } + int outputLength = Math.min(maxLength, tokensFiltered.size()); + + INDArray features = Nd4j.create(1, inputRepresentation.getVectorSize(), outputLength); + + for (int j = 0; j < tokensFiltered.size() && j < maxLength; j++) { + String token = tokensFiltered.get(j); + INDArray vector = inputRepresentation.getVector(token); + features.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j) }, + vector); + } + return features; + } } diff --git a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/NGramIterator.java b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/NGramIterator.java index d77ef3f..d9c32fc 100644 --- a/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/NGramIterator.java +++ b/src/main/java/at/medunigraz/imi/bst/n2c2/nn/iterator/NGramIterator.java @@ -1,12 +1,12 @@ package at.medunigraz.imi.bst.n2c2.nn.iterator; -import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import at.medunigraz.imi.bst.n2c2.nn.DataUtilities; +import at.medunigraz.imi.bst.n2c2.nn.input.InputRepresentation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -24,20 +24,7 @@ public class NGramIterator extends BaseNNIterator { private static final long serialVersionUID = 1L; - public ArrayList characterNGram_3 = new ArrayList(); - - public Map char3GramToIdxMap = new HashMap(); - - public int maxSentences = 0; - - Map> patientLines; - - /** - * Default constructor. - * - */ - public NGramIterator() { - } + private Map> patientLines; /** * Iterator representing sentences as character 3-grams. @@ -46,20 +33,28 @@ public NGramIterator() { * List of patients. * @param batchSize * Minibatch size. - * @throws IOException */ - public NGramIterator(List patients, int batchSize) throws IOException { + public NGramIterator(List patients, InputRepresentation inputRepresentation, int batchSize) { + super(inputRepresentation); this.patients = patients; this.batchSize = batchSize; this.patientLines = createPatientLines(patients); + this.truncateLength = calculateMaxSentences(patients); + } - // generate char 3 grams - this.fillCharNGramsMaps(); + /** + * + * @param inputRepresentation + * @param truncateLength + * @param batchSize + */ + public NGramIterator(InputRepresentation inputRepresentation, int truncateLength, int batchSize) { + super(inputRepresentation); - // generate index - this.createIndizes(); + this.truncateLength = truncateLength; + this.batchSize = batchSize; } /** @@ -68,57 +63,27 @@ public NGramIterator(List patients, int batchSize) throws IOException { * @param patients * @return */ - private Map> createPatientLines(List patients) { - this.patientLines = new HashMap>(); + public static Map> createPatientLines(List patients) { + // TODO return List + Map> integerListMap = new HashMap>(); int patientIndex = 0; for (Patient patient : patients) { List tmpLines = DataUtilities.getSentences(patient.getText()); - this.maxSentences = tmpLines.size() > maxSentences ? tmpLines.size() : maxSentences; - this.patientLines.put(patientIndex++, tmpLines); + integerListMap.put(patientIndex++, tmpLines); } - return this.patientLines; - } - - /** - * Creates index for character 3-grams. - */ - private void createIndizes() { - - // store indexes - for (int i = 0; i < characterNGram_3.size(); i++) - char3GramToIdxMap.put(characterNGram_3.get(i), i); + return integerListMap; } - /** - * Fills character 3-gram dictionary. - * - * @throws IOException - */ - private void fillCharNGramsMaps() throws IOException { - - for (Map.Entry> entry : patientLines.entrySet()) { - for (String line : entry.getValue()) { - String normalized = DataUtilities.processTextReduced(line); - String char3Grams = DataUtilities.getChar3GramRepresentation(normalized); - - // process character n-grams - String[] char3Splits = char3Grams.split("\\s+"); - - for (String split : char3Splits) { - if (!characterNGram_3.contains(split)) { - characterNGram_3.add(split); - } - } - } - - // adding out of dictionary entries - characterNGram_3.add("OOD"); + private int calculateMaxSentences(List patients) { + // TODO reuse patientLines? + int maxSentences = 0; + for (Patient patient : patients) { + List tmpLines = DataUtilities.getSentences(patient.getText()); + maxSentences = tmpLines.size() > maxSentences ? tmpLines.size() : maxSentences; } - - // set vector dimensionality - vectorSize = characterNGram_3.size(); + return maxSentences; } /** @@ -136,6 +101,7 @@ public DataSet getNext(int num) { int maxLength = 0; Map> patientBatch = new HashMap>(batchSize); for (int i = 0; i < num && cursor < totalExamples(); i++) { + // TODO regenerate sentences and do not depend on patientLines? List sentences = patientLines.get(cursor); patientBatch.put(i, sentences); @@ -149,10 +115,10 @@ public DataSet getNext(int num) { } // truncate if sequence is longer than maxSentences - if (maxLength > maxSentences) - maxLength = maxSentences; + if (maxLength > getTruncateLength()) + maxLength = getTruncateLength(); - INDArray features = Nd4j.create(new int[] { patientBatch.size(), vectorSize, maxLength }, 'f'); + INDArray features = Nd4j.create(new int[] { patientBatch.size(), inputRepresentation.getVectorSize(), maxLength }, 'f'); INDArray labels = Nd4j.create(new int[] { patientBatch.size(), totalOutcomes(), maxLength }, 'f'); INDArray featuresMask = Nd4j.zeros(patientBatch.size(), maxLength); @@ -168,7 +134,7 @@ public DataSet getNext(int num) { String sentence = sentences.get(j); // get vector presentation of sentence - INDArray vector = getChar3GramVectorToSentence(sentence); + INDArray vector = inputRepresentation.getVector(sentence); features.put(new INDArrayIndex[] { NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j) }, vector); @@ -193,34 +159,29 @@ public DataSet getNext(int num) { } /** - * Sentence will be transformed to a character 3-gram vector. - * - * @param sentence - * Sentence which gets vector representation. - * @return + * Load features from narrative. + * + * @param reviewContents + * Narrative content. + * @param maxLength + * Maximum length of token series length. + * @return Time series feature presentation of narrative. */ - public INDArray getChar3GramVectorToSentence(String sentence) { - - INDArray featureVector = Nd4j.zeros(vectorSize); - try { - String normalized = DataUtilities.processTextReduced(sentence); - String char3Grams = DataUtilities.getChar3GramRepresentation(normalized); - - // process character n-grams - String[] char3Splits = char3Grams.split("\\s+"); - - for (String split : char3Splits) { - if (char3GramToIdxMap.get(split) == null) { - int nGramIndex = char3GramToIdxMap.get("OOD"); - featureVector.putScalar(new int[] { nGramIndex }, 1.0); - } else { - int nGramIndex = char3GramToIdxMap.get(split); - featureVector.putScalar(new int[] { nGramIndex }, 1.0); - } - } - } catch (IOException e) { - e.printStackTrace(); + @Override + public INDArray loadFeaturesForNarrative(String reviewContents, int maxLength) { + + List sentences = DataUtilities.getSentences(reviewContents); + + int outputLength = Math.min(maxLength, sentences.size()); + INDArray features = Nd4j.create(1, inputRepresentation.getVectorSize(), outputLength); + + for (int j = 0; j < sentences.size() && j < outputLength; j++) { + String sentence = sentences.get(j); + INDArray vector = inputRepresentation.getVector(sentence); + features.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j) }, + vector); } - return featureVector; + return features; } + } diff --git a/src/test/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifierTest.java b/src/test/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifierTest.java index 497de8e..b1ff768 100644 --- a/src/test/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifierTest.java +++ b/src/test/java/at/medunigraz/imi/bst/n2c2/nn/BILSTMC3GClassifierTest.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.List; -import at.medunigraz.imi.bst.n2c2.nn.iterator.NGramIterator; import org.junit.Test; import org.xml.sax.SAXException; @@ -39,19 +38,6 @@ public void saveAndLoad() throws IOException, SAXException { // TODO use Mockito to call train() and ensure trainFullSetBMC is NOT called. testClassifier.initializeNetworkFromFile(BaseNNClassifier.getModelPath(train)); - // Check classifier attributes are properly initialized - assertEquals(trainClassifier.truncateLength, testClassifier.truncateLength); - assertEquals(trainClassifier.vectorSize, testClassifier.vectorSize); - - // Check maps are properly initialized - NGramIterator trainIterator = (NGramIterator)trainClassifier.fullSetIterator; - NGramIterator testIterator = (NGramIterator)testClassifier.fullSetIterator; - assertEquals(trainIterator.characterNGram_3, testIterator.characterNGram_3); - assertEquals(trainIterator.char3GramToIdxMap, testIterator.char3GramToIdxMap); - - // XXX maxSentences is not initialized, but is not needed for prediction - //assertEquals(trainIterator.maxSentences, testIterator.maxSentences); - assertSamplePatient(testClassifier, p); } }