From 9398dced0331c0ec098ef5eb4616571874ceefb6 Mon Sep 17 00:00:00 2001 From: Xinghao <pxinghao@gmail.com> Date: Sun, 28 Jul 2013 21:39:19 -0700 Subject: [PATCH] Changed Classification to return Int instead of Double Also minor changes to formatting and comments --- .../mllib/classification/Classification.scala | 8 ++++---- .../classification/LogisticRegression.scala | 16 ++++------------ .../scala/spark/mllib/classification/SVM.scala | 8 ++++---- .../spark/mllib/optimization/Gradient.scala | 4 ++-- .../mllib/optimization/GradientDescent.scala | 8 ++++---- .../scala/spark/mllib/optimization/Updater.scala | 12 ++++++++++-- .../spark/mllib/classification/SVMSuite.scala | 2 -- 7 files changed, 28 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/spark/mllib/classification/Classification.scala b/mllib/src/main/scala/spark/mllib/classification/Classification.scala index 96d7a54f18..d6154b66ae 100644 --- a/mllib/src/main/scala/spark/mllib/classification/Classification.scala +++ b/mllib/src/main/scala/spark/mllib/classification/Classification.scala @@ -7,15 +7,15 @@ trait ClassificationModel extends Serializable { * Predict values for the given data set using the model trained. * * @param testData RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction + * @return RDD[Int] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] + def predict(testData: RDD[Array[Double]]): RDD[Int] /** * Predict values for a single data point using the model trained. * * @param testData array representing a single data point - * @return Double prediction from the trained model + * @return Int prediction from the trained model */ - def predict(testData: Array[Double]): Double + def predict(testData: Array[Double]): Int } diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala index 1b093187f2..0a7effb1d7 100644 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala @@ -35,21 +35,21 @@ class LogisticRegressionModel( // Create a column vector that can be used for predictions private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) - override def predict(testData: spark.RDD[Array[Double]]) = { + override def predict(testData: spark.RDD[Array[Double]]): RDD[Int] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. val localWeights = weightsMatrix val localIntercept = intercept testData.map { x => val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept - 1.0/ (1.0 + math.exp(margin * -1)) + (1.0/ (1.0 + math.exp(margin * -1))).toInt } } - override def predict(testData: Array[Double]): Double = { + override def predict(testData: Array[Double]): Int = { val dataMat = new DoubleMatrix(1, testData.length, testData:_*) val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept - 1.0/ (1.0 + math.exp(margin * -1)) + (1.0/ (1.0 + math.exp(margin * -1))).toInt } } @@ -70,14 +70,6 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa this } - - - - - - - - /** * Set fraction of data to be used for each SGD iteration. Default 1.0. */ diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala index 76844f6b9c..30766a4c64 100644 --- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala @@ -35,19 +35,19 @@ class SVMModel( // Create a column vector that can be used for predictions private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) - override def predict(testData: spark.RDD[Array[Double]]) = { + override def predict(testData: spark.RDD[Array[Double]]): RDD[Int] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. val localWeights = weightsMatrix val localIntercept = intercept testData.map { x => - signum(new DoubleMatrix(1, x.length, x:_*).dot(localWeights) + localIntercept) + signum(new DoubleMatrix(1, x.length, x:_*).dot(localWeights) + localIntercept).toInt } } - override def predict(testData: Array[Double]): Double = { + override def predict(testData: Array[Double]): Int = { val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - signum(dataMat.dot(weightsMatrix) + this.intercept) + signum(dataMat.dot(weightsMatrix) + this.intercept).toInt } } diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala index 4864ab7ccf..22b2ec5ed6 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala @@ -70,8 +70,8 @@ class HingeGradient extends Gradient { val dotProduct = data.dot(weights) if (1.0 > label * dotProduct) - (data.mul(-label), 1.0 - label * dotProduct) + (data.mul(-label), 1.0 - label * dotProduct) else - (DoubleMatrix.zeros(1,weights.length), 0.0) + (DoubleMatrix.zeros(1,weights.length), 0.0) } } diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala index 8387d4939b..d4b83a1456 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala @@ -76,10 +76,10 @@ object GradientDescent { weights = update._1 reg_val = update._2 stochasticLossHistory.append(lossSum / miniBatchSize + reg_val) - /*** - Xinghao: The loss here is sum of lossSum computed using the weights before applying updater, - and reg_val using weights after applying updater - ***/ + /* + * NOTE(Xinghao): The loss here is sum of lossSum computed using the weights before applying updater, + * and reg_val using weights after applying updater + */ } (weights.toArray, stochasticLossHistory.toArray) diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala index cd344a6680..188fe7d972 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala @@ -46,17 +46,25 @@ class SimpleUpdater extends Updater { } /** -L1 regularization -- corresponding proximal operator is the soft-thresholding function +* L1 regularization -- corresponding proximal operator is the soft-thresholding function +* That is, each weight component is shrunk towards 0 by shrinkageVal +* If w > shrinkageVal, set weight component to w-shrinkageVal. +* If w < -shrinkageVal, set weight component to w+shrinkageVal. +* If -shrinkageVal < w < shrinkageVal, set weight component to 0. +* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) **/ class L1Updater extends Updater { override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) val normGradient = gradient.mul(thisIterStepSize) + // Take gradient step val newWeights = weightsOld.sub(normGradient) + // Soft thresholding + val shrinkageVal = regParam * thisIterStepSize (0 until newWeights.length).foreach(i => { val wi = newWeights.get(i) - newWeights.put(i, signum(wi) * max(0.0, abs(wi) - regParam * thisIterStepSize)) + newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) }) (newWeights, newWeights.norm1 * regParam) } diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala index 2a23825acc..91c037e9b1 100644 --- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala @@ -25,8 +25,6 @@ import org.scalatest.FunSuite import spark.SparkContext -import java.io._ - class SVMSuite extends FunSuite with BeforeAndAfterAll { val sc = new SparkContext("local", "test") -- GitLab