diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 8617722ae542f7dc5fab85626b042542abcb66e5..797870eb8ce8abf0869e20f50733313da2b57c9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares( val aaBar = summary.aaBar val aaValues = aaBar.values + if (bStd == 0) { + if (fitIntercept) { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + val coefficients = new DenseVector(Array.ofDim(k-1)) + val intercept = bBar + val diagInvAtWA = new DenseVector(Array(0D)) + return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + } else { + require(!(regParam > 0.0 && standardizeLabel), + "The standard deviation of the label is zero. " + + "Model cannot be regularized with standardization=true") + logWarning(s"The standard deviation of the label is zero. " + + "Consider setting fitIntercept=true.") + } + } + // add regularization to diagonals var i = 0 var j = 2 @@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares( if (standardizeFeatures) { lambda *= aVar(j - 2) } - if (standardizeLabel) { - // TODO: handle the case when bStd = 0 + if (standardizeLabel && bStd != 0) { lambda /= bStd } aaValues(i) += lambda diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index b542ba3dc54d273446535fe76bf582368ac24e41..0b58a9821f57b85435ce62bfbd2ea3018379ccbe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { private var instances: RDD[Instance] = _ + private var instancesConstLabel: RDD[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(17, 17, 17, 17) + w <- c(1, 2, 3, 4) + */ + instancesConstLabel = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) } test("WLS against lm") { @@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) - .fit(instances) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + for (standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization).fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } + idx += 1 + } + } + + test("WLS against lm when label is constant and no regularization") { + /* + R code: + + df.const.label <- as.data.frame(cbind(A, b.const)) + for (formula in c(b.const ~ . -1, b.const ~ .)) { + model <- lm(formula, data=df.const.label, weights=w) + print(as.vector(coef(model))) + } + + [1] -9.221298 3.394343 + [1] 17 0 0 + */ + + val expected = Seq( + Vectors.dense(0.0, -9.221298, 3.394343), + Vectors.dense(17.0, 0.0, 0.0)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization).fit(instancesConstLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } idx += 1 } } + test("WLS with regularization when label is constant") { + // if regParam is non-zero and standardization is true, the problem is ill-defined and + // an exception is thrown. + val wls = new WeightedLeastSquares( + fitIntercept = false, regParam = 0.1, standardizeFeatures = true, + standardizeLabel = true) + intercept[IllegalArgumentException]{ + wls.fit(instancesConstLabel) + } + } + test("WLS against glmnet") { /* R code: