Skip to content

Commit cc2d735

Browse files
authored
Merge pull request #7 from syh6585/master
modify testCF
2 parents ac72084 + 8d28cdf commit cc2d735

File tree

6 files changed

+133
-49
lines changed

6 files changed

+133
-49
lines changed

.idea/copyright/libble.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/copyright/profiles_settings.xml

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

data/testMF.txt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
1,1,3.0
2+
1,2,4.0
3+
1,3,2.8
4+
1,4,4.0
5+
1,5,3.7
6+
1,6,4.7
7+
2,1,2.0
8+
2,2,5.0
9+
2,3,4.8
10+
2,4,2.6
11+
2,5,4.2
12+
2,6,3.0
13+
3,1,4.3
14+
3,2,3.2
15+
3,3,5.0
16+
3,4,4.9
17+
3,5,3.2
18+
3,6,4.0
19+
4,1,3.0
20+
4,2,4.3
21+
4,3,4.3
22+
4,4,1.0
23+
4,5,3.2
24+
4,6,2.3
25+
5,1,4.0
26+
5,2,4.3
27+
5,3,4.5
28+
5,4,2.3
29+
5,5,2.0
30+
5,6,1.0

src/main/scala/collaborativeFiltering/MatrixFactorization.scala

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
/*
2+
*
3+
* Copyright (c) 2016 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University.
4+
* All Rights Reserved.
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* You may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
118
/**
219
* Created by syh on 2016/12/9.
320
*/
@@ -90,26 +107,26 @@ class MatrixFactorization extends Serializable{
90107
var testTime = 0L
91108
var i = 0
92109
while (i < numIters){
93-
//loss
94-
val testTimeStart = System.currentTimeMillis()
95-
val bc_test_itemFactors = ratingsByRow.context.broadcast(itemFactors)
96-
val loss = ratingsByRow.mapPartitionsWithIndex {(index,iter) =>
97-
val localV = bc_test_itemFactors.value
98-
val localU = MatrixFactorization.workerstore.get[Map[Int, Vector]](s"userFactors_$index")
99-
val reguV = localV.mapValues(v => lambda_v * v.dot(v))
100-
val reguU = localU.mapValues(u => lambda_u * u.dot(u))
101-
val ls = iter.foldLeft(0.0) { (l, r) =>
102-
val uh = localU.get(r.index_x).get
103-
val vj = localV.get(r.index_y).get
104-
val residual = r.rating - uh.dot(vj)
105-
l + residual * residual + reguU.get(r.index_x).get + reguV.get(r.index_y).get
106-
}
107-
Iterator.single(ls)
108-
}.reduce(_ + _) / numRatings
109-
bc_test_itemFactors.unpersist()
110-
print(s"$loss\t")
111-
testTime += (System.currentTimeMillis() - testTimeStart)
112-
println(s"${System.currentTimeMillis() - testTime - startTime}")
110+
// //loss
111+
// val testTimeStart = System.currentTimeMillis()
112+
// val bc_test_itemFactors = ratingsByRow.context.broadcast(itemFactors)
113+
// val loss = ratingsByRow.mapPartitionsWithIndex {(index,iter) =>
114+
// val localV = bc_test_itemFactors.value
115+
// val localU = MatrixFactorization.workerstore.get[Map[Int, Vector]](s"userFactors_$index")
116+
// val reguV = localV.mapValues(v => lambda_v * v.dot(v))
117+
// val reguU = localU.mapValues(u => lambda_u * u.dot(u))
118+
// val ls = iter.foldLeft(0.0) { (l, r) =>
119+
// val uh = localU.get(r.index_x).get
120+
// val vj = localV.get(r.index_y).get
121+
// val residual = r.rating - uh.dot(vj)
122+
// l + residual * residual + reguU.get(r.index_x).get + reguV.get(r.index_y).get
123+
// }
124+
// Iterator.single(ls)
125+
// }.reduce(_ + _) / numRatings
126+
// bc_test_itemFactors.unpersist()
127+
// print(s"$loss\t")
128+
// testTime += (System.currentTimeMillis() - testTimeStart)
129+
// println(s"${System.currentTimeMillis() - testTime - startTime}")
113130
//broadcast V to p workers
114131
val bc_itemFactors = ratingsByRow.context.broadcast(itemFactors)
115132
//for each woker i parallelly do
Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,27 @@
1+
/*
2+
*
3+
* Copyright (c) 2016 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University.
4+
* All Rights Reserved.
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* You may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
118
/**
219
* We licence this file to you under the Apache Licence 2.0; you could get a copy
320
* of the licence from http://www.apache.org/licenses/LICENSE-2.0.
421
*/
522
package libble.examples
623

7-
import libble.collaborativeFiltering.{MatrixFactorization, Rating}
24+
import libble.collaborativeFiltering.{MatrixFactorizationByScope, MatrixFactorization, Rating}
825
import org.apache.log4j.{Level, Logger}
926
import org.apache.spark.{SparkConf, SparkContext}
1027

@@ -16,47 +33,53 @@ import scala.collection.mutable
1633
*/
1734
object testCF {
1835
def main(args: Array[String]) {
19-
20-
if (args.length < 1) {
21-
System.err.println("Usage: ~ path:String --numIters=Int --numParts=Int --rank=Int --regParam_u=Double --regParam_v=Double --stepsize=Double")
22-
System.exit(1)
23-
}
24-
25-
val optionsList = args.drop(1).map { arg =>
36+
val optionsList = args.map { arg =>
2637
arg.dropWhile(_ == '-').split('=') match {
2738
case Array(opt, v) => (opt -> v)
2839
case _ => throw new IllegalArgumentException("Invalid argument: " + arg)
2940
}
3041
}
3142
val options = mutable.Map(optionsList: _*)
32-
System.setProperty("hadoop.home.dir", "D:\\Program Files\\hadoop-2.6.0")
3343
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
3444
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
3545

3646
val conf = new SparkConf()
3747
.setAppName("testMF")
3848
val sc = new SparkContext(conf)
3949

50+
val trainsetPath = options.remove("trainset").map(_.toString).getOrElse("data\\testMF.txt")
51+
val stepsize = options.remove("stepsize").map(_.toDouble).getOrElse(0.1)
52+
val regParam_u = options.remove("regParam_u").map(_.toDouble).getOrElse(0.1)
53+
val regParam_v = options.remove("regParam_u").map(_.toDouble).getOrElse(0.1)
54+
val numIters = options.remove("numIters").map(_.toInt).getOrElse(50)
55+
val numParts = options.remove("numParts").map(_.toInt).getOrElse(2)
56+
val rank = options.remove("rank").map(_.toInt).getOrElse(10)
57+
val testsetPath = options.remove("testset").map(_.toString).getOrElse("data\\testMF.txt")
4058

41-
val stepsize = options.remove("stepsize").map(_.toDouble).getOrElse(0.01)
42-
val regParam_u = options.remove("regParam_u").map(_.toDouble).getOrElse(0.05)
43-
val regParam_v = options.remove("regParam_u").map(_.toDouble).getOrElse(0.05)
44-
val numIters = options.remove("numIters").map(_.toInt).getOrElse(200)
45-
val numParts = options.remove("numParts").map(_.toInt).getOrElse(16)
46-
val rank = options.remove("rank").map(_.toInt).getOrElse(40)
47-
48-
val trainSet = sc.textFile(args(0), numParts)
59+
val trainSet = sc.textFile(trainsetPath, numParts)
60+
.map(_.split(',') match { case Array(user, item, rate) =>
61+
Rating(rate.toDouble, user.toInt, item.toInt)
62+
})
63+
val testSet = sc.textFile(testsetPath, numParts)
4964
.map(_.split(',') match { case Array(user, item, rate) =>
5065
Rating(rate.toDouble, user.toInt, item.toInt)
5166
})
5267

53-
val model = new MatrixFactorization()
68+
val model = new MatrixFactorizationByScope()
5469
.train(trainSet,
5570
numIters,
5671
numParts,
5772
rank,
5873
regParam_u,
5974
regParam_v,
6075
stepsize)
76+
77+
val result = model.predict(testSet.map(r=>(r.index_x,r.index_y)))
78+
val rmse = result.map(r=>((r.index_x,r.index_y), r.rating))
79+
.join(testSet.map(r=>((r.index_x,r.index_y), r.rating)))
80+
.values
81+
.map(i => math.pow(i._1 - i._2, 2))
82+
.sum() / testSet.count()
83+
println(s"rmse of test set: $rmse")
6184
}
6285
}

src/main/scala/examples/testLR.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
/*
2-
* Copyright (c) 2016 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University.
3-
* All Rights Reserved.
4-
* Licensed under the Apache License, Version 2.0 (the "License");
5-
* You may not use this file except in compliance with the License.
6-
* You may obtain a copy of the License at
72
*
8-
* http://www.apache.org/licenses/LICENSE-2.0
3+
* Copyright (c) 2016 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University.
4+
* All Rights Reserved.
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* You may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
915
*
10-
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS,
11-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
* See the License for the specific language governing permissions and
13-
* limitations under the License.
1416
*/
1517
package libble.examples
1618

@@ -55,5 +57,7 @@ object testLR {
5557
val training = sc.loadLIBBLEFile(args(0), numPart)
5658
val m = new LogisticRegression(stepSize, regParam, elasticF, numIter, numPart)
5759
m.train(training)
60+
61+
5862
}
5963
}

0 commit comments

Comments
 (0)