Skip to content

Commit 73cdc97

Browse files
authored
Merge pull request #11 from Ru-Xiang/master
add elastic net and L1 bug fix
2 parents bb526af + 033905a commit 73cdc97

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

src/main/scala/generalizedLinear/Updater.scala

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class L1Updater extends Updater {
130130

131131
override def update(data: RDD[(Double, Vector)], weights: Vector, mu: Vector, lossfunc: LossFunc, stepSize: Double, factor: Double, regParam: Double): Vector = {
132132

133-
val preFact = 1.0 - stepSize * (regParam + factor)
133+
val preFact = 1.0 - stepSize * factor
134134

135135
val upFact = -stepSize / preFact
136136

@@ -246,3 +246,77 @@ class L2Updater extends Updater {
246246
}
247247
}
248248

249+
class elasticNetUpdater(val alpha:Double) extends Updater{
250+
val beta= 1-alpha
251+
252+
253+
override def update(data: RDD[(Double, Vector)], weights: Vector, mu: Vector, lossfunc: LossFunc, stepSize: Double, factor: Double, regParam: Double): Vector = {
254+
255+
val preFact = 1.0 - stepSize * (regParam*alpha + factor)
256+
val upFact = -stepSize / preFact
257+
mu.plusax(-factor, weights)
258+
val w_0 = data.sparkContext.broadcast(weights)
259+
val fix = data.sparkContext.broadcast(mu)
260+
val partsNum = data.partitions.length
261+
val chkSize = findChkSize(preFact)
262+
263+
val l1fact=stepSize * regParam*beta
264+
265+
data.mapPartitions(iter => {
266+
val omiga = new WeightsVector(w_0.value.copy, fix.value)
267+
val indexedSeq = iter.toIndexedSeq
268+
val pNum = indexedSeq.size
269+
270+
val rand = new Random(partsNum * pNum)
271+
272+
val flags = new Array[Int](omiga.size)
273+
util.Arrays.fill(flags, 0)
274+
275+
for (j <- 1 to pNum) {
276+
val e = indexedSeq(rand.nextInt(pNum))
277+
val f1 = lossfunc.deltaF(e._2, e._1, omiga)
278+
f1 -= lossfunc.deltaF(e._2, e._1, w_0.value)
279+
// val delta = f1 x e._2
280+
// delta += mu
281+
if (j % chkSize == 0)
282+
omiga.merge()
283+
284+
val oValues = omiga.partA.toArray
285+
e._2.foreachActive { (i, v) =>
286+
val wi = omiga.apply(i)
287+
oValues(i) = (math.signum(wi) * max(0.0, abs(wi) - (j - 1 - flags(i)) * l1fact) - omiga.fac_b * omiga.partB(i)) / omiga.fac_a
288+
flags(i) = j - 1
289+
}
290+
291+
omiga.partA.plusax(upFact / omiga.fac_a, f1 x e._2)
292+
omiga.fac_a *= preFact
293+
omiga.fac_b *= preFact
294+
omiga.fac_b -= stepSize
295+
296+
}
297+
Iterator(omiga.toDense())
298+
299+
}, true).treeAggregate(new DenseVector(weights.size))(seqOp = (c, w) => {
300+
c += w
301+
}, combOp = (c1, c2) => {
302+
c1 += c2
303+
}) /= (partsNum)
304+
305+
306+
307+
}
308+
/**
309+
* In this method, we give the cost of the regularizer.
310+
*
311+
* @param weight
312+
* @param regParam
313+
* @return regCost
314+
*/
315+
override def getRegVal(weight: Vector, regParam: Double): Double = {
316+
val norm1= weight.norm1()
317+
val norm2=weight.norm2()
318+
regParam*(0.5*alpha*norm2*norm2+beta*norm1)
319+
320+
}
321+
}
322+

0 commit comments

Comments
 (0)