From dc06b528790c69b2e6de85cba84266fea81dd4f4 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman <shivaram@eecs.berkeley.edu> Date: Sun, 25 Aug 2013 23:14:35 -0700 Subject: [PATCH] Add an option to turn off data validation, test it. Also moves addIntercept to have default true to make it similar to validateData option --- .../classification/LogisticRegression.scala | 9 ++++----- .../scala/spark/mllib/classification/SVM.scala | 9 ++++----- .../regression/GeneralizedLinearAlgorithm.scala | 16 +++++++++++++--- .../scala/spark/mllib/regression/Lasso.scala | 9 ++++----- .../spark/mllib/classification/SVMSuite.scala | 3 +++ 5 files changed, 28 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala index 474ca6e97c..482e4a6745 100644 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala @@ -54,8 +54,7 @@ class LogisticRegressionWithSGD private ( var stepSize: Double, var numIterations: Int, var regParam: Double, - var miniBatchFraction: Double, - var addIntercept: Boolean) + var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { @@ -71,7 +70,7 @@ class LogisticRegressionWithSGD private ( /** * Construct a LogisticRegression object with default parameters */ - def this() = this(1.0, 100, 0.0, 1.0, true) + def this() = this(1.0, 100, 0.0, 1.0) def createModel(weights: Array[Double], intercept: Double) = { new LogisticRegressionModel(weights, intercept) @@ -108,7 +107,7 @@ object LogisticRegressionWithSGD { initialWeights: Array[Double]) : LogisticRegressionModel = { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run( + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( input, initialWeights) } @@ -131,7 +130,7 @@ object LogisticRegressionWithSGD { miniBatchFraction: Double) : LogisticRegressionModel = { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run( + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( input) } diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala index b680d81e86..69393cd7b0 100644 --- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala @@ -54,8 +54,7 @@ class SVMWithSGD private ( var stepSize: Double, var numIterations: Int, var regParam: Double, - var miniBatchFraction: Double, - var addIntercept: Boolean) + var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { val gradient = new HingeGradient() @@ -71,7 +70,7 @@ class SVMWithSGD private ( /** * Construct a SVM object with default parameters */ - def this() = this(1.0, 100, 1.0, 1.0, true) + def this() = this(1.0, 100, 1.0, 1.0) def createModel(weights: Array[Double], intercept: Double) = { new SVMModel(weights, intercept) @@ -107,7 +106,7 @@ object SVMWithSGD { initialWeights: Array[Double]) : SVMModel = { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input, + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, initialWeights) } @@ -131,7 +130,7 @@ object SVMWithSGD { miniBatchFraction: Double) : SVMModel = { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input) + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } /** diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 03f991df39..d164d415d6 100644 --- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -87,13 +87,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val optimizer: Optimizer + protected var addIntercept: Boolean = true + + protected var validateData: Boolean = true + /** * Create a model given the weights and intercept */ protected def createModel(weights: Array[Double], intercept: Double): M - protected var addIntercept: Boolean - /** * Set if the algorithm should add an intercept. Default true. */ @@ -102,6 +104,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } + /** + * Set if the algorithm should validate data before training. Default true. + */ + def setValidateData(validateData: Boolean): this.type = { + this.validateData = validateData + this + } + /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -119,7 +129,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { // Check the data properties before running the optimizer - if (!validators.forall(func => func(input))) { + if (validateData && !validators.forall(func => func(input))) { throw new SparkException("Input validation failed.") } diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala index 6bbc990a5a..89f791e85a 100644 --- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala @@ -48,8 +48,7 @@ class LassoWithSGD private ( var stepSize: Double, var numIterations: Int, var regParam: Double, - var miniBatchFraction: Double, - var addIntercept: Boolean) + var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[LassoModel] with Serializable { @@ -63,7 +62,7 @@ class LassoWithSGD private ( /** * Construct a Lasso object with default parameters */ - def this() = this(1.0, 100, 1.0, 1.0, true) + def this() = this(1.0, 100, 1.0, 1.0) def createModel(weights: Array[Double], intercept: Double) = { new LassoModel(weights, intercept) @@ -98,7 +97,7 @@ object LassoWithSGD { initialWeights: Array[Double]) : LassoModel = { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input, + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, initialWeights) } @@ -121,7 +120,7 @@ object LassoWithSGD { miniBatchFraction: Double) : LassoModel = { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input) + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } /** diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala index 8fa9e4639b..894ae458ad 100644 --- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala @@ -162,5 +162,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { intercept[spark.SparkException] { val model = SVMWithSGD.train(testRDDInvalid, 100) } + + // Turning off data validation should not throw an exception + val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } } -- GitLab