@@ -130,7 +130,7 @@ class L1Updater extends Updater {
130130
131131 override def update (data : RDD [(Double , Vector )], weights : Vector , mu : Vector , lossfunc : LossFunc , stepSize : Double , factor : Double , regParam : Double ): Vector = {
132132
133- val preFact = 1.0 - stepSize * (regParam + factor)
133+ val preFact = 1.0 - stepSize * factor
134134
135135 val upFact = - stepSize / preFact
136136
@@ -246,3 +246,77 @@ class L2Updater extends Updater {
246246 }
247247}
248248
249+ class elasticNetUpdater (val alpha : Double ) extends Updater {
250+ val beta = 1 - alpha
251+
252+
253+ override def update (data : RDD [(Double , Vector )], weights : Vector , mu : Vector , lossfunc : LossFunc , stepSize : Double , factor : Double , regParam : Double ): Vector = {
254+
255+ val preFact = 1.0 - stepSize * (regParam* alpha + factor)
256+ val upFact = - stepSize / preFact
257+ mu.plusax(- factor, weights)
258+ val w_0 = data.sparkContext.broadcast(weights)
259+ val fix = data.sparkContext.broadcast(mu)
260+ val partsNum = data.partitions.length
261+ val chkSize = findChkSize(preFact)
262+
263+ val l1fact = stepSize * regParam* beta
264+
265+ data.mapPartitions(iter => {
266+ val omiga = new WeightsVector (w_0.value.copy, fix.value)
267+ val indexedSeq = iter.toIndexedSeq
268+ val pNum = indexedSeq.size
269+
270+ val rand = new Random (partsNum * pNum)
271+
272+ val flags = new Array [Int ](omiga.size)
273+ util.Arrays .fill(flags, 0 )
274+
275+ for (j <- 1 to pNum) {
276+ val e = indexedSeq(rand.nextInt(pNum))
277+ val f1 = lossfunc.deltaF(e._2, e._1, omiga)
278+ f1 -= lossfunc.deltaF(e._2, e._1, w_0.value)
279+ // val delta = f1 x e._2
280+ // delta += mu
281+ if (j % chkSize == 0 )
282+ omiga.merge()
283+
284+ val oValues = omiga.partA.toArray
285+ e._2.foreachActive { (i, v) =>
286+ val wi = omiga.apply(i)
287+ oValues(i) = (math.signum(wi) * max(0.0 , abs(wi) - (j - 1 - flags(i)) * l1fact) - omiga.fac_b * omiga.partB(i)) / omiga.fac_a
288+ flags(i) = j - 1
289+ }
290+
291+ omiga.partA.plusax(upFact / omiga.fac_a, f1 x e._2)
292+ omiga.fac_a *= preFact
293+ omiga.fac_b *= preFact
294+ omiga.fac_b -= stepSize
295+
296+ }
297+ Iterator (omiga.toDense())
298+
299+ }, true ).treeAggregate(new DenseVector (weights.size))(seqOp = (c, w) => {
300+ c += w
301+ }, combOp = (c1, c2) => {
302+ c1 += c2
303+ }) /= (partsNum)
304+
305+
306+
307+ }
308+ /**
309+ * In this method, we give the cost of the regularizer.
310+ *
311+ * @param weight
312+ * @param regParam
313+ * @return regCost
314+ */
315+ override def getRegVal (weight : Vector , regParam : Double ): Double = {
316+ val norm1 = weight.norm1()
317+ val norm2 = weight.norm2()
318+ regParam* (0.5 * alpha* norm2* norm2+ beta* norm1)
319+
320+ }
321+ }
322+
0 commit comments