Skip to content
Snippets Groups Projects
Commit 9753835c authored by Imran Younus's avatar Imran Younus Committed by Xiangrui Meng
Browse files

[SPARK-12230][ML] WeightedLeastSquares.fit() should handle division by zero...

[SPARK-12230][ML] WeightedLeastSquares.fit() should handle division by zero properly if standard deviation of target variable is zero.

This fixes the behavior of WeightedLeastSquars.fit() when the standard deviation of the target variable is zero. If the fitIntercept is true, there is no need to train.

Author: Imran Younus <iyounus@us.ibm.com>

Closes #10274 from iyounus/SPARK-12230_bug_fix_in_weighted_least_squares.
parent 9bb35c5b
No related branches found
No related tags found
No related merge requests found
...@@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares( ...@@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares(
val aaBar = summary.aaBar val aaBar = summary.aaBar
val aaValues = aaBar.values val aaValues = aaBar.values
if (bStd == 0) {
if (fitIntercept) {
logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
s"zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
val coefficients = new DenseVector(Array.ofDim(k-1))
val intercept = bBar
val diagInvAtWA = new DenseVector(Array(0D))
return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA)
} else {
require(!(regParam > 0.0 && standardizeLabel),
"The standard deviation of the label is zero. " +
"Model cannot be regularized with standardization=true")
logWarning(s"The standard deviation of the label is zero. " +
"Consider setting fitIntercept=true.")
}
}
// add regularization to diagonals // add regularization to diagonals
var i = 0 var i = 0
var j = 2 var j = 2
...@@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares( ...@@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares(
if (standardizeFeatures) { if (standardizeFeatures) {
lambda *= aVar(j - 2) lambda *= aVar(j - 2)
} }
if (standardizeLabel) { if (standardizeLabel && bStd != 0) {
// TODO: handle the case when bStd = 0
lambda /= bStd lambda /= bStd
} }
aaValues(i) += lambda aaValues(i) += lambda
......
...@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD ...@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
private var instances: RDD[Instance] = _ private var instances: RDD[Instance] = _
private var instancesConstLabel: RDD[Instance] = _
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
...@@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext ...@@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2) ), 2)
/*
R code:
A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
b.const <- c(17, 17, 17, 17)
w <- c(1, 2, 3, 4)
*/
instancesConstLabel = sc.parallelize(Seq(
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
), 2)
} }
test("WLS against lm") { test("WLS against lm") {
...@@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext ...@@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
var idx = 0 var idx = 0
for (fitIntercept <- Seq(false, true)) { for (fitIntercept <- Seq(false, true)) {
val wls = new WeightedLeastSquares( for (standardization <- Seq(false, true)) {
fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) val wls = new WeightedLeastSquares(
.fit(instances) fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) standardizeLabel = standardization).fit(instances)
assert(actual ~== expected(idx) absTol 1e-4) val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
assert(actual ~== expected(idx) absTol 1e-4)
}
idx += 1
}
}
test("WLS against lm when label is constant and no regularization") {
/*
R code:
df.const.label <- as.data.frame(cbind(A, b.const))
for (formula in c(b.const ~ . -1, b.const ~ .)) {
model <- lm(formula, data=df.const.label, weights=w)
print(as.vector(coef(model)))
}
[1] -9.221298 3.394343
[1] 17 0 0
*/
val expected = Seq(
Vectors.dense(0.0, -9.221298, 3.394343),
Vectors.dense(17.0, 0.0, 0.0))
var idx = 0
for (fitIntercept <- Seq(false, true)) {
for (standardization <- Seq(false, true)) {
val wls = new WeightedLeastSquares(
fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
standardizeLabel = standardization).fit(instancesConstLabel)
val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
assert(actual ~== expected(idx) absTol 1e-4)
}
idx += 1 idx += 1
} }
} }
test("WLS with regularization when label is constant") {
// if regParam is non-zero and standardization is true, the problem is ill-defined and
// an exception is thrown.
val wls = new WeightedLeastSquares(
fitIntercept = false, regParam = 0.1, standardizeFeatures = true,
standardizeLabel = true)
intercept[IllegalArgumentException]{
wls.fit(instancesConstLabel)
}
}
test("WLS against glmnet") { test("WLS against glmnet") {
/* /*
R code: R code:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment