Skip to content
Snippets Groups Projects
Commit cef17887 authored by Shivaram Venkataraman's avatar Shivaram Venkataraman
Browse files

Refactor SGD options into a new class.

This refactoring pulls out code shared between SVM, Lasso, LR into
a common GradientDescentOpts class. Some style cleanup as well
parent 29b8cd36
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.mllib.optimization
/**
* Class used to configure options used for GradientDescent based optimization
* algorithms.
*/
class GradientDescentOpts private (
var stepSize: Double,
var numIters: Int,
var regParam: Double,
var miniBatchFraction: Double) {
def this() = this(1.0, 100, 0.0, 1.0)
/**
* Set the step size per-iteration of SGD. Default 1.0.
*/
def setStepSize(step: Double) = {
this.stepSize = step
this
}
/**
* Set fraction of data to be used for each SGD iteration. Default 1.0.
*/
def setMiniBatchFraction(fraction: Double) = {
this.miniBatchFraction = fraction
this
}
/**
* Set the number of iterations for SGD. Default 100.
*/
def setNumIterations(iters: Int) = {
this.numIters = iters
this
}
/**
* Set the regularization parameter used for SGD. Default 0.0.
*/
def setRegParam(regParam: Double) = {
this.regParam = regParam
this
}
}
object GradientDescentOpts {
def apply(stepSize: Double, numIters: Int, regParam: Double, miniBatchFraction: Double) = {
new GradientDescentOpts(stepSize, numIters, regParam, miniBatchFraction)
}
def apply() = {
new GradientDescentOpts()
}
}
...@@ -55,38 +55,12 @@ class LogisticRegressionModel( ...@@ -55,38 +55,12 @@ class LogisticRegressionModel(
} }
} }
class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBatchFraction: Double, class LogisticRegression(val opts: GradientDescentOpts) extends Logging {
var numIters: Int)
extends Logging {
/** /**
* Construct a LogisticRegression object with default parameters * Construct a LogisticRegression object with default parameters
*/ */
def this() = this(1.0, 1.0, 100) def this() = this(new GradientDescentOpts())
/**
* Set the step size per-iteration of SGD. Default 1.0.
*/
def setStepSize(step: Double) = {
this.stepSize = step
this
}
/**
* Set fraction of data to be used for each SGD iteration. Default 1.0.
*/
def setMiniBatchFraction(fraction: Double) = {
this.miniBatchFraction = fraction
this
}
/**
* Set the number of iterations for SGD. Default 100.
*/
def setNumIterations(iters: Int) = {
this.numIters = iters
this
}
def train(input: RDD[(Int, Array[Double])]): LogisticRegressionModel = { def train(input: RDD[(Int, Array[Double])]): LogisticRegressionModel = {
val nfeatures: Int = input.take(1)(0)._2.length val nfeatures: Int = input.take(1)(0)._2.length
...@@ -109,11 +83,8 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa ...@@ -109,11 +83,8 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa
data, data,
new LogisticGradient(), new LogisticGradient(),
new SimpleUpdater(), new SimpleUpdater(),
stepSize, opts,
numIters, initalWeightsWithIntercept)
0.0,
initalWeightsWithIntercept,
miniBatchFraction)
val intercept = weights(0) val intercept = weights(0)
val weightsScaled = weights.tail val weightsScaled = weights.tail
...@@ -132,7 +103,7 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa ...@@ -132,7 +103,7 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa
* NOTE(shivaram): We use multiple train methods instead of default arguments to support * NOTE(shivaram): We use multiple train methods instead of default arguments to support
* Java programs. * Java programs.
*/ */
object LogisticRegressionLocalRandomSGD { object LogisticRegression {
/** /**
* Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
...@@ -155,8 +126,8 @@ object LogisticRegressionLocalRandomSGD { ...@@ -155,8 +126,8 @@ object LogisticRegressionLocalRandomSGD {
initialWeights: Array[Double]) initialWeights: Array[Double])
: LogisticRegressionModel = : LogisticRegressionModel =
{ {
new LogisticRegressionLocalRandomSGD(stepSize, miniBatchFraction, numIterations).train( val sgdOpts = GradientDescentOpts(stepSize, numIterations, 0.0, miniBatchFraction)
input, initialWeights) new LogisticRegression(sgdOpts).train(input, initialWeights)
} }
/** /**
...@@ -177,7 +148,8 @@ object LogisticRegressionLocalRandomSGD { ...@@ -177,7 +148,8 @@ object LogisticRegressionLocalRandomSGD {
miniBatchFraction: Double) miniBatchFraction: Double)
: LogisticRegressionModel = : LogisticRegressionModel =
{ {
new LogisticRegressionLocalRandomSGD(stepSize, miniBatchFraction, numIterations).train(input) val sgdOpts = GradientDescentOpts(stepSize, numIterations, 0.0, miniBatchFraction)
new LogisticRegression(sgdOpts).train(input)
} }
/** /**
...@@ -225,7 +197,7 @@ object LogisticRegressionLocalRandomSGD { ...@@ -225,7 +197,7 @@ object LogisticRegressionLocalRandomSGD {
} }
val sc = new SparkContext(args(0), "LogisticRegression") val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadLabeledData(sc, args(1)).map(yx => (yx._1.toInt, yx._2)) val data = MLUtils.loadLabeledData(sc, args(1)).map(yx => (yx._1.toInt, yx._2))
val model = LogisticRegressionLocalRandomSGD.train( val model = LogisticRegression.train(
data, args(4).toInt, args(2).toDouble, args(3).toDouble) data, args(4).toInt, args(2).toDouble, args(3).toDouble)
sc.stop() sc.stop()
......
...@@ -53,46 +53,12 @@ class SVMModel( ...@@ -53,46 +53,12 @@ class SVMModel(
class SVMLocalRandomSGD private (var stepSize: Double, var regParam: Double, class SVM(val opts: GradientDescentOpts) extends Logging {
var miniBatchFraction: Double, var numIters: Int)
extends Logging {
/** /**
* Construct a SVM object with default parameters * Construct a SVM object with default parameters
*/ */
def this() = this(1.0, 1.0, 1.0, 100) def this() = this(GradientDescentOpts(1.0, 100, 1.0, 1.0))
/**
* Set the step size per-iteration of SGD. Default 1.0.
*/
def setStepSize(step: Double) = {
this.stepSize = step
this
}
/**
* Set the regularization parameter. Default 1.0.
*/
def setRegParam(param: Double) = {
this.regParam = param
this
}
/**
* Set fraction of data to be used for each SGD iteration. Default 1.0.
*/
def setMiniBatchFraction(fraction: Double) = {
this.miniBatchFraction = fraction
this
}
/**
* Set the number of iterations for SGD. Default 100.
*/
def setNumIterations(iters: Int) = {
this.numIters = iters
this
}
def train(input: RDD[(Int, Array[Double])]): SVMModel = { def train(input: RDD[(Int, Array[Double])]): SVMModel = {
val nfeatures: Int = input.take(1)(0)._2.length val nfeatures: Int = input.take(1)(0)._2.length
...@@ -115,11 +81,8 @@ class SVMLocalRandomSGD private (var stepSize: Double, var regParam: Double, ...@@ -115,11 +81,8 @@ class SVMLocalRandomSGD private (var stepSize: Double, var regParam: Double,
data, data,
new HingeGradient(), new HingeGradient(),
new SquaredL2Updater(), new SquaredL2Updater(),
stepSize, opts,
numIters, initalWeightsWithIntercept)
regParam,
initalWeightsWithIntercept,
miniBatchFraction)
val intercept = weights(0) val intercept = weights(0)
val weightsScaled = weights.tail val weightsScaled = weights.tail
...@@ -135,10 +98,8 @@ class SVMLocalRandomSGD private (var stepSize: Double, var regParam: Double, ...@@ -135,10 +98,8 @@ class SVMLocalRandomSGD private (var stepSize: Double, var regParam: Double,
/** /**
* Top-level methods for calling SVM. * Top-level methods for calling SVM.
*/ */
object SVMLocalRandomSGD { object SVM {
/** /**
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
...@@ -163,8 +124,8 @@ object SVMLocalRandomSGD { ...@@ -163,8 +124,8 @@ object SVMLocalRandomSGD {
initialWeights: Array[Double]) initialWeights: Array[Double])
: SVMModel = : SVMModel =
{ {
new SVMLocalRandomSGD(stepSize, regParam, miniBatchFraction, numIterations).train( val sgdOpts = GradientDescentOpts(stepSize, numIterations, regParam, miniBatchFraction)
input, initialWeights) new SVM(sgdOpts).train(input, initialWeights)
} }
/** /**
...@@ -186,7 +147,8 @@ object SVMLocalRandomSGD { ...@@ -186,7 +147,8 @@ object SVMLocalRandomSGD {
miniBatchFraction: Double) miniBatchFraction: Double)
: SVMModel = : SVMModel =
{ {
new SVMLocalRandomSGD(stepSize, regParam, miniBatchFraction, numIterations).train(input) val sgdOpts = GradientDescentOpts(stepSize, numIterations, regParam, miniBatchFraction)
new SVM(sgdOpts).train(input)
} }
/** /**
...@@ -234,7 +196,7 @@ object SVMLocalRandomSGD { ...@@ -234,7 +196,7 @@ object SVMLocalRandomSGD {
} }
val sc = new SparkContext(args(0), "SVM") val sc = new SparkContext(args(0), "SVM")
val data = MLUtils.loadLabeledData(sc, args(1)).map(yx => (yx._1.toInt, yx._2)) val data = MLUtils.loadLabeledData(sc, args(1)).map(yx => (yx._1.toInt, yx._2))
val model = SVMLocalRandomSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) val model = SVM.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
sc.stop() sc.stop()
} }
......
...@@ -24,7 +24,6 @@ import org.jblas.DoubleMatrix ...@@ -24,7 +24,6 @@ import org.jblas.DoubleMatrix
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
object GradientDescent { object GradientDescent {
/** /**
...@@ -48,23 +47,20 @@ object GradientDescent { ...@@ -48,23 +47,20 @@ object GradientDescent {
data: RDD[(Double, Array[Double])], data: RDD[(Double, Array[Double])],
gradient: Gradient, gradient: Gradient,
updater: Updater, updater: Updater,
stepSize: Double, opts: GradientDescentOpts,
numIters: Int, initialWeights: Array[Double]) : (Array[Double], Array[Double]) = {
regParam: Double,
initialWeights: Array[Double],
miniBatchFraction: Double=1.0) : (Array[Double], Array[Double]) = {
val stochasticLossHistory = new ArrayBuffer[Double](numIters) val stochasticLossHistory = new ArrayBuffer[Double](opts.numIters)
val nexamples: Long = data.count() val nexamples: Long = data.count()
val miniBatchSize = nexamples * miniBatchFraction val miniBatchSize = nexamples * opts.miniBatchFraction
// Initialize weights as a column vector // Initialize weights as a column vector
var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*)
var regVal = 0.0 var regVal = 0.0
for (i <- 1 to numIters) { for (i <- 1 to opts.numIters) {
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map { val (gradientSum, lossSum) = data.sample(false, opts.miniBatchFraction, 42+i).map {
case (y, features) => case (y, features) =>
val featuresRow = new DoubleMatrix(features.length, 1, features:_*) val featuresRow = new DoubleMatrix(features.length, 1, features:_*)
val (grad, loss) = gradient.compute(featuresRow, y, weights) val (grad, loss) = gradient.compute(featuresRow, y, weights)
...@@ -76,7 +72,8 @@ object GradientDescent { ...@@ -76,7 +72,8 @@ object GradientDescent {
* and regVal is the regularization value computed in the previous iteration as well. * and regVal is the regularization value computed in the previous iteration as well.
*/ */
stochasticLossHistory.append(lossSum / miniBatchSize + regVal) stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i, regParam) val update = updater.compute(
weights, gradientSum.div(miniBatchSize), opts.stepSize, i, opts.regParam)
weights = update._1 weights = update._1
regVal = update._2 regVal = update._2
} }
......
...@@ -53,46 +53,12 @@ class LassoModel( ...@@ -53,46 +53,12 @@ class LassoModel(
} }
class LassoLocalRandomSGD private (var stepSize: Double, var regParam: Double, class Lasso(val opts: GradientDescentOpts) extends Logging {
var miniBatchFraction: Double, var numIters: Int)
extends Logging {
/** /**
* Construct a Lasso object with default parameters * Construct a Lasso object with default parameters
*/ */
def this() = this(1.0, 1.0, 1.0, 100) def this() = this(GradientDescentOpts(1.0, 100, 1.0, 1.0))
/**
* Set the step size per-iteration of SGD. Default 1.0.
*/
def setStepSize(step: Double) = {
this.stepSize = step
this
}
/**
* Set the regularization parameter. Default 1.0.
*/
def setRegParam(param: Double) = {
this.regParam = param
this
}
/**
* Set fraction of data to be used for each SGD iteration. Default 1.0.
*/
def setMiniBatchFraction(fraction: Double) = {
this.miniBatchFraction = fraction
this
}
/**
* Set the number of iterations for SGD. Default 100.
*/
def setNumIterations(iters: Int) = {
this.numIters = iters
this
}
def train(input: RDD[(Double, Array[Double])]): LassoModel = { def train(input: RDD[(Double, Array[Double])]): LassoModel = {
val nfeatures: Int = input.take(1)(0)._2.length val nfeatures: Int = input.take(1)(0)._2.length
...@@ -115,11 +81,8 @@ class LassoLocalRandomSGD private (var stepSize: Double, var regParam: Double, ...@@ -115,11 +81,8 @@ class LassoLocalRandomSGD private (var stepSize: Double, var regParam: Double,
data, data,
new SquaredGradient(), new SquaredGradient(),
new L1Updater(), new L1Updater(),
stepSize, opts,
numIters, initalWeightsWithIntercept)
regParam,
initalWeightsWithIntercept,
miniBatchFraction)
val intercept = weights(0) val intercept = weights(0)
val weightsScaled = weights.tail val weightsScaled = weights.tail
...@@ -135,10 +98,8 @@ class LassoLocalRandomSGD private (var stepSize: Double, var regParam: Double, ...@@ -135,10 +98,8 @@ class LassoLocalRandomSGD private (var stepSize: Double, var regParam: Double,
/** /**
* Top-level methods for calling Lasso. * Top-level methods for calling Lasso.
*
*
*/ */
object LassoLocalRandomSGD { object Lasso {
/** /**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
...@@ -163,8 +124,8 @@ object LassoLocalRandomSGD { ...@@ -163,8 +124,8 @@ object LassoLocalRandomSGD {
initialWeights: Array[Double]) initialWeights: Array[Double])
: LassoModel = : LassoModel =
{ {
new LassoLocalRandomSGD(stepSize, regParam, miniBatchFraction, numIterations).train( val sgdOpts = GradientDescentOpts(stepSize, numIterations, regParam, miniBatchFraction)
input, initialWeights) new Lasso(sgdOpts).train(input, initialWeights)
} }
/** /**
...@@ -186,7 +147,8 @@ object LassoLocalRandomSGD { ...@@ -186,7 +147,8 @@ object LassoLocalRandomSGD {
miniBatchFraction: Double) miniBatchFraction: Double)
: LassoModel = : LassoModel =
{ {
new LassoLocalRandomSGD(stepSize, regParam, miniBatchFraction, numIterations).train(input) val sgdOpts = GradientDescentOpts(stepSize, numIterations, regParam, miniBatchFraction)
new Lasso(sgdOpts).train(input)
} }
/** /**
...@@ -234,7 +196,7 @@ object LassoLocalRandomSGD { ...@@ -234,7 +196,7 @@ object LassoLocalRandomSGD {
} }
val sc = new SparkContext(args(0), "Lasso") val sc = new SparkContext(args(0), "Lasso")
val data = MLUtils.loadLabeledData(sc, args(1)) val data = MLUtils.loadLabeledData(sc, args(1))
val model = LassoLocalRandomSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) val model = Lasso.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
sc.stop() sc.stop()
} }
......
...@@ -24,6 +24,7 @@ import org.scalatest.FunSuite ...@@ -24,6 +24,7 @@ import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers import org.scalatest.matchers.ShouldMatchers
import spark.SparkContext import spark.SparkContext
import spark.mllib.optimization._
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
...@@ -79,7 +80,8 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul ...@@ -79,7 +80,8 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val lr = new LogisticRegressionLocalRandomSGD().setStepSize(10.0).setNumIterations(20) val sgdOpts = GradientDescentOpts().setStepSize(10.0).setNumIterations(20)
val lr = new LogisticRegression(sgdOpts)
val model = lr.train(testRDD) val model = lr.train(testRDD)
...@@ -111,7 +113,8 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul ...@@ -111,7 +113,8 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul
testRDD.cache() testRDD.cache()
// Use half as many iterations as the previous test. // Use half as many iterations as the previous test.
val lr = new LogisticRegressionLocalRandomSGD().setStepSize(10.0).setNumIterations(10) val sgdOpts = GradientDescentOpts().setStepSize(10.0).setNumIterations(10)
val lr = new LogisticRegression(sgdOpts)
val model = lr.train(testRDD, initialWeights) val model = lr.train(testRDD, initialWeights)
......
...@@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfterAll ...@@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite import org.scalatest.FunSuite
import spark.SparkContext import spark.SparkContext
import spark.mllib.optimization._
import org.jblas.DoubleMatrix import org.jblas.DoubleMatrix
...@@ -44,10 +45,14 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { ...@@ -44,10 +45,14 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val rnd = new Random(seed) val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian())) val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian()))
val y = x.map(xi => val y = x.map { xi =>
signum((new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.1 * rnd.nextGaussian()).toInt signum(
) (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) +
y zip x intercept +
0.1 * rnd.nextGaussian()
).toInt
}
y.zip(x)
} }
def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) { def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
...@@ -58,7 +63,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { ...@@ -58,7 +63,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
assert(numOffPredictions < input.length / 5) assert(numOffPredictions < input.length / 5)
} }
test("SVMLocalRandomSGD") { test("SVM using local random SGD") {
val nPoints = 10000 val nPoints = 10000
val A = 2.0 val A = 2.0
...@@ -70,7 +75,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { ...@@ -70,7 +75,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val svm = new SVMLocalRandomSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100) val sgdOpts = GradientDescentOpts().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
val svm = new SVM(sgdOpts)
val model = svm.train(testRDD) val model = svm.train(testRDD)
...@@ -84,7 +90,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { ...@@ -84,7 +90,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
validatePrediction(validationData.map(row => model.predict(row._2)), validationData) validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
} }
test("SVMLocalRandomSGD with initial weights") { test("SVM local random SGD with initial weights") {
val nPoints = 10000 val nPoints = 10000
val A = 2.0 val A = 2.0
...@@ -100,7 +106,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { ...@@ -100,7 +106,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val svm = new SVMLocalRandomSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100) val sgdOpts = GradientDescentOpts().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
val svm = new SVM(sgdOpts)
val model = svm.train(testRDD, initialWeights) val model = svm.train(testRDD, initialWeights)
......
...@@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterAll ...@@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite import org.scalatest.FunSuite
import spark.SparkContext import spark.SparkContext
import spark.mllib.optimization._
import org.jblas.DoubleMatrix import org.jblas.DoubleMatrix
...@@ -59,7 +60,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll { ...@@ -59,7 +60,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
assert(numOffPredictions < input.length / 5) assert(numOffPredictions < input.length / 5)
} }
test("LassoLocalRandomSGD") { test("Lasso local random SGD") {
val nPoints = 10000 val nPoints = 10000
val A = 2.0 val A = 2.0
...@@ -70,7 +71,9 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll { ...@@ -70,7 +71,9 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val ls = new LassoLocalRandomSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
val sgdOpts = GradientDescentOpts().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
val ls = new Lasso(sgdOpts)
val model = ls.train(testRDD) val model = ls.train(testRDD)
...@@ -90,7 +93,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll { ...@@ -90,7 +93,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
validatePrediction(validationData.map(row => model.predict(row._2)), validationData) validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
} }
test("LassoLocalRandomSGD with initial weights") { test("Lasso local random SGD with initial weights") {
val nPoints = 10000 val nPoints = 10000
val A = 2.0 val A = 2.0
...@@ -105,7 +108,9 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll { ...@@ -105,7 +108,9 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val ls = new LassoLocalRandomSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
val sgdOpts = GradientDescentOpts().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
val ls = new Lasso(sgdOpts)
val model = ls.train(testRDD, initialWeights) val model = ls.train(testRDD, initialWeights)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment