diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index b9621530efa22457b8fb0dacbb0eebd187bd79b2..3e1ed91bf67291f773ad18bdfa61c30f18ece83c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0))) + input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features)) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - initialWeights.+:(1.0) + 0.0 +: initialWeights } else { initialWeights } - val weights = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = weights(0) - val weightsScaled = weights.tail + val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val model = createModel(weightsScaled, intercept) + val (intercept, weights) = if (addIntercept) { + (weightsWithIntercept(0), weightsWithIntercept.tail) + } else { + (0.0, weightsWithIntercept) + } + + logInfo("Final weights " + weights.mkString(",")) + logInfo("Final intercept " + intercept) - logInfo("Final model weights " + model.weights.mkString(",")) - logInfo("Final model intercept " + model.intercept) - model + createModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index fb2bc9b92a51c21f9c9a55a781eaaa0b534a5787..be63ce8538fefc2689fa7f7a094a28734f25fc89 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -36,8 +36,10 @@ class LassoModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -66,7 +68,7 @@ class LassoWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -77,10 +79,16 @@ class LassoWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) new LassoModel(weightsScaled.data, interceptScaled) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 8ee40addb25d95c943fc3dd8bc34bdc976280897..f5f15d1a33f4d1ace1f5a47af60f08a2f0c97dbe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LinearRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + + override def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private ( var stepSize: Double, var numIterations: Int, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LinearRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { val gradient = new LeastSquaresGradient() val updater = new SimpleUpdater() @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + override def createModel(weights: Array[Double], intercept: Double) = { new LinearRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index c504d3d40c773acaf48e819e04d6f7640c9ae2b8..feb100f21888ffc14aed0560afca3f50d2da0d7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -36,8 +36,10 @@ class RidgeRegressionModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept in RidgeRegression, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 64e4cbb860f61d8f805fe956fc98fd50ff062a85..2cebac943e15f66643a6c57ee4b2934051f71eeb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -17,11 +17,8 @@ package org.apache.spark.mllib.regression - -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import org.apache.spark.SparkContext import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class LassoSuite extends FunSuite with LocalSparkContext { @@ -104,4 +101,10 @@ class LassoSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("do not support intercept") { + intercept[UnsupportedOperationException] { + new LassoWithSGD().setIntercept(true) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 281f9df36ddb3a025d21e68b6f2d6e85dbabf6b9..5d251bcbf35dbc2c6334fa2ba7ad5992afb589ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.regression -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} @@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + // Test if we can correctly learn Y = 10*X1 + 10*X2 + test("linear regression without intercept") { + val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 42), 2).cache() + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(testRDD) + + assert(model.intercept === 0.0) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 17) + val validationRDD = sc.parallelize(validationData, 2).cache() + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 67dd06cc0f5eb9d1cb56dc5a231c6749bb5d5671..b2044ed0d8066793c5353a20627785cdf0ed7a55 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -17,14 +17,11 @@ package org.apache.spark.mllib.regression - import org.jblas.DoubleMatrix -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} - class RidgeRegressionSuite extends FunSuite with LocalSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { @@ -74,4 +71,10 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } + + test("do not support intercept") { + intercept[UnsupportedOperationException] { + new RidgeRegressionWithSGD().setIntercept(true) + } + } }