diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 8e87b98bac061deab72265ee653b348a129c323a..b967b22e818d34fab5ae6f3ff47b62c48d42afc6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -149,7 +149,13 @@ object GradientDescent extends Logging { // Initialize weights as a column vector var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) - var regVal = 0.0 + + /** + * For the first iteration, the regVal will be initialized as sum of sqrt of + * weights if it's L2 update; for L1 update; the same logic is followed. + */ + var regVal = updater.compute( + weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2 for (i <- 1 to numIterations) { // Sample a subset (fraction miniBatchFraction) of the total data diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 889a03e3e61d2ec0fcf6f4a95f098d6e49fd9ba4..bf8f731459e991b7a1db995b8728c63f56d161ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -111,6 +111,8 @@ class SquaredL2Updater extends Updater { val step = gradient.mul(thisIterStepSize) // add up both updates from the gradient of the loss (= step) as well as // the gradient of the regularizer (= regParam * weightsOld) + // w' = w - thisIterStepSize * (gradient + regParam * w) + // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step) (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index a453de6767aa2214555ebbc607640ce7982b269d..631d0e2ad9cdb3696f9d38b124dd963fd5e6bfab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -104,4 +104,45 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa val lossDiff = loss.init.zip(loss.tail).map { case (lhs, rhs) => lhs - rhs } assert(lossDiff.count(_ > 0).toDouble / lossDiff.size > 0.8) } + + test("Test the loss and gradient of first iteration with regularization.") { + + val gradient = new LogisticGradient() + val updater = new SquaredL2Updater() + + // Add a extra variable consisting of all 1.0's for the intercept. + val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42) + val data = testData.map { case LabeledPoint(label, features) => + label -> Array(1.0, features: _*) + } + + val dataRDD = sc.parallelize(data, 2).cache() + + // Prepare non-zero weights + val initialWeightsWithIntercept = Array(1.0, 0.5) + + val regParam0 = 0 + val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD( + dataRDD, gradient, updater, 1, 1, regParam0, 1.0, initialWeightsWithIntercept) + + val regParam1 = 1 + val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD( + dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept) + + def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { + math.abs(x - y) / (math.abs(y) + 1e-15) < tol + } + + assert(compareDouble( + loss1(0), + loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + + math.pow(initialWeightsWithIntercept(1), 2)) / 2), + """For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""") + + assert( + compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) && + compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)), + "The different between newWeights with/without regularization " + + "should be initialWeightsWithIntercept.") + } }