Skip to content
Snippets Groups Projects
Commit dc06b528 authored by Shivaram Venkataraman's avatar Shivaram Venkataraman
Browse files

Add an option to turn off data validation, test it.

Also moves addIntercept to have default true to make it similar
to validateData option
parent c8746253
No related branches found
No related tags found
No related merge requests found
......@@ -54,8 +54,7 @@ class LogisticRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
with Serializable {
......@@ -71,7 +70,7 @@ class LogisticRegressionWithSGD private (
/**
* Construct a LogisticRegression object with default parameters
*/
def this() = this(1.0, 100, 0.0, 1.0, true)
def this() = this(1.0, 100, 0.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
......@@ -108,7 +107,7 @@ object LogisticRegressionWithSGD {
initialWeights: Array[Double])
: LogisticRegressionModel =
{
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input, initialWeights)
}
......@@ -131,7 +130,7 @@ object LogisticRegressionWithSGD {
miniBatchFraction: Double)
: LogisticRegressionModel =
{
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input)
}
......
......@@ -54,8 +54,7 @@ class SVMWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
val gradient = new HingeGradient()
......@@ -71,7 +70,7 @@ class SVMWithSGD private (
/**
* Construct a SVM object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0, true)
def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new SVMModel(weights, intercept)
......@@ -107,7 +106,7 @@ object SVMWithSGD {
initialWeights: Array[Double])
: SVMModel =
{
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
......@@ -131,7 +130,7 @@ object SVMWithSGD {
miniBatchFraction: Double)
: SVMModel =
{
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
......
......@@ -87,13 +87,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
val optimizer: Optimizer
protected var addIntercept: Boolean = true
protected var validateData: Boolean = true
/**
* Create a model given the weights and intercept
*/
protected def createModel(weights: Array[Double], intercept: Double): M
protected var addIntercept: Boolean
/**
* Set if the algorithm should add an intercept. Default true.
*/
......@@ -102,6 +104,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}
/**
* Set if the algorithm should validate data before training. Default true.
*/
def setValidateData(validateData: Boolean): this.type = {
this.validateData = validateData
this
}
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
......@@ -119,7 +129,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {
// Check the data properties before running the optimizer
if (!validators.forall(func => func(input))) {
if (validateData && !validators.forall(func => func(input))) {
throw new SparkException("Input validation failed.")
}
......
......@@ -48,8 +48,7 @@ class LassoWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {
......@@ -63,7 +62,7 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0, true)
def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LassoModel(weights, intercept)
......@@ -98,7 +97,7 @@ object LassoWithSGD {
initialWeights: Array[Double])
: LassoModel =
{
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
......@@ -121,7 +120,7 @@ object LassoWithSGD {
miniBatchFraction: Double)
: LassoModel =
{
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
......
......@@ -162,5 +162,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
intercept[spark.SparkException] {
val model = SVMWithSGD.train(testRDDInvalid, 100)
}
// Turning off data validation should not throw an exception
val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment