Skip to content

Commit 5f2b4f2

Browse files
committed
Making AgePredict work
1 parent 290e6fb commit 5f2b4f2

File tree

4 files changed

+117
-108
lines changed

4 files changed

+117
-108
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ Note: Each document must be followed by an empty line to be detected as a separa
6969
Usage: bin/authorage AgeClassify model < documents
7070
```
7171
72+
```shell
73+
Usage: bin/authorage AgePredict ./model/classify-unigram.bin ./model/regression-global.bin data/sample_test.txt
74+
```
75+
76+
# Downloads
77+
For AgePredict to work you need to download `en-pos-maxent.bin`, `en-sent.bin` and `en-token.bin` from [http://opennlp.sourceforge.net/models-1.5/](http://opennlp.sourceforge.net/models-1.5/) to `model/opennlp/`
78+
7279
# Contributors
7380
* Joey Hong, Caltech, CA
7481

bin/authorage

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# specific language governing permissions and limitations
1818
# under the License.
1919

20-
export SPARK_HOME="spark-2.0.0-bin-hadoop2.7"
20+
# export SPARK_HOME="spark-2.0.0-bin-hadoop2.7"
2121

2222
# Created JAR Application
2323
export JAR="target/age-predictor-1.0-SNAPSHOT-jar-with-dependencies.jar"

data/sample_test.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Can AI really predict my age through what I wrote?
2+
3+
That will be so cool of AI

src/main/java/gov/nasa/jpl/ml/cmdline/spark/authorage/AgePredictTool.java

+106-107
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.io.File;
2121
import java.io.IOException;
22+
import java.io.Serializable;
2223
import java.util.ArrayList;
2324
import java.util.Collection;
2425
import java.util.List;
@@ -27,7 +28,8 @@
2728
import org.apache.spark.api.java.function.VoidFunction;
2829
import org.apache.spark.ml.feature.CountVectorizerModel;
2930
import org.apache.spark.ml.feature.Normalizer;
30-
import org.apache.spark.mllib.linalg.Vector;
31+
import org.apache.spark.ml.linalg.SparseVector;
32+
import org.apache.spark.mllib.linalg.Vectors;
3133
import org.apache.spark.mllib.regression.LassoModel;
3234
import org.apache.spark.sql.Dataset;
3335
import org.apache.spark.sql.Row;
@@ -54,115 +56,112 @@
5456
/**
5557
* TODO: Documentation
5658
*/
57-
public class AgePredictTool extends BasicCmdLineTool {
58-
59-
@Override
60-
public String getShortDescription() {
61-
return "age predictor";
62-
}
63-
64-
@Override
65-
public String getHelp() {
66-
return "Usage: " + CLI.CMD + " " + getName() + " [MaxEntModel] RegressionModel Documents";
67-
}
68-
69-
@Override
70-
public void run(String[] args) {
71-
AgePredictModel model = null;
72-
AgeClassifyME classify = null;
73-
if (args.length == 3) {
74-
try {
75-
AgeClassifyModel classifyModel = new AgeClassifyModel(new File(args[0]));
76-
77-
classify = new AgeClassifyME(classifyModel);
78-
model = AgePredictModel.readModel(new File(args[1]));
79-
} catch (Exception e) {
80-
e.printStackTrace();
81-
return;
82-
}
83-
}
84-
else if (args.length == 2) {
85-
try {
86-
model = AgePredictModel.readModel(new File(args[0]));
87-
} catch (Exception e) {
88-
e.printStackTrace();
89-
return;
90-
}
59+
public class AgePredictTool extends BasicCmdLineTool implements Serializable {
60+
61+
@Override
62+
public String getShortDescription() {
63+
return "age predictor";
9164
}
92-
else {
93-
System.out.println(getHelp());
94-
return;
65+
66+
@Override
67+
public String getHelp() {
68+
return "Usage: " + CLI.CMD + " " + getName() + " [MaxEntModel] RegressionModel Documents";
9569
}
96-
97-
ObjectStream<String> documentStream;
98-
List<Row> data = new ArrayList<Row>();
99-
100-
SparkSession spark = SparkSession
101-
.builder()
102-
.appName("AgePredict")
103-
.getOrCreate();
104-
105-
try {
106-
documentStream = new ParagraphStream(
107-
new PlainTextByLineStream(new SystemInputStreamFactory(), SystemInputStreamFactory.encoding()));
108-
109-
String document;
110-
FeatureGenerator[] featureGenerators = model.getContext().getFeatureGenerators();
111-
while ((document = documentStream.read()) != null) {
112-
String[] tokens = model.getContext().getTokenizer().tokenize(document);
113-
114-
double prob[] = classify.getProbabilities(tokens);
115-
String category = classify.getBestCategory(prob);
116-
117-
Collection<String> context = new ArrayList<String>();
118-
119-
for (FeatureGenerator featureGenerator : featureGenerators) {
120-
Collection<String> extractedFeatures =
121-
featureGenerator.extractFeatures(tokens);
122-
context.addAll(extractedFeatures);
70+
71+
@Override
72+
public void run(String[] args) {
73+
AgePredictModel model = null;
74+
AgeClassifyME classify = null;
75+
if (args.length == 3) {
76+
try {
77+
AgeClassifyModel classifyModel = new AgeClassifyModel(new File(args[0]));
78+
79+
classify = new AgeClassifyME(classifyModel);
80+
model = AgePredictModel.readModel(new File(args[1]));
81+
} catch (Exception e) {
82+
e.printStackTrace();
83+
return;
84+
}
85+
} else if (args.length == 2) {
86+
try {
87+
model = AgePredictModel.readModel(new File(args[0]));
88+
} catch (Exception e) {
89+
e.printStackTrace();
90+
return;
91+
}
92+
} else {
93+
System.out.println(getHelp());
94+
return;
12395
}
124-
125-
if (category != null) {
126-
for (int i = 0; i < tokens.length / 18; i++) {
127-
context.add("cat="+ category);
128-
}
96+
97+
ObjectStream<String> documentStream;
98+
List<Row> data = new ArrayList<Row>();
99+
100+
SparkSession spark = SparkSession.builder().appName("AgePredict").getOrCreate();
101+
102+
try {
103+
System.out.println("Please enter your text separted by newline. When done press ctrl+d to terminate system input");
104+
documentStream = new ParagraphStream(
105+
new PlainTextByLineStream(new SystemInputStreamFactory(), SystemInputStreamFactory.encoding()));
106+
107+
String document;
108+
FeatureGenerator[] featureGenerators = model.getContext().getFeatureGenerators();
109+
while ((document = documentStream.read()) != null) {
110+
String[] tokens = model.getContext().getTokenizer().tokenize(document);
111+
112+
double prob[] = classify.getProbabilities(tokens);
113+
String category = classify.getBestCategory(prob);
114+
115+
Collection<String> context = new ArrayList<String>();
116+
117+
for (FeatureGenerator featureGenerator : featureGenerators) {
118+
Collection<String> extractedFeatures = featureGenerator.extractFeatures(tokens);
119+
context.addAll(extractedFeatures);
120+
}
121+
122+
if (category != null) {
123+
for (int i = 0; i < tokens.length / 18; i++) {
124+
context.add("cat=" + category);
125+
}
126+
}
127+
if (context.size() > 0) {
128+
data.add(RowFactory.create(document, context.toArray()));
129+
}
130+
131+
}
132+
} catch (IOException e) {
133+
e.printStackTrace();
134+
CmdLineUtil.handleStdinIoError(e);
129135
}
130-
if (context.size() > 0) {
131-
data.add(RowFactory.create(document, context.toArray()));
132-
}
133-
}
134-
} catch (IOException e) {
135-
CmdLineUtil.handleStdinIoError(e);
136+
137+
StructType schema = new StructType(
138+
new StructField[] { new StructField("document", DataTypes.StringType, false, Metadata.empty()),
139+
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) });
140+
141+
Dataset<Row> df = spark.createDataFrame(data, schema);
142+
143+
CountVectorizerModel cvm = new CountVectorizerModel(model.getVocabulary()).setInputCol("text")
144+
.setOutputCol("feature");
145+
146+
Dataset<Row> eventDF = cvm.transform(df);
147+
148+
Normalizer normalizer = new Normalizer().setInputCol("feature").setOutputCol("normFeature").setP(1.0);
149+
150+
JavaRDD<Row> normEventDF = normalizer.transform(eventDF).javaRDD();
151+
152+
//org.apache.spark.ml.linalg.SparseVector cannot be cast to org.apache.spark.mllib.linalg.Vector
153+
154+
final LassoModel linModel = model.getModel();
155+
normEventDF.foreach(new VoidFunction<Row>() {
156+
public void call(Row event) {
157+
SparseVector sp = (SparseVector) event.getAs("normFeature");
158+
159+
double prediction = linModel.predict(Vectors.sparse(sp.size(), sp.indices(), sp.values()));
160+
System.out.println((String) event.getAs("document"));
161+
System.out.println("Prediction: " + prediction);
162+
}
163+
});
164+
165+
spark.stop();
136166
}
137-
StructType schema = new StructType(new StructField [] {
138-
new StructField("document", DataTypes.StringType, false, Metadata.empty()),
139-
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
140-
});
141-
142-
Dataset<Row> df = spark.createDataFrame(data, schema);
143-
144-
CountVectorizerModel cvm = new CountVectorizerModel(model.getVocabulary())
145-
.setInputCol("text")
146-
.setOutputCol("feature");
147-
148-
Dataset<Row> eventDF = cvm.transform(df);
149-
150-
Normalizer normalizer = new Normalizer()
151-
.setInputCol("feature")
152-
.setOutputCol("normFeature")
153-
.setP(1.0);
154-
155-
JavaRDD<Row> normEventDF= normalizer.transform(eventDF).javaRDD();
156-
157-
final LassoModel linModel = model.getModel();
158-
normEventDF.foreach( new VoidFunction<Row>() {
159-
public void call(Row event) {
160-
double prediction = linModel.predict((Vector) event.getAs("normFeature"));
161-
System.out.println((String) event.getAs("document"));
162-
System.out.println("Prediction: "+ prediction);
163-
}
164-
});
165-
166-
spark.stop();
167-
}
168167
}

0 commit comments

Comments
 (0)