Skip to content
Snippets Groups Projects
Commit 852d8107 authored by Evan Sparks's avatar Evan Sparks
Browse files

Merge pull request #819 from shivaram/sgd-cleanup

Change SVM to use {0,1} labels
parents ca716209 dc06b528
No related branches found
No related tags found
No related merge requests found
......@@ -17,12 +17,13 @@
package spark.mllib.classification
import scala.math.round
import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.regression._
import spark.mllib.util.MLUtils
import scala.math.round
import spark.mllib.util.DataValidators
import org.jblas.DoubleMatrix
......@@ -47,26 +48,29 @@ class LogisticRegressionModel(
/**
* Train a classification model for Logistic Regression using Stochastic Gradient Descent.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*/
class LogisticRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
with Serializable {
val gradient = new LogisticGradient()
val updater = new SimpleUpdater()
val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
override val validators = List(DataValidators.classificationLabels)
/**
* Construct a LogisticRegression object with default parameters
*/
def this() = this(1.0, 100, 0.0, 1.0, true)
def this() = this(1.0, 100, 0.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
......@@ -75,6 +79,7 @@ class LogisticRegressionWithSGD private (
/**
* Top-level methods for calling Logistic Regression.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*/
object LogisticRegressionWithSGD {
// NOTE(shivaram): We use multiple train methods instead of default arguments to support
......@@ -85,6 +90,7 @@ object LogisticRegressionWithSGD {
* number of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
* gradient descent are initialized using the initial weights provided.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
......@@ -101,7 +107,7 @@ object LogisticRegressionWithSGD {
initialWeights: Array[Double])
: LogisticRegressionModel =
{
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input, initialWeights)
}
......@@ -109,6 +115,7 @@ object LogisticRegressionWithSGD {
* Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
* number of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
......@@ -123,7 +130,7 @@ object LogisticRegressionWithSGD {
miniBatchFraction: Double)
: LogisticRegressionModel =
{
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input)
}
......@@ -131,6 +138,7 @@ object LogisticRegressionWithSGD {
* Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
* number of iterations of gradient descent using the specified step size. We use the entire data
* set to update the gradient in each iteration.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
......@@ -151,6 +159,7 @@ object LogisticRegressionWithSGD {
* Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
* number of iterations of gradient descent using a step size of 1.0. We use the entire data set
* to update the gradient in each iteration.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
......
......@@ -18,10 +18,12 @@
package spark.mllib.classification
import scala.math.signum
import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.regression._
import spark.mllib.util.MLUtils
import spark.mllib.util.DataValidators
import org.jblas.DoubleMatrix
......@@ -39,31 +41,36 @@ class SVMModel(
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
signum(dataMatrix.dot(weightMatrix) + intercept)
val margin = dataMatrix.dot(weightMatrix) + intercept
if (margin < 0) 0.0 else 1.0
}
}
/**
* Train an SVM using Stochastic Gradient Descent.
* NOTE: Labels used in SVM should be {0, 1}
*/
class SVMWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
val gradient = new HingeGradient()
val updater = new SquaredL2Updater()
val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
override val validators = List(DataValidators.classificationLabels)
/**
* Construct a SVM object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0, true)
def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new SVMModel(weights, intercept)
......@@ -71,7 +78,7 @@ class SVMWithSGD private (
}
/**
* Top-level methods for calling SVM.
* Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}
*/
object SVMWithSGD {
......@@ -80,6 +87,7 @@ object SVMWithSGD {
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
* gradient descent are initialized using the initial weights provided.
* NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
......@@ -98,7 +106,7 @@ object SVMWithSGD {
initialWeights: Array[Double])
: SVMModel =
{
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
......@@ -106,6 +114,7 @@ object SVMWithSGD {
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient.
* NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
......@@ -121,13 +130,14 @@ object SVMWithSGD {
miniBatchFraction: Double)
: SVMModel =
{
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
* update the gradient in each iteration.
* NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
......@@ -149,6 +159,7 @@ object SVMWithSGD {
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
* update the gradient in each iteration.
* NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
......
......@@ -77,16 +77,22 @@ class SquaredGradient extends Gradient {
/**
* Compute gradient and loss for a Hinge loss function.
* NOTE: This assumes that the labels are {0,1}
*/
class HingeGradient extends Gradient {
override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
(DoubleMatrix, Double) = {
val dotProduct = data.dot(weights)
if (1.0 > label * dotProduct)
(data.mul(-label), 1.0 - label * dotProduct)
else
(DoubleMatrix.zeros(1,weights.length), 0.0)
// Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
// Therefore the gradient is -(2y - 1)*x
val labelScaled = 2 * label - 1.0
if (1.0 > labelScaled * dotProduct) {
(data.mul(-labelScaled), 1.0 - labelScaled * dotProduct)
} else {
(DoubleMatrix.zeros(1, weights.length), 0.0)
}
}
}
......@@ -17,7 +17,7 @@
package spark.mllib.regression
import spark.{Logging, RDD}
import spark.{Logging, RDD, SparkException}
import spark.mllib.optimization._
import org.jblas.DoubleMatrix
......@@ -83,15 +83,19 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
extends Logging with Serializable {
protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List()
val optimizer: Optimizer
protected var addIntercept: Boolean = true
protected var validateData: Boolean = true
/**
* Create a model given the weights and intercept
*/
protected def createModel(weights: Array[Double], intercept: Double): M
protected var addIntercept: Boolean
/**
* Set if the algorithm should add an intercept. Default true.
*/
......@@ -100,6 +104,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}
/**
* Set if the algorithm should validate data before training. Default true.
*/
def setValidateData(validateData: Boolean): this.type = {
this.validateData = validateData
this
}
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
......@@ -116,6 +128,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {
// Check the data properties before running the optimizer
if (validateData && !validators.forall(func => func(input))) {
throw new SparkException("Input validation failed.")
}
// Add a extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*)))
......
......@@ -48,8 +48,7 @@ class LassoWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {
......@@ -63,7 +62,7 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0, true)
def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LassoModel(weights, intercept)
......@@ -98,7 +97,7 @@ object LassoWithSGD {
initialWeights: Array[Double])
: LassoModel =
{
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
......@@ -121,7 +120,7 @@ object LassoWithSGD {
miniBatchFraction: Double)
: LassoModel =
{
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.mllib.util
import spark.{RDD, Logging}
import spark.mllib.regression.LabeledPoint
/**
* A collection of methods used to validate data before applying ML algorithms.
*/
object DataValidators extends Logging {
/**
* Function to check if labels used for classification are either zero or one.
*
* @param data - input data set that needs to be checked
*
* @return True if labels are all zero or one, false otherwise.
*/
val classificationLabels: RDD[LabeledPoint] => Boolean = { data =>
val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
if (numInvalid != 0) {
logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels")
}
numInvalid == 0
}
}
package spark.mllib.util
import scala.util.Random
import scala.math.signum
import spark.{RDD, SparkContext}
......@@ -30,8 +29,8 @@ object SVMDataGenerator {
val sc = new SparkContext(sparkMaster, "SVMGenerator")
val globalRnd = new Random(94720)
val trueWeights = new DoubleMatrix(1, nfeatures+1,
Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
val trueWeights = new DoubleMatrix(1, nfeatures + 1,
Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*)
val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx)
......@@ -39,11 +38,13 @@ object SVMDataGenerator {
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
val y = signum((new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1)
val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1
val y = if (yD < 0) 0.0 else 1.0
LabeledPoint(y, x)
}
MLUtils.saveLabeledData(data, outputPath)
sc.stop()
}
}
......@@ -48,13 +48,11 @@ object SVMSuite {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
Array.fill[Double](weights.length)(rnd.nextGaussian()))
Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0))
val y = x.map { xi =>
signum(
(new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) +
intercept +
0.1 * rnd.nextGaussian()
).toInt
val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) +
intercept + 0.01 * rnd.nextGaussian()
if (yD < 0) 0.0 else 1.0
}
y.zip(x).map(p => LabeledPoint(p._1, p._2))
}
......@@ -85,7 +83,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
test("SVM using local random SGD") {
val nPoints = 10000
val A = 2.0
// NOTE: Intercept should be small for generating equal 0s and 1s
val A = 0.01
val B = -1.5
val C = 1.0
......@@ -100,7 +99,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val model = svm.run(testRDD)
val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
val validationRDD = sc.parallelize(validationData,2)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
......@@ -112,7 +111,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
test("SVM local random SGD with initial weights") {
val nPoints = 10000
val A = 2.0
// NOTE: Intercept should be small for generating equal 0s and 1s
val A = 0.01
val B = -1.5
val C = 1.0
......@@ -139,4 +139,31 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
test("SVM with invalid labels") {
val nPoints = 10000
// NOTE: Intercept should be small for generating equal 0s and 1s
val A = 0.01
val B = -1.5
val C = 1.0
val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
val testRDDInvalid = testRDD.map { lp =>
if (lp.label == 0.0) {
LabeledPoint(-1.0, lp.features)
} else {
lp
}
}
intercept[spark.SparkException] {
val model = SVMWithSGD.train(testRDDInvalid, 100)
}
// Turning off data validation should not throw an exception
val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment