From 2812e722008b772756cbd0ef0600a55b65d6ee0e Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman <shivaram@eecs.berkeley.edu> Date: Thu, 8 Aug 2013 16:24:31 -0700 Subject: [PATCH] Add setters for optimizer, gradient in SGD. Also remove java-specific constructor for LabeledPoint. --- .../java/spark/mllib/examples/JavaLR.java | 2 +- .../mllib/optimization/GradientDescent.scala | 19 ++++++++++++++++++- .../spark/mllib/regression/LabeledPoint.scala | 8 +------- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/examples/src/main/java/spark/mllib/examples/JavaLR.java b/examples/src/main/java/spark/mllib/examples/JavaLR.java index e11f4830a8..bf4aeaf40f 100644 --- a/examples/src/main/java/spark/mllib/examples/JavaLR.java +++ b/examples/src/main/java/spark/mllib/examples/JavaLR.java @@ -37,7 +37,7 @@ public class JavaLR { static class ParsePoint extends Function<String, LabeledPoint> { public LabeledPoint call(String line) { String[] parts = line.split(","); - Double y = Double.parseDouble(parts[0]); + double y = Double.parseDouble(parts[0]); StringTokenizer tok = new StringTokenizer(parts[1], " "); int numTokens = tok.countTokens(); double[] x = new double[numTokens]; diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala index 54793ca74d..1f04398d0c 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala @@ -24,7 +24,7 @@ import org.jblas.DoubleMatrix import scala.collection.mutable.ArrayBuffer -class GradientDescent(gradient: Gradient, updater: Updater) extends Optimizer { +class GradientDescent(var gradient: Gradient, var updater: Updater) extends Optimizer { var stepSize: Double = 1.0 var numIterations: Int = 100 @@ -63,6 +63,23 @@ class GradientDescent(gradient: Gradient, updater: Updater) extends Optimizer { this } + /** + * Set the gradient function to be used for SGD. + */ + def setGradient(gradient: Gradient): this.type = { + this.gradient = gradient + this + } + + + /** + * Set the updater function to be used for SGD. + */ + def setUpdater(updater: Updater): this.type = { + this.updater = updater + this + } + def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) : Array[Double] = { diff --git a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala index 592f0b5414..3de60482c5 100644 --- a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala @@ -23,10 +23,4 @@ package spark.mllib.regression * @param label Label for this data point. * @param features List of features for this data point. */ -case class LabeledPoint(val label: Double, val features: Array[Double]) { - - /** - * Construct a labeled point using java.lang.Double. - */ - def this(label: java.lang.Double, features: Array[Double]) = this(label.doubleValue(), features) -} +case class LabeledPoint(val label: Double, val features: Array[Double]) -- GitLab