|
19 | 19 |
|
20 | 20 | import java.io.File;
|
21 | 21 | import java.io.IOException;
|
| 22 | +import java.io.Serializable; |
22 | 23 | import java.util.ArrayList;
|
23 | 24 | import java.util.Collection;
|
24 | 25 | import java.util.List;
|
|
27 | 28 | import org.apache.spark.api.java.function.VoidFunction;
|
28 | 29 | import org.apache.spark.ml.feature.CountVectorizerModel;
|
29 | 30 | import org.apache.spark.ml.feature.Normalizer;
|
30 |
| -import org.apache.spark.mllib.linalg.Vector; |
| 31 | +import org.apache.spark.ml.linalg.SparseVector; |
| 32 | +import org.apache.spark.mllib.linalg.Vectors; |
31 | 33 | import org.apache.spark.mllib.regression.LassoModel;
|
32 | 34 | import org.apache.spark.sql.Dataset;
|
33 | 35 | import org.apache.spark.sql.Row;
|
|
54 | 56 | /**
|
55 | 57 | * TODO: Documentation
|
56 | 58 | */
|
57 |
| -public class AgePredictTool extends BasicCmdLineTool { |
58 |
| - |
59 |
| - @Override |
60 |
| - public String getShortDescription() { |
61 |
| - return "age predictor"; |
62 |
| - } |
63 |
| - |
64 |
| - @Override |
65 |
| - public String getHelp() { |
66 |
| - return "Usage: " + CLI.CMD + " " + getName() + " [MaxEntModel] RegressionModel Documents"; |
67 |
| - } |
68 |
| - |
69 |
| - @Override |
70 |
| - public void run(String[] args) { |
71 |
| - AgePredictModel model = null; |
72 |
| - AgeClassifyME classify = null; |
73 |
| - if (args.length == 3) { |
74 |
| - try { |
75 |
| - AgeClassifyModel classifyModel = new AgeClassifyModel(new File(args[0])); |
76 |
| - |
77 |
| - classify = new AgeClassifyME(classifyModel); |
78 |
| - model = AgePredictModel.readModel(new File(args[1])); |
79 |
| - } catch (Exception e) { |
80 |
| - e.printStackTrace(); |
81 |
| - return; |
82 |
| - } |
83 |
| - } |
84 |
| - else if (args.length == 2) { |
85 |
| - try { |
86 |
| - model = AgePredictModel.readModel(new File(args[0])); |
87 |
| - } catch (Exception e) { |
88 |
| - e.printStackTrace(); |
89 |
| - return; |
90 |
| - } |
| 59 | +public class AgePredictTool extends BasicCmdLineTool implements Serializable { |
| 60 | + |
| 61 | + @Override |
| 62 | + public String getShortDescription() { |
| 63 | + return "age predictor"; |
91 | 64 | }
|
92 |
| - else { |
93 |
| - System.out.println(getHelp()); |
94 |
| - return; |
| 65 | + |
| 66 | + @Override |
| 67 | + public String getHelp() { |
| 68 | + return "Usage: " + CLI.CMD + " " + getName() + " [MaxEntModel] RegressionModel Documents"; |
95 | 69 | }
|
96 |
| - |
97 |
| - ObjectStream<String> documentStream; |
98 |
| - List<Row> data = new ArrayList<Row>(); |
99 |
| - |
100 |
| - SparkSession spark = SparkSession |
101 |
| - .builder() |
102 |
| - .appName("AgePredict") |
103 |
| - .getOrCreate(); |
104 |
| - |
105 |
| - try { |
106 |
| - documentStream = new ParagraphStream( |
107 |
| - new PlainTextByLineStream(new SystemInputStreamFactory(), SystemInputStreamFactory.encoding())); |
108 |
| - |
109 |
| - String document; |
110 |
| - FeatureGenerator[] featureGenerators = model.getContext().getFeatureGenerators(); |
111 |
| - while ((document = documentStream.read()) != null) { |
112 |
| - String[] tokens = model.getContext().getTokenizer().tokenize(document); |
113 |
| - |
114 |
| - double prob[] = classify.getProbabilities(tokens); |
115 |
| - String category = classify.getBestCategory(prob); |
116 |
| - |
117 |
| - Collection<String> context = new ArrayList<String>(); |
118 |
| - |
119 |
| - for (FeatureGenerator featureGenerator : featureGenerators) { |
120 |
| - Collection<String> extractedFeatures = |
121 |
| - featureGenerator.extractFeatures(tokens); |
122 |
| - context.addAll(extractedFeatures); |
| 70 | + |
| 71 | + @Override |
| 72 | + public void run(String[] args) { |
| 73 | + AgePredictModel model = null; |
| 74 | + AgeClassifyME classify = null; |
| 75 | + if (args.length == 3) { |
| 76 | + try { |
| 77 | + AgeClassifyModel classifyModel = new AgeClassifyModel(new File(args[0])); |
| 78 | + |
| 79 | + classify = new AgeClassifyME(classifyModel); |
| 80 | + model = AgePredictModel.readModel(new File(args[1])); |
| 81 | + } catch (Exception e) { |
| 82 | + e.printStackTrace(); |
| 83 | + return; |
| 84 | + } |
| 85 | + } else if (args.length == 2) { |
| 86 | + try { |
| 87 | + model = AgePredictModel.readModel(new File(args[0])); |
| 88 | + } catch (Exception e) { |
| 89 | + e.printStackTrace(); |
| 90 | + return; |
| 91 | + } |
| 92 | + } else { |
| 93 | + System.out.println(getHelp()); |
| 94 | + return; |
123 | 95 | }
|
124 |
| - |
125 |
| - if (category != null) { |
126 |
| - for (int i = 0; i < tokens.length / 18; i++) { |
127 |
| - context.add("cat="+ category); |
128 |
| - } |
| 96 | + |
| 97 | + ObjectStream<String> documentStream; |
| 98 | + List<Row> data = new ArrayList<Row>(); |
| 99 | + |
| 100 | + SparkSession spark = SparkSession.builder().appName("AgePredict").getOrCreate(); |
| 101 | + |
| 102 | + try { |
| 103 | + System.out.println("Please enter your text separted by newline. When done press ctrl+d to terminate system input"); |
| 104 | + documentStream = new ParagraphStream( |
| 105 | + new PlainTextByLineStream(new SystemInputStreamFactory(), SystemInputStreamFactory.encoding())); |
| 106 | + |
| 107 | + String document; |
| 108 | + FeatureGenerator[] featureGenerators = model.getContext().getFeatureGenerators(); |
| 109 | + while ((document = documentStream.read()) != null) { |
| 110 | + String[] tokens = model.getContext().getTokenizer().tokenize(document); |
| 111 | + |
| 112 | + double prob[] = classify.getProbabilities(tokens); |
| 113 | + String category = classify.getBestCategory(prob); |
| 114 | + |
| 115 | + Collection<String> context = new ArrayList<String>(); |
| 116 | + |
| 117 | + for (FeatureGenerator featureGenerator : featureGenerators) { |
| 118 | + Collection<String> extractedFeatures = featureGenerator.extractFeatures(tokens); |
| 119 | + context.addAll(extractedFeatures); |
| 120 | + } |
| 121 | + |
| 122 | + if (category != null) { |
| 123 | + for (int i = 0; i < tokens.length / 18; i++) { |
| 124 | + context.add("cat=" + category); |
| 125 | + } |
| 126 | + } |
| 127 | + if (context.size() > 0) { |
| 128 | + data.add(RowFactory.create(document, context.toArray())); |
| 129 | + } |
| 130 | + |
| 131 | + } |
| 132 | + } catch (IOException e) { |
| 133 | + e.printStackTrace(); |
| 134 | + CmdLineUtil.handleStdinIoError(e); |
129 | 135 | }
|
130 |
| - if (context.size() > 0) { |
131 |
| - data.add(RowFactory.create(document, context.toArray())); |
132 |
| - } |
133 |
| - } |
134 |
| - } catch (IOException e) { |
135 |
| - CmdLineUtil.handleStdinIoError(e); |
| 136 | + |
| 137 | + StructType schema = new StructType( |
| 138 | + new StructField[] { new StructField("document", DataTypes.StringType, false, Metadata.empty()), |
| 139 | + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); |
| 140 | + |
| 141 | + Dataset<Row> df = spark.createDataFrame(data, schema); |
| 142 | + |
| 143 | + CountVectorizerModel cvm = new CountVectorizerModel(model.getVocabulary()).setInputCol("text") |
| 144 | + .setOutputCol("feature"); |
| 145 | + |
| 146 | + Dataset<Row> eventDF = cvm.transform(df); |
| 147 | + |
| 148 | + Normalizer normalizer = new Normalizer().setInputCol("feature").setOutputCol("normFeature").setP(1.0); |
| 149 | + |
| 150 | + JavaRDD<Row> normEventDF = normalizer.transform(eventDF).javaRDD(); |
| 151 | + |
| 152 | + //org.apache.spark.ml.linalg.SparseVector cannot be cast to org.apache.spark.mllib.linalg.Vector |
| 153 | + |
| 154 | + final LassoModel linModel = model.getModel(); |
| 155 | + normEventDF.foreach(new VoidFunction<Row>() { |
| 156 | + public void call(Row event) { |
| 157 | + SparseVector sp = (SparseVector) event.getAs("normFeature"); |
| 158 | + |
| 159 | + double prediction = linModel.predict(Vectors.sparse(sp.size(), sp.indices(), sp.values())); |
| 160 | + System.out.println((String) event.getAs("document")); |
| 161 | + System.out.println("Prediction: " + prediction); |
| 162 | + } |
| 163 | + }); |
| 164 | + |
| 165 | + spark.stop(); |
136 | 166 | }
|
137 |
| - StructType schema = new StructType(new StructField [] { |
138 |
| - new StructField("document", DataTypes.StringType, false, Metadata.empty()), |
139 |
| - new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) |
140 |
| - }); |
141 |
| - |
142 |
| - Dataset<Row> df = spark.createDataFrame(data, schema); |
143 |
| - |
144 |
| - CountVectorizerModel cvm = new CountVectorizerModel(model.getVocabulary()) |
145 |
| - .setInputCol("text") |
146 |
| - .setOutputCol("feature"); |
147 |
| - |
148 |
| - Dataset<Row> eventDF = cvm.transform(df); |
149 |
| - |
150 |
| - Normalizer normalizer = new Normalizer() |
151 |
| - .setInputCol("feature") |
152 |
| - .setOutputCol("normFeature") |
153 |
| - .setP(1.0); |
154 |
| - |
155 |
| - JavaRDD<Row> normEventDF= normalizer.transform(eventDF).javaRDD(); |
156 |
| - |
157 |
| - final LassoModel linModel = model.getModel(); |
158 |
| - normEventDF.foreach( new VoidFunction<Row>() { |
159 |
| - public void call(Row event) { |
160 |
| - double prediction = linModel.predict((Vector) event.getAs("normFeature")); |
161 |
| - System.out.println((String) event.getAs("document")); |
162 |
| - System.out.println("Prediction: "+ prediction); |
163 |
| - } |
164 |
| - }); |
165 |
| - |
166 |
| - spark.stop(); |
167 |
| - } |
168 | 167 | }
|
0 commit comments