Skip to content

Commit 3884095

Browse files
author
AlkaSaliss
committed
Add readR and readWeka methods
1 parent 1d31951 commit 3884095

File tree

8 files changed

+1718
-38
lines changed

8 files changed

+1718
-38
lines changed

main.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
public class main {
1515

16-
public static void main(String[] args) {
16+
public static void main2(String[] args) {
1717
// TODO Auto-generated method stub
1818

1919

methods/methods/R/methods.rdb

Whitespace-only changes.

methods/methods/R/methods.rdx

113 Bytes
Binary file not shown.

src/main/java/fr/ensai/renjin_ml/RDecisionTree.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@ public RDecisionTree(String form, String method, int coly) {
3939

4040
}
4141

42-
public void fit(String dataPath) throws ScriptException {
42+
public void fit(String dataPath, boolean hasRowname) throws ScriptException {
4343
/*
4444
* loading the file
4545
*/
46+
if (hasRowname) {
4647
engine.eval("data <- as.data.frame(read.csv("+dataPath+", header=T, row.names=1))");
48+
} else {
49+
engine.eval("data <- as.data.frame(read.csv("+dataPath+", header=T, row.names=NULL))");
50+
}
4751
/*
4852
* Spliting data into training and test sets
4953
*/

src/main/java/fr/ensai/renjin_ml/WekaDecisionTree.java

+25-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import weka.core.converters.CSVLoader;
1515

1616
import weka.filters.Filter;
17+
import weka.filters.unsupervised.attribute.NumericToNominal;
1718
import weka.filters.unsupervised.attribute.Remove;
1819

1920

@@ -29,7 +30,7 @@ public class WekaDecisionTree {
2930

3031

3132

32-
public WekaDecisionTree(String path, boolean hasRowName) throws Exception {
33+
public WekaDecisionTree(String path, boolean hasRowName, boolean toNumeric) throws Exception {
3334
this.datapath = path;
3435

3536
//BufferedReader reader = new BufferedReader(new FileReader(path));
@@ -55,6 +56,29 @@ public WekaDecisionTree(String path, boolean hasRowName) throws Exception {
5556
}
5657

5758

59+
60+
61+
/*
62+
* Convert numeric target variable to categories
63+
* */
64+
if (toNumeric) {
65+
66+
NumericToNominal convert= new NumericToNominal();
67+
String[] options= new String[2];
68+
options[0]="-R"; //remove option
69+
options[1]= String.valueOf(this.data.numAttributes() ) ;
70+
71+
convert.setOptions(options);
72+
convert.setInputFormat(this.data);
73+
this.data =Filter.useFilter(this.data, convert);
74+
75+
}
76+
77+
78+
}
79+
80+
public int getAttributeIndex(String label) {
81+
return this.data.attribute(label).index();
5882
}
5983

6084
public String getDatapath() {

src/main/java/fr/ensai/renjin_ml/test.java

+11-34
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,10 @@ public test() {
1717
public static void main(String[] args) throws Exception {
1818

1919
//create the tree object
20-
RDecisionTree tree = new RDecisionTree("Species~.", "class", 5);
21-
tree.fit("\"src/main/resources/iris.csv\"");
20+
RDecisionTree tree = new RDecisionTree("quality~.", "class", 12);
21+
tree.fit("\"src/main/resources/winequality.csv\"", false);
2222
tree.predict();
23-
24-
25-
// create a script engine manager:
26-
/* RenjinScriptEngineFactory factory = new RenjinScriptEngineFactory();
27-
// create a Renjin engine:
28-
RenjinScriptEngine engine = factory.getScriptEngine();
29-
30-
engine.eval("dat <- data.frame(x=1:10, y=(1:10)+rnorm(n=10))");
31-
engine.eval("print(typeof(dat))");*/
32-
33-
/* String chemin = "\"src/main/resources/iris.csv\"";
34-
engine.eval("data <- read.csv(" + chemin + ", header=T, row.names=1)");*/
35-
//engine.eval("print(data)");
36-
// engine.eval("print(lm(y ~ x, df))");
37-
/* engine.eval("rg <- lm(y ~ x, dat)");
38-
39-
//engine.eval("install.packages(\"rpart\")");
40-
engine.eval("library(rpart)");
41-
23+
4224
// ListVector vect = (ListVector) engine.get("rg");
4325
// System.out.println(engine.get("rg$residuals"));
4426
// List<String> test = List
@@ -47,29 +29,24 @@ public static void main(String[] args) throws Exception {
4729

4830
//System.out.println( df.class);
4931

50-
System.out.println( engine.eval("packageVersion(\"rpart\")"));*/
51-
52-
/*engine.eval("n=2*floor(nrow(data)/3)");
53-
engine.eval("train=data[1:n,]");
54-
engine.eval("library(rpart)"); //load the package
55-
engine.eval("fit <- rpart(Species ~ ., data=train)");*/
56-
32+
5733
System.out.println("*************************************************\n");
5834

59-
WekaDecisionTree treeWeka = new WekaDecisionTree("src/main/resources/iris.csv", true);
35+
WekaDecisionTree treeWeka = new WekaDecisionTree("src/main/resources/winequality.csv", false, true);
6036
J48 model = treeWeka.fit((float)0.7);
6137

62-
treeWeka.predict(model);
38+
//treeWeka.predict(model);
6339

64-
System.out.println("*************************************************\n");
65-
//System.out.println(treeWeka.getTest());
40+
System.out.println("**********************Position***************************\n");
41+
//System.out.println(treeWeka.getData());
42+
//System.out.println(treeWeka.getAttributeIndex("chlorides"));
6643

6744
System.out.println("*************************************************\n");
6845
//System.out.println(tree.getEngine().eval("print(train)"));
6946
//double a = ((DoubleArrayVector) tree.getEngine().eval("accuracy")).asReal();
70-
ListVector a = (ListVector) tree.getEngine().eval("pred");
47+
//ListVector a = (ListVector) tree.getEngine().eval("pred");
7148

72-
System.out.println(a.asReal());
49+
//System.out.println(tree.getEngine().eval("print(pred)"));
7350

7451

7552
}

src/main/java/metier/Data.java

+76-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,84 @@
11
package metier;
22

3+
import java.io.File;
4+
import java.io.IOException;
5+
6+
import javax.script.ScriptException;
7+
8+
import org.renjin.script.RenjinScriptEngine;
9+
10+
import weka.core.Instances;
11+
import weka.core.converters.CSVLoader;
12+
import weka.filters.Filter;
13+
import weka.filters.unsupervised.attribute.NumericToNominal;
14+
import weka.filters.unsupervised.attribute.Remove;
15+
316
public class Data {
417
String path;
5-
String header;
18+
String header="true";
619
String targetname;
20+
String hasRowNames="false"; //R
21+
String sep = ","; //R
22+
String dec = "."; //R
23+
24+
String catTarget = "false"; //weka
25+
String toNominal = "false"; //
26+
27+
28+
29+
30+
public void readR(RenjinScriptEngine engine) throws ScriptException {
31+
32+
String rownames = this.hasRowNames.toUpperCase()=="true" ? "1" : "NULL";
33+
34+
engine.eval("\"data <- read.csv(\"+path+\", header="+this.header.toUpperCase() +"," + sep + ", row.names="+rownames+")\"");
35+
36+
}
37+
38+
public Instances readWeka() throws Exception {
39+
40+
File f = new File(this.path);
41+
CSVLoader cnv = new CSVLoader();
42+
cnv.setSource(f);
43+
Instances data = cnv.getDataSet();
44+
45+
46+
/*
47+
* Retrieve the target column index and set this column as model dependant variable*/
48+
int targetIndex = data.attribute(this.targetname).index(); // target variable index
49+
data.setClassIndex(targetIndex);
50+
51+
if (this.hasRowNames.toUpperCase() == "TRUE") {
52+
String[] options = new String[2];
53+
options[0] = "-R"; // "range"
54+
options[1] = "1"; // first attribute
55+
Remove remove = new Remove(); // new instance of filter
56+
remove.setOptions(options); // set options
57+
remove.setInputFormat(data); // inform filter about dataset **AFTER** setting options
58+
data = Filter.useFilter(data, remove); // apply filter
59+
}
60+
61+
62+
63+
64+
/*
65+
* Convert numeric target variable to categories
66+
* */
67+
if (toNominal.toUpperCase()=="TRUE") {
68+
69+
NumericToNominal convert= new NumericToNominal();
70+
String[] options= new String[2];
71+
options[0]="-R";
72+
options[1]= String.valueOf(targetIndex) ;
73+
74+
convert.setOptions(options);
75+
convert.setInputFormat(data);
76+
data =Filter.useFilter(data, convert);
77+
}
78+
79+
return data;
80+
81+
}
782

883

984
}

0 commit comments

Comments
 (0)