diff --git a/ner/config/ner.properties b/ner/config/ner.properties index fd3c72d5c..b0f7114da 100644 --- a/ner/config/ner.properties +++ b/ner/config/ner.properties @@ -7,6 +7,7 @@ modelName = CoNLL # A path to the model files. During training this will be the location where the models are stored. # During testing this parameter can point to either a classpath or a local directory. pathToModelFile = ner/models +pathToLM = input_path_here GazetteersFeatures = 0 BrownClusterPaths = 0 diff --git a/ner/src/main/java/edu/illinois/cs/cogcomp/ner/ExpressiveFeatures/CharacterLanguageModel.java b/ner/src/main/java/edu/illinois/cs/cogcomp/ner/ExpressiveFeatures/CharacterLanguageModel.java index a7b55aab4..9f0d30118 100644 --- a/ner/src/main/java/edu/illinois/cs/cogcomp/ner/ExpressiveFeatures/CharacterLanguageModel.java +++ b/ner/src/main/java/edu/illinois/cs/cogcomp/ner/ExpressiveFeatures/CharacterLanguageModel.java @@ -1,11 +1,17 @@ package edu.illinois.cs.cogcomp.ner.ExpressiveFeatures; +import edu.illinois.cs.cogcomp.ner.IO.InFile; import edu.illinois.cs.cogcomp.core.datastructures.Pair; import edu.illinois.cs.cogcomp.core.io.LineIO; import edu.illinois.cs.cogcomp.core.utilities.StringUtils; import edu.illinois.cs.cogcomp.core.utilities.configuration.ResourceManager; import edu.illinois.cs.cogcomp.lbjava.parse.LinkedVector; import edu.illinois.cs.cogcomp.ner.LbjTagger.*; +import gnu.trove.map.hash.THashMap; + +import java.io.InputStream; +import java.io.FileInputStream; + import javax.annotation.Resource; import java.io.File; @@ -18,6 +24,7 @@ public class CharacterLanguageModel { private HashMap> counts; private int order; private String pad = "_"; + private static THashMap charlms = new THashMap<>(); public CharacterLanguageModel(){ // parameterized how? order of ngrams? @@ -30,6 +37,13 @@ public CharacterLanguageModel(){ order = 4; } + public static void addLM(String key, CharacterLanguageModel clm) { + charlms.put(key, clm); + } + + public static CharacterLanguageModel getLM(String key) { + return charlms.get(key); + } /** * Actually returns the log perplexity. @@ -261,23 +275,106 @@ public static void test() throws FileNotFoundException { } + + public static void test(CharacterLanguageModel eclm, CharacterLanguageModel neclm, Data testData) throws IOException { + + double correct = 0; + double total = 0; + List outpreds = new ArrayList<>(); + for(NERDocument doc : testData.documents){ + for(LinkedVector sentence : doc.sentences){ + for(int i = 0; i < sentence.size(); i++) { + NEWord word = (NEWord) sentence.get(i); + String label = word.neLabel.equals("O")? "O" : "B-ENT"; + double eppl = eclm.perplexity(string2list(word.form)); + double neppl = neclm.perplexity(string2list(word.form)); + + String pred; + + if(word.form.length() < 3){ + pred = "O"; + }else if(eppl < neppl){ + pred = "B-ENT"; + }else{ + pred = "O"; + } + + if (pred.equals(label)){ + //System.out.println(word.form + ": correct"); + correct += 1; + }else{ + System.out.println(word.form + ": WRONG***"); + } + total +=1; + + outpreds.add(word.form + " " + label + " " + pred); + } + outpreds.add(""); + } + } + + System.out.println("Accuracy: " + correct / total); + + LineIO.write("pred.txt", outpreds); + System.out.println("Wrote to pred.txt. Now run $ conlleval pred.txt to get F1 scores."); + + + } + + + + public static List> readList(String path) { + + List> seqs = new ArrayList<>(); + try { + List lines = LineIO.read("/shared/corpora/ner/clm/wikiEntity_train.out"); + for(String line : lines){ + String[] chars = line.trim().split(" "); + ArrayList seq = new ArrayList(Arrays.asList(chars)); + seqs.add(seq); + } + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + return seqs; + } + + public static void main(String[] args) throws Exception { // this trains models, and provides perplexities. - test2(); +// test2(); + + ParametersForLbjCode params = Parameters.readConfigAndLoadExternalData("config/ner.properties", false); + + String trainpath= "/shared/corpora/ner/conll2003/eng-files/Train-json/"; + String testpath = "/shared/corpora/ner/conll2003/eng-files/Test-json/"; + +// String trainpath= "/shared/corpora/ner/lorelei-swm-new/ben/Train/"; +// String testpath = "/shared/corpora/ner/lorelei-swm-new/ben/Test/"; - //ParametersForLbjCode params = Parameters.readConfigAndLoadExternalData("config/ner.properties", false); + System.out.println("Reading List"); + String wiki_ent_file = "/shared/corpora/ner/clm/wikiEntity_train.out"; + String wiki_nonent_file = "/shared/corpora/ner/clm/wikiNotEntity_train.out"; -// String trainpath= "/shared/corpora/ner/conll2003/eng-files/Train-json/"; -// String testpath = "/shared/corpora/ner/conll2003/eng-files/Test-json/"; +// List> wiki_ent = CharacterLanguageModel.readList(wiki_ent_file); +// List> wiki_non_ent = CharacterLanguageModel.readList(wiki_nonent_file); + + System.out.println("train entity clm"); + CharacterLanguageModel eclm = new CharacterLanguageModel(); + eclm.train(CharacterLanguageModel.readList(wiki_ent_file)); + + System.out.println("train non entity clm"); + CharacterLanguageModel neclm = new CharacterLanguageModel(); + neclm.train(CharacterLanguageModel.readList(wiki_nonent_file)); - //String trainpath= "/shared/corpora/ner/lorelei-swm-new/ara/Train/"; - //String testpath = "/shared/corpora/ner/lorelei-swm-new/ara/Test/"; + System.out.println("Testing"); +// Data trainData = new Data(trainpath, trainpath, "-json", new String[] {}, new String[] {}, params); + Data testData = new Data(testpath, testpath, "-json", new String[] {}, new String[] {}, params); + CharacterLanguageModel.test(eclm, neclm, testData); - //Data trainData = new Data(trainpath, trainpath, "-json", new String[] {}, new String[] {}, params); - //Data testData = new Data(testpath, testpath, "-json", new String[] {}, new String[] {}, params); +// trainEntityNotEntity(trainData, testData); - //trainEntityNotEntity(trainData, testData); } diff --git a/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/LearningCurveMultiDataset.java b/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/LearningCurveMultiDataset.java index b8611d539..3c9b826a1 100644 --- a/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/LearningCurveMultiDataset.java +++ b/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/LearningCurveMultiDataset.java @@ -73,6 +73,7 @@ public static void buildFinalModel(int fixedNumIterations, String trainDataPath, * @param incremental if the model is being incremented, this is true. * @throws Exception */ + public static void getLearningCurve(int fixedNumIterations, String trainDataPath, String testDataPath, boolean incremental, ParametersForLbjCode params) throws Exception { getLearningCurve(fixedNumIterations, "-c", trainDataPath, testDataPath, incremental, params); diff --git a/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/Parameters.java b/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/Parameters.java index 491968c6b..7aaa4ff60 100644 --- a/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/Parameters.java +++ b/ner/src/main/java/edu/illinois/cs/cogcomp/ner/LbjTagger/Parameters.java @@ -9,6 +9,7 @@ import edu.illinois.cs.cogcomp.core.constants.Language; import edu.illinois.cs.cogcomp.core.datastructures.ViewNames; +import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.CharacterLanguageModel; import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.BrownClusters; import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.GazetteersFactory; import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.TitleTextNormalizer; @@ -315,6 +316,35 @@ public static ParametersForLbjCode readAndLoadConfig(ResourceManager rm, boolean } + if (rm.containsKey("pathToLM")) { + String wiki_ent_file = "/shared/corpora/ner/clm/wikiEntity_train.out"; + String wiki_nonent_file = "/shared/corpora/ner/clm/wikiNotEntity_train.out"; + String arabic_file = "/shared/corpora/ner/clm/arabic_names.out"; + String russian_file = "/shared/corpora/ner/clm/russian_train.out"; + + System.out.println("train entity clm"); + CharacterLanguageModel eclm = new CharacterLanguageModel(); + eclm.train(CharacterLanguageModel.readList(wiki_ent_file)); + + System.out.println("train non entity clm"); + CharacterLanguageModel neclm = new CharacterLanguageModel(); + neclm.train(CharacterLanguageModel.readList(wiki_nonent_file)); + + System.out.println("train arabic clm"); + CharacterLanguageModel arabic_clm = new CharacterLanguageModel(); + arabic_clm.train(CharacterLanguageModel.readList(arabic_file)); + + System.out.println("train russian clm"); + CharacterLanguageModel russian_clm = new CharacterLanguageModel(); + russian_clm.train(CharacterLanguageModel.readList(russian_file)); + + CharacterLanguageModel.addLM("entity", eclm); + CharacterLanguageModel.addLM("nonentity", neclm); + CharacterLanguageModel.addLM("arabic", arabic_clm); + CharacterLanguageModel.addLM("russian", russian_clm); + + } + // If enabled, load up the brown clusters String brownDebug = ""; if (rm.containsKey("BrownClusterPaths") diff --git a/ner/src/main/lbj/LbjTagger.lbj b/ner/src/main/lbj/LbjTagger.lbj index 1077fc480..8bb99ea45 100644 --- a/ner/src/main/lbj/LbjTagger.lbj +++ b/ner/src/main/lbj/LbjTagger.lbj @@ -3,6 +3,7 @@ package edu.illinois.cs.cogcomp.ner.LbjFeatures; import java.util.*; import edu.illinois.cs.cogcomp.ner.LbjTagger.NEWord; +import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.CharacterLanguageModel; import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.WordTopicAndLayoutFeatures; import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.BrownClusters; import edu.illinois.cs.cogcomp.ner.ExpressiveFeatures.Gazetteers; @@ -363,7 +364,52 @@ discrete% PreviousTagPatternLevel1(NEWord word) <- } } -mixed% FeaturesSharedTemp(NEWord word) <- IsSentenceStart, Capitalization, nonLocalFeatures, GazetteersFeatures, FormParts, Forms, WordTypeInformation, Affixes, BrownClusterPaths, WordEmbeddingFeatures, WikifierFeatures, AffixesZH +discrete{false, true}% CharLangModelPrediction_context(NEWord word) <- +{ + int i; + CharacterLanguageModel eclm = CharacterLanguageModel.getLM("entity"); + CharacterLanguageModel neclm = CharacterLanguageModel.getLM("nonentity"); + + NEWord w = word, last = word; + for (i = 0; i <= 2 && last != null; ++i) last = (NEWord) last.next; + for (i = 0; i > -2 && w.previous != null; --i) w = (NEWord) w.previous; + for (; w != last; w = (NEWord) w.next){ + Double isEntity = eclm.perplexity(CharacterLanguageModel.string2list(w.form)); + Double isNotEntity = neclm.perplexity(CharacterLanguageModel.string2list(w.form)); + + if( Double.compare(isEntity, isNotEntity) < 0 ) + sense i+"NoThreshold" : true; + else + sense i+"NoThreshold" : false; + i++; + } +} + +discrete{false, true}% CharLangModelArabic(NEWord word) <- +{ + CharacterLanguageModel arabic_clm = CharacterLanguageModel.getLM("arabic"); + CharacterLanguageModel neclm = CharacterLanguageModel.getLM("nonentity"); + Double isArabic = arabic_clm.perplexity(CharacterLanguageModel.string2list(word.form)); + Double isNotEntity = neclm.perplexity(CharacterLanguageModel.string2list(word.form)); + if( Double.compare(isArabic, isNotEntity) < 0 ) + sense "Arabic" : true; + else + sense "Arabic" : false; +} + +discrete{false, true}% CharLangModelRussian(NEWord word) <- +{ + CharacterLanguageModel russian_clm = CharacterLanguageModel.getLM("russian"); + CharacterLanguageModel neclm = CharacterLanguageModel.getLM("nonentity"); + Double isRussian = russian_clm.perplexity(CharacterLanguageModel.string2list(word.form)); + Double isNotEntity = neclm.perplexity(CharacterLanguageModel.string2list(word.form)); + if( Double.compare(isRussian, isNotEntity) < 0 ) + sense "Russian" : true; + else + sense "Russian" : false; +} + +mixed% FeaturesSharedTemp(NEWord word) <- IsSentenceStart, Capitalization, nonLocalFeatures, GazetteersFeatures, FormParts, Forms, WordTypeInformation, Affixes, BrownClusterPaths, WordEmbeddingFeatures, WikifierFeatures, AffixesZH, CharLangModelPrediction_context, CharLangModelArabic, CharLangModelRussian mixed% FeaturesLevel1SharedWithLevel2(NEWord word) <- FeaturesSharedTemp /*, IsWordCaseNormalized&&FeaturesSharedTemp*/