@@ -38,10 +38,11 @@ import scala.collection.mutable
38
38
* @param maxP: max value of y
39
39
* @param feaUsed: array of used feature of the input data
40
40
*/
41
- class FMLearner (override val ctx : TaskContext , val minP : Double , val maxP : Double , val feaUsed :
42
- Array [Int ]) extends MLLearner (ctx) {
41
+ class FMLearner (override val ctx : TaskContext , val minP : Double , val maxP : Double , val feaUsed : Array [Int ])
42
+ extends MLLearner (ctx) {
43
+
43
44
val LOG : Log = LogFactory .getLog(classOf [FMLearner ])
44
- val fmmodel = new FMModel (conf, ctx)
45
+ val fmModel = new FMModel (conf, ctx)
45
46
46
47
val learnType = conf.get(MLConf .ML_FM_LEARN_TYPE , MLConf .DEFAULT_ML_FM_LEARN_TYPE )
47
48
val feaNum : Int = conf.getInt(MLConf .ML_FEATURE_NUM , MLConf .DEFAULT_ML_FEATURE_NUM )
@@ -53,12 +54,11 @@ Array[Int]) extends MLLearner(ctx) {
53
54
val reg2 : Double = conf.getDouble(MLConf .ML_FM_REG2 , MLConf .DEFAULT_ML_FM_REG2 )
54
55
val lr : Double = conf.getDouble(MLConf .ML_LEARN_RATE , MLConf .DEFAULT_ML_LEAR_RATE )
55
56
val vStddev : Double = conf.getDouble(MLConf .ML_FM_V_STDDEV , MLConf .DEFAULT_ML_FM_V_INIT )
56
- // Put used feature indexes to vIndexs
57
- // val vIndexs = feaUsed.zipWithIndex.filter((p:(Int,Int))=>p._1!=0).map((p:(Int,Int))=>p._2).array
58
- val vIndexs = new RowIndex ()
59
- feaUsed.zipWithIndex.filter((p: (Int , Int ))=> p._1!= 0 ).map((p: (Int , Int ))=> vIndexs.addRowId(p._2))
60
- val feaUsedN = vIndexs.getRowsNumber
61
- LOG .info(" vIndexs's row's number = " + vIndexs)
57
+
58
+ val vIndexes = new RowIndex ()
59
+ feaUsed.zipWithIndex.filter((p: (Int , Int ))=> p._1!= 0 ).map((p: (Int , Int ))=> vIndexes.addRowId(p._2))
60
+ val feaUsedN = vIndexes.getRowsNumber
61
+ LOG .info(" vIndexs's row's number = " + vIndexes)
62
62
63
63
/**
64
64
* Train a Factorization machines Model
@@ -68,8 +68,7 @@ Array[Int]) extends MLLearner(ctx) {
68
68
* @return : a learned model
69
69
*/
70
70
override
71
- def train (trainData : DataBlock [LabeledData ], vali : DataBlock [LabeledData ]):
72
- MLModel = {
71
+ def train (trainData : DataBlock [LabeledData ], vali : DataBlock [LabeledData ]): MLModel = {
73
72
val start = System .currentTimeMillis()
74
73
LOG .info(s " learnType= $learnType, feaNum= $feaNum, rank= $rank, #trainData= ${trainData.size}" )
75
74
LOG .info(s " reg0= $reg0, reg1= $reg1, reg2= $reg2, lr= $lr, vStev= $vStddev" )
@@ -79,7 +78,7 @@ Array[Int]) extends MLLearner(ctx) {
79
78
val initCost = System .currentTimeMillis() - beforeInit
80
79
LOG .info(s " Init matrixes cost $initCost ms. " )
81
80
82
- globalMetrics.addMetrics(fmmodel .FM_OBJ , LossMetric (trainData.size()))
81
+ globalMetrics.addMetrics(fmModel .FM_OBJ , LossMetric (trainData.size()))
83
82
84
83
while (ctx.getIteration < epochNum) {
85
84
val startIter = System .currentTimeMillis()
@@ -90,7 +89,7 @@ Array[Int]) extends MLLearner(ctx) {
90
89
val loss = evaluate(trainData, w0.get(0 ), w, v)
91
90
val valiCost = System .currentTimeMillis() - startVali
92
91
93
- globalMetrics.metrics(fmmodel .FM_OBJ , loss)
92
+ globalMetrics.metrics(fmModel .FM_OBJ , loss)
94
93
LOG .info(s " Epoch= ${ctx.getIteration}, evaluate loss= ${loss/ trainData.size()}. " +
95
94
s " trainCost= $iterCost, " +
96
95
s " valiCost= $valiCost" )
@@ -101,7 +100,7 @@ Array[Int]) extends MLLearner(ctx) {
101
100
val end = System .currentTimeMillis()
102
101
val cost = end - start
103
102
LOG .info(s " FM Learner train cost $cost ms. " )
104
- fmmodel
103
+ fmModel
105
104
}
106
105
107
106
/**
@@ -110,22 +109,24 @@ Array[Int]) extends MLLearner(ctx) {
110
109
def initModels (): Unit = {
111
110
if (ctx.getTaskId.getIndex == 0 ) {
112
111
for (row <- 0 until feaNum) {
113
- fmmodel .v.update(new RandomNormal (fmmodel .v.getMatrixId(), row, 0.0 , vStddev)).get()
112
+ fmModel .v.update(new RandomNormal (fmModel .v.getMatrixId(), row, 0.0 , vStddev)).get()
114
113
}
115
114
}
116
115
117
- fmmodel .v.clock().get()
116
+ fmModel .v.clock().get()
118
117
}
119
118
120
119
/**
121
120
* One iteration to train Factorization Machines
121
+ *
122
122
* @param dataBlock
123
123
* @return
124
124
*/
125
- def oneIteration (dataBlock : DataBlock [LabeledData ]): (DenseDoubleVector ,
126
- DenseDoubleVector , mutable.HashMap [Int , DenseDoubleVector ]) = {
125
+ def oneIteration (dataBlock : DataBlock [LabeledData ]):
126
+ (DenseDoubleVector , DenseDoubleVector , mutable.HashMap [Int , DenseDoubleVector ]) = {
127
+
127
128
val startGet = System .currentTimeMillis()
128
- val (w0, w, v) = fmmodel .pullFromPS(vIndexs )
129
+ val (w0, w, v) = fmModel .pullFromPS(vIndexes )
129
130
val getCost = System .currentTimeMillis() - startGet
130
131
LOG .info(s " Get matrixes cost $getCost ms. " )
131
132
@@ -154,7 +155,7 @@ Array[Int]) extends MLLearner(ctx) {
154
155
v(update._1).plusBy(update._2, - 1.0 ).timesBy(- 1.0 )
155
156
}
156
157
157
- fmmodel .pushToPS(w0.plusBy(_w0, - 1.0 ).timesBy(- 1.0 ).asInstanceOf [DenseDoubleVector ],
158
+ fmModel .pushToPS(w0.plusBy(_w0, - 1.0 ).timesBy(- 1.0 ).asInstanceOf [DenseDoubleVector ],
158
159
w.plusBy(_w, - 1.0 ).timesBy(- 1.0 ).asInstanceOf [DenseDoubleVector ],
159
160
v)
160
161
@@ -163,6 +164,7 @@ Array[Int]) extends MLLearner(ctx) {
163
164
164
165
/**
165
166
* Evaluate the objective value
167
+ *
166
168
* @param dataBlock
167
169
* @param w0
168
170
* @param w
@@ -188,6 +190,7 @@ Array[Int]) extends MLLearner(ctx) {
188
190
189
191
/**
190
192
* Predict an instance
193
+ *
191
194
* @param x:feature vector of instance
192
195
* @param y: label value of instance
193
196
* @param w0: w0 mat of FM
@@ -220,6 +223,7 @@ Array[Int]) extends MLLearner(ctx) {
220
223
221
224
/**
222
225
* \frac{\partial loss}{\partial x} = dm * \frac{\partial y}{\partial x}
226
+ *
223
227
* @param y: label of the instance
224
228
* @param pre: predict value of the instance
225
229
* @return : dm value
@@ -240,6 +244,7 @@ Array[Int]) extends MLLearner(ctx) {
240
244
241
245
/**
242
246
* Update v mat
247
+ *
243
248
* @param x: a train instance
244
249
* @param dm: dm value of the instance
245
250
* @param v: v mat
0 commit comments