forked from bst-mug/n2c2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
146 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...igraz/imi/bst/n2c2/nn/BaseNNIterator.java → .../bst/n2c2/nn/iterator/BaseNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
280 changes: 140 additions & 140 deletions
280
...i/bst/n2c2/nn/N2c2PatientIteratorBML.java → ...2/nn/iterator/N2c2PatientIteratorBML.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,140 +1,140 @@ | ||
package at.medunigraz.imi.bst.n2c2.nn; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
|
||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; | ||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | ||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.dataset.DataSet; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.indexing.INDArrayIndex; | ||
import org.nd4j.linalg.indexing.NDArrayIndex; | ||
|
||
import at.medunigraz.imi.bst.n2c2.model.Patient; | ||
|
||
/** | ||
* Date iterator refactored from dl4j examples. | ||
* | ||
* @author Markus | ||
* | ||
*/ | ||
public class N2c2PatientIteratorBML extends BaseNNIterator { | ||
|
||
private static final long serialVersionUID = 1L; | ||
|
||
private final WordVectors wordVectors; | ||
|
||
private final int truncateLength; | ||
|
||
private final TokenizerFactory tokenizerFactory; | ||
|
||
|
||
/** | ||
* Patient data iterator for the n2c2 task. | ||
* | ||
* @param patients | ||
* Patient data. | ||
* @param wordVectors | ||
* Word vectors object. | ||
* @param batchSize | ||
* Mini batch size use for processing. | ||
* @param truncateLength | ||
* Maximum length of token sequence. | ||
*/ | ||
public N2c2PatientIteratorBML(List<Patient> patients, WordVectors wordVectors, int batchSize, int truncateLength) { | ||
|
||
this.patients = patients; | ||
this.batchSize = batchSize; | ||
this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; | ||
|
||
this.wordVectors = wordVectors; | ||
this.truncateLength = truncateLength; | ||
|
||
tokenizerFactory = new DefaultTokenizerFactory(); | ||
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); | ||
} | ||
|
||
/** | ||
* Next data set implementation. | ||
* | ||
* @param num | ||
* Mini batch size. | ||
* @return DataSet Patients data set. | ||
*/ | ||
@Override | ||
public DataSet getNext(int num) { | ||
|
||
HashMap<Integer, ArrayList<Boolean>> binaryMultiHotVectorMap = new HashMap<Integer, ArrayList<Boolean>>(); | ||
|
||
// load narrative from patient | ||
List<String> narratives = new ArrayList<>(num); | ||
for (int i = 0; i < num && cursor < totalExamples(); i++) { | ||
String narrative = patients.get(cursor).getText(); | ||
narratives.add(narrative); | ||
|
||
ArrayList<Boolean> binaryMultiHotVector = new ArrayList<Boolean>(); | ||
fillBinaryMultiHotVector(binaryMultiHotVector); | ||
|
||
binaryMultiHotVectorMap.put(i, binaryMultiHotVector); | ||
cursor++; | ||
} | ||
|
||
// filter unknown words and tokenize | ||
List<List<String>> allTokens = new ArrayList<>(narratives.size()); | ||
int maxLength = 0; | ||
for (String narrative : narratives) { | ||
List<String> tokens = tokenizerFactory.create(narrative).getTokens(); | ||
List<String> tokensFiltered = new ArrayList<>(); | ||
for (String token : tokens) { | ||
if (wordVectors.hasWord(token)) | ||
tokensFiltered.add(token); | ||
} | ||
allTokens.add(tokensFiltered); | ||
maxLength = Math.max(maxLength, tokensFiltered.size()); | ||
} | ||
|
||
// truncate if sequence is longer than truncateLength | ||
if (maxLength > truncateLength) | ||
maxLength = truncateLength; | ||
|
||
INDArray features = Nd4j.create(narratives.size(), vectorSize, maxLength); | ||
INDArray labels = Nd4j.create(narratives.size(), 13, maxLength); | ||
|
||
INDArray featuresMask = Nd4j.zeros(narratives.size(), maxLength); | ||
INDArray labelsMask = Nd4j.zeros(narratives.size(), maxLength); | ||
|
||
int[] temp = new int[2]; | ||
for (int i = 0; i < narratives.size(); i++) { | ||
List<String> tokens = allTokens.get(i); | ||
temp[0] = i; | ||
|
||
// get word vectors for each token in narrative | ||
for (int j = 0; j < tokens.size() && j < maxLength; j++) { | ||
String token = tokens.get(j); | ||
INDArray vector = wordVectors.getWordVectorMatrix(token); | ||
features.put(new INDArrayIndex[] { NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j) }, | ||
vector); | ||
|
||
temp[1] = j; | ||
featuresMask.putScalar(temp, 1.0); | ||
} | ||
|
||
int lastIdx = Math.min(tokens.size(), maxLength); | ||
|
||
// set binary multi-labels | ||
ArrayList<Boolean> binaryMultiHotVector = binaryMultiHotVectorMap.get(i); | ||
int labelIndex = 0; | ||
for (Boolean label : binaryMultiHotVector) { | ||
labels.putScalar(new int[] { i, labelIndex, lastIdx - 1 }, label == true ? 1.0 : 0.0); | ||
labelIndex++; | ||
} | ||
// out exists at the final step of the sequence | ||
labelsMask.putScalar(new int[] { i, lastIdx - 1 }, 1.0); | ||
} | ||
return new DataSet(features, labels, featuresMask, labelsMask); | ||
} | ||
} | ||
package at.medunigraz.imi.bst.n2c2.nn.iterator; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
|
||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; | ||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | ||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.dataset.DataSet; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.indexing.INDArrayIndex; | ||
import org.nd4j.linalg.indexing.NDArrayIndex; | ||
|
||
import at.medunigraz.imi.bst.n2c2.model.Patient; | ||
|
||
/** | ||
* Date iterator refactored from dl4j examples. | ||
* | ||
* @author Markus | ||
* | ||
*/ | ||
public class N2c2PatientIteratorBML extends BaseNNIterator { | ||
|
||
private static final long serialVersionUID = 1L; | ||
|
||
private final WordVectors wordVectors; | ||
|
||
private final int truncateLength; | ||
|
||
private final TokenizerFactory tokenizerFactory; | ||
|
||
|
||
/** | ||
* Patient data iterator for the n2c2 task. | ||
* | ||
* @param patients | ||
* Patient data. | ||
* @param wordVectors | ||
* Word vectors object. | ||
* @param batchSize | ||
* Mini batch size use for processing. | ||
* @param truncateLength | ||
* Maximum length of token sequence. | ||
*/ | ||
public N2c2PatientIteratorBML(List<Patient> patients, WordVectors wordVectors, int batchSize, int truncateLength) { | ||
|
||
this.patients = patients; | ||
this.batchSize = batchSize; | ||
this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; | ||
|
||
this.wordVectors = wordVectors; | ||
this.truncateLength = truncateLength; | ||
|
||
tokenizerFactory = new DefaultTokenizerFactory(); | ||
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); | ||
} | ||
|
||
/** | ||
* Next data set implementation. | ||
* | ||
* @param num | ||
* Mini batch size. | ||
* @return DataSet Patients data set. | ||
*/ | ||
@Override | ||
public DataSet getNext(int num) { | ||
|
||
HashMap<Integer, ArrayList<Boolean>> binaryMultiHotVectorMap = new HashMap<Integer, ArrayList<Boolean>>(); | ||
|
||
// load narrative from patient | ||
List<String> narratives = new ArrayList<>(num); | ||
for (int i = 0; i < num && cursor < totalExamples(); i++) { | ||
String narrative = patients.get(cursor).getText(); | ||
narratives.add(narrative); | ||
|
||
ArrayList<Boolean> binaryMultiHotVector = new ArrayList<Boolean>(); | ||
fillBinaryMultiHotVector(binaryMultiHotVector); | ||
|
||
binaryMultiHotVectorMap.put(i, binaryMultiHotVector); | ||
cursor++; | ||
} | ||
|
||
// filter unknown words and tokenize | ||
List<List<String>> allTokens = new ArrayList<>(narratives.size()); | ||
int maxLength = 0; | ||
for (String narrative : narratives) { | ||
List<String> tokens = tokenizerFactory.create(narrative).getTokens(); | ||
List<String> tokensFiltered = new ArrayList<>(); | ||
for (String token : tokens) { | ||
if (wordVectors.hasWord(token)) | ||
tokensFiltered.add(token); | ||
} | ||
allTokens.add(tokensFiltered); | ||
maxLength = Math.max(maxLength, tokensFiltered.size()); | ||
} | ||
|
||
// truncate if sequence is longer than truncateLength | ||
if (maxLength > truncateLength) | ||
maxLength = truncateLength; | ||
|
||
INDArray features = Nd4j.create(narratives.size(), vectorSize, maxLength); | ||
INDArray labels = Nd4j.create(narratives.size(), 13, maxLength); | ||
|
||
INDArray featuresMask = Nd4j.zeros(narratives.size(), maxLength); | ||
INDArray labelsMask = Nd4j.zeros(narratives.size(), maxLength); | ||
|
||
int[] temp = new int[2]; | ||
for (int i = 0; i < narratives.size(); i++) { | ||
List<String> tokens = allTokens.get(i); | ||
temp[0] = i; | ||
|
||
// get word vectors for each token in narrative | ||
for (int j = 0; j < tokens.size() && j < maxLength; j++) { | ||
String token = tokens.get(j); | ||
INDArray vector = wordVectors.getWordVectorMatrix(token); | ||
features.put(new INDArrayIndex[] { NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j) }, | ||
vector); | ||
|
||
temp[1] = j; | ||
featuresMask.putScalar(temp, 1.0); | ||
} | ||
|
||
int lastIdx = Math.min(tokens.size(), maxLength); | ||
|
||
// set binary multi-labels | ||
ArrayList<Boolean> binaryMultiHotVector = binaryMultiHotVectorMap.get(i); | ||
int labelIndex = 0; | ||
for (Boolean label : binaryMultiHotVector) { | ||
labels.putScalar(new int[] { i, labelIndex, lastIdx - 1 }, label == true ? 1.0 : 0.0); | ||
labelIndex++; | ||
} | ||
// out exists at the final step of the sequence | ||
labelsMask.putScalar(new int[] { i, lastIdx - 1 }, 1.0); | ||
} | ||
return new DataSet(features, labels, featuresMask, labelsMask); | ||
} | ||
} |
3 changes: 2 additions & 1 deletion
3
...nigraz/imi/bst/n2c2/nn/NGramIterator.java → ...i/bst/n2c2/nn/iterator/NGramIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters