From 9753835cf3acc135e61bf668223046e29306c80d Mon Sep 17 00:00:00 2001
From: Imran Younus <iyounus@us.ibm.com>
Date: Wed, 20 Jan 2016 11:16:59 -0800
Subject: [PATCH] [SPARK-12230][ML] WeightedLeastSquares.fit() should handle
 division by zero properly if standard deviation of target variable is zero.

This fixes the behavior of WeightedLeastSquars.fit() when the standard deviation of the target variable is zero. If the fitIntercept is true, there is no need to train.

Author: Imran Younus <iyounus@us.ibm.com>

Closes #10274 from iyounus/SPARK-12230_bug_fix_in_weighted_least_squares.
---
 .../spark/ml/optim/WeightedLeastSquares.scala | 21 +++++-
 .../ml/optim/WeightedLeastSquaresSuite.scala  | 69 +++++++++++++++++--
 2 files changed, 83 insertions(+), 7 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 8617722ae5..797870eb8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares(
     val aaBar = summary.aaBar
     val aaValues = aaBar.values
 
+    if (bStd == 0) {
+      if (fitIntercept) {
+        logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
+          s"zeros and the intercept will be the mean of the label; as a result, " +
+          s"training is not needed.")
+        val coefficients = new DenseVector(Array.ofDim(k-1))
+        val intercept = bBar
+        val diagInvAtWA = new DenseVector(Array(0D))
+        return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA)
+      } else {
+        require(!(regParam > 0.0 && standardizeLabel),
+          "The standard deviation of the label is zero. " +
+            "Model cannot be regularized with standardization=true")
+        logWarning(s"The standard deviation of the label is zero. " +
+          "Consider setting fitIntercept=true.")
+      }
+    }
+
     // add regularization to diagonals
     var i = 0
     var j = 2
@@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares(
       if (standardizeFeatures) {
         lambda *= aVar(j - 2)
       }
-      if (standardizeLabel) {
-        // TODO: handle the case when bStd = 0
+      if (standardizeLabel && bStd != 0) {
         lambda /= bStd
       }
       aaValues(i) += lambda
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
index b542ba3dc5..0b58a9821f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
 class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
 
   private var instances: RDD[Instance] = _
+  private var instancesConstLabel: RDD[Instance] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
       Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
       Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
     ), 2)
+
+    /*
+       R code:
+
+       A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+       b.const <- c(17, 17, 17, 17)
+       w <- c(1, 2, 3, 4)
+     */
+    instancesConstLabel = sc.parallelize(Seq(
+      Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+      Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+      Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+    ), 2)
   }
 
   test("WLS against lm") {
@@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
 
     var idx = 0
     for (fitIntercept <- Seq(false, true)) {
-      val wls = new WeightedLeastSquares(
-        fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false)
-        .fit(instances)
-      val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
-      assert(actual ~== expected(idx) absTol 1e-4)
+       for (standardization <- Seq(false, true)) {
+         val wls = new WeightedLeastSquares(
+           fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+           standardizeLabel = standardization).fit(instances)
+         val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+         assert(actual ~== expected(idx) absTol 1e-4)
+       }
+      idx += 1
+    }
+  }
+
+  test("WLS against lm when label is constant and no regularization") {
+    /*
+       R code:
+
+       df.const.label <- as.data.frame(cbind(A, b.const))
+       for (formula in c(b.const ~ . -1, b.const ~ .)) {
+         model <- lm(formula, data=df.const.label, weights=w)
+         print(as.vector(coef(model)))
+       }
+
+      [1] -9.221298  3.394343
+      [1] 17  0  0
+    */
+
+    val expected = Seq(
+      Vectors.dense(0.0, -9.221298, 3.394343),
+      Vectors.dense(17.0, 0.0, 0.0))
+
+    var idx = 0
+    for (fitIntercept <- Seq(false, true)) {
+      for (standardization <- Seq(false, true)) {
+        val wls = new WeightedLeastSquares(
+          fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+          standardizeLabel = standardization).fit(instancesConstLabel)
+        val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+        assert(actual ~== expected(idx) absTol 1e-4)
+      }
       idx += 1
     }
   }
 
+  test("WLS with regularization when label is constant") {
+    // if regParam is non-zero and standardization is true, the problem is ill-defined and
+    // an exception is thrown.
+    val wls = new WeightedLeastSquares(
+      fitIntercept = false, regParam = 0.1, standardizeFeatures = true,
+      standardizeLabel = true)
+    intercept[IllegalArgumentException]{
+      wls.fit(instancesConstLabel)
+    }
+  }
+
   test("WLS against glmnet") {
     /*
        R code:
-- 
GitLab