1+ /*
2+ *
3+ * Copyright (c) 2016 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University.
4+ * All Rights Reserved.
5+ * Licensed under the Apache License, Version 2.0 (the "License");
6+ * You may not use this file except in compliance with the License.
7+ * You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ *
16+ */
17+
118/**
219 * We licence this file to you under the Apache Licence 2.0; you could get a copy
320 * of the licence from http://www.apache.org/licenses/LICENSE-2.0.
421 */
522package libble .examples
623
7- import libble .collaborativeFiltering .{MatrixFactorization , Rating }
24+ import libble .collaborativeFiltering .{MatrixFactorizationByScope , MatrixFactorization , Rating }
825import org .apache .log4j .{Level , Logger }
926import org .apache .spark .{SparkConf , SparkContext }
1027
@@ -16,47 +33,53 @@ import scala.collection.mutable
1633 */
1734object testCF {
1835 def main (args : Array [String ]) {
19-
20- if (args.length < 1 ) {
21- System .err.println(" Usage: ~ path:String --numIters=Int --numParts=Int --rank=Int --regParam_u=Double --regParam_v=Double --stepsize=Double" )
22- System .exit(1 )
23- }
24-
25- val optionsList = args.drop(1 ).map { arg =>
36+ val optionsList = args.map { arg =>
2637 arg.dropWhile(_ == '-' ).split('=' ) match {
2738 case Array (opt, v) => (opt -> v)
2839 case _ => throw new IllegalArgumentException (" Invalid argument: " + arg)
2940 }
3041 }
3142 val options = mutable.Map (optionsList : _* )
32- System .setProperty(" hadoop.home.dir" , " D:\\ Program Files\\ hadoop-2.6.0" )
3343 Logger .getLogger(" org.apache.spark" ).setLevel(Level .WARN )
3444 Logger .getLogger(" org.eclipse.jetty.server" ).setLevel(Level .OFF )
3545
3646 val conf = new SparkConf ()
3747 .setAppName(" testMF" )
3848 val sc = new SparkContext (conf)
3949
50+ val trainsetPath = options.remove(" trainset" ).map(_.toString).getOrElse(" data\\ testMF.txt" )
51+ val stepsize = options.remove(" stepsize" ).map(_.toDouble).getOrElse(0.1 )
52+ val regParam_u = options.remove(" regParam_u" ).map(_.toDouble).getOrElse(0.1 )
53+ val regParam_v = options.remove(" regParam_u" ).map(_.toDouble).getOrElse(0.1 )
54+ val numIters = options.remove(" numIters" ).map(_.toInt).getOrElse(50 )
55+ val numParts = options.remove(" numParts" ).map(_.toInt).getOrElse(2 )
56+ val rank = options.remove(" rank" ).map(_.toInt).getOrElse(10 )
57+ val testsetPath = options.remove(" testset" ).map(_.toString).getOrElse(" data\\ testMF.txt" )
4058
41- val stepsize = options.remove(" stepsize" ).map(_.toDouble).getOrElse(0.01 )
42- val regParam_u = options.remove(" regParam_u" ).map(_.toDouble).getOrElse(0.05 )
43- val regParam_v = options.remove(" regParam_u" ).map(_.toDouble).getOrElse(0.05 )
44- val numIters = options.remove(" numIters" ).map(_.toInt).getOrElse(200 )
45- val numParts = options.remove(" numParts" ).map(_.toInt).getOrElse(16 )
46- val rank = options.remove(" rank" ).map(_.toInt).getOrElse(40 )
47-
48- val trainSet = sc.textFile(args(0 ), numParts)
59+ val trainSet = sc.textFile(trainsetPath, numParts)
60+ .map(_.split(',' ) match { case Array (user, item, rate) =>
61+ Rating (rate.toDouble, user.toInt, item.toInt)
62+ })
63+ val testSet = sc.textFile(testsetPath, numParts)
4964 .map(_.split(',' ) match { case Array (user, item, rate) =>
5065 Rating (rate.toDouble, user.toInt, item.toInt)
5166 })
5267
53- val model = new MatrixFactorization ()
68+ val model = new MatrixFactorizationByScope ()
5469 .train(trainSet,
5570 numIters,
5671 numParts,
5772 rank,
5873 regParam_u,
5974 regParam_v,
6075 stepsize)
76+
77+ val result = model.predict(testSet.map(r=> (r.index_x,r.index_y)))
78+ val rmse = result.map(r=> ((r.index_x,r.index_y), r.rating))
79+ .join(testSet.map(r=> ((r.index_x,r.index_y), r.rating)))
80+ .values
81+ .map(i => math.pow(i._1 - i._2, 2 ))
82+ .sum() / testSet.count()
83+ println(s " rmse of test set: $rmse" )
6184 }
6285}
0 commit comments