From dc06b528790c69b2e6de85cba84266fea81dd4f4 Mon Sep 17 00:00:00 2001
From: Shivaram Venkataraman <shivaram@eecs.berkeley.edu>
Date: Sun, 25 Aug 2013 23:14:35 -0700
Subject: [PATCH] Add an option to turn off data validation, test it. Also
 moves addIntercept to have default true to make it similar to validateData
 option

---
 .../classification/LogisticRegression.scala      |  9 ++++-----
 .../scala/spark/mllib/classification/SVM.scala   |  9 ++++-----
 .../regression/GeneralizedLinearAlgorithm.scala  | 16 +++++++++++++---
 .../scala/spark/mllib/regression/Lasso.scala     |  9 ++++-----
 .../spark/mllib/classification/SVMSuite.scala    |  3 +++
 5 files changed, 28 insertions(+), 18 deletions(-)

diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index 474ca6e97c..482e4a6745 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -54,8 +54,7 @@ class LogisticRegressionWithSGD private (
     var stepSize: Double,
     var numIterations: Int,
     var regParam: Double,
-    var miniBatchFraction: Double,
-    var addIntercept: Boolean)
+    var miniBatchFraction: Double)
   extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
   with Serializable {
 
@@ -71,7 +70,7 @@ class LogisticRegressionWithSGD private (
   /**
    * Construct a LogisticRegression object with default parameters
    */
-  def this() = this(1.0, 100, 0.0, 1.0, true)
+  def this() = this(1.0, 100, 0.0, 1.0)
 
   def createModel(weights: Array[Double], intercept: Double) = {
     new LogisticRegressionModel(weights, intercept)
@@ -108,7 +107,7 @@ object LogisticRegressionWithSGD {
       initialWeights: Array[Double])
     : LogisticRegressionModel =
   {
-    new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
+    new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
       input, initialWeights)
   }
 
@@ -131,7 +130,7 @@ object LogisticRegressionWithSGD {
       miniBatchFraction: Double)
     : LogisticRegressionModel =
   {
-    new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
+    new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
       input)
   }
 
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
index b680d81e86..69393cd7b0 100644
--- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
@@ -54,8 +54,7 @@ class SVMWithSGD private (
     var stepSize: Double,
     var numIterations: Int,
     var regParam: Double,
-    var miniBatchFraction: Double,
-    var addIntercept: Boolean)
+    var miniBatchFraction: Double)
   extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
 
   val gradient = new HingeGradient()
@@ -71,7 +70,7 @@ class SVMWithSGD private (
   /**
    * Construct a SVM object with default parameters
    */
-  def this() = this(1.0, 100, 1.0, 1.0, true)
+  def this() = this(1.0, 100, 1.0, 1.0)
 
   def createModel(weights: Array[Double], intercept: Double) = {
     new SVMModel(weights, intercept)
@@ -107,7 +106,7 @@ object SVMWithSGD {
       initialWeights: Array[Double])
     : SVMModel =
   {
-    new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
+    new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
       initialWeights)
   }
 
@@ -131,7 +130,7 @@ object SVMWithSGD {
       miniBatchFraction: Double)
     : SVMModel =
   {
-    new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
+    new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
   }
 
   /**
diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 03f991df39..d164d415d6 100644
--- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -87,13 +87,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
 
   val optimizer: Optimizer
 
+  protected var addIntercept: Boolean = true
+
+  protected var validateData: Boolean = true
+
   /**
    * Create a model given the weights and intercept
    */
   protected def createModel(weights: Array[Double], intercept: Double): M
 
-  protected var addIntercept: Boolean
-
   /**
    * Set if the algorithm should add an intercept. Default true.
    */
@@ -102,6 +104,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
     this
   }
 
+  /**
+   * Set if the algorithm should validate data before training. Default true.
+   */
+  def setValidateData(validateData: Boolean): this.type = {
+    this.validateData = validateData
+    this
+  }
+
   /**
    * Run the algorithm with the configured parameters on an input
    * RDD of LabeledPoint entries.
@@ -119,7 +129,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
   def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {
 
     // Check the data properties before running the optimizer
-    if (!validators.forall(func => func(input))) {
+    if (validateData && !validators.forall(func => func(input))) {
       throw new SparkException("Input validation failed.")
     }
 
diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
index 6bbc990a5a..89f791e85a 100644
--- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
@@ -48,8 +48,7 @@ class LassoWithSGD private (
     var stepSize: Double,
     var numIterations: Int,
     var regParam: Double,
-    var miniBatchFraction: Double,
-    var addIntercept: Boolean)
+    var miniBatchFraction: Double)
   extends GeneralizedLinearAlgorithm[LassoModel]
   with Serializable {
 
@@ -63,7 +62,7 @@ class LassoWithSGD private (
   /**
    * Construct a Lasso object with default parameters
    */
-  def this() = this(1.0, 100, 1.0, 1.0, true)
+  def this() = this(1.0, 100, 1.0, 1.0)
 
   def createModel(weights: Array[Double], intercept: Double) = {
     new LassoModel(weights, intercept)
@@ -98,7 +97,7 @@ object LassoWithSGD {
       initialWeights: Array[Double])
     : LassoModel =
   {
-    new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
+    new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
         initialWeights)
   }
 
@@ -121,7 +120,7 @@ object LassoWithSGD {
       miniBatchFraction: Double)
     : LassoModel =
   {
-    new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
+    new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
   }
 
   /**
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 8fa9e4639b..894ae458ad 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -162,5 +162,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
     intercept[spark.SparkException] {
       val model = SVMWithSGD.train(testRDDInvalid, 100)
     }
+
+    // Turning off data validation should not throw an exception
+    val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
   }
 }
-- 
GitLab