Skip to content
Snippets Groups Projects
Commit 9398dced authored by Xinghao's avatar Xinghao
Browse files

Changed Classification to return Int instead of Double

Also minor changes to formatting and comments
parent 67de051b
No related branches found
No related tags found
No related merge requests found
...@@ -7,15 +7,15 @@ trait ClassificationModel extends Serializable { ...@@ -7,15 +7,15 @@ trait ClassificationModel extends Serializable {
* Predict values for the given data set using the model trained. * Predict values for the given data set using the model trained.
* *
* @param testData RDD representing data points to be predicted * @param testData RDD representing data points to be predicted
* @return RDD[Double] where each entry contains the corresponding prediction * @return RDD[Int] where each entry contains the corresponding prediction
*/ */
def predict(testData: RDD[Array[Double]]): RDD[Double] def predict(testData: RDD[Array[Double]]): RDD[Int]
/** /**
* Predict values for a single data point using the model trained. * Predict values for a single data point using the model trained.
* *
* @param testData array representing a single data point * @param testData array representing a single data point
* @return Double prediction from the trained model * @return Int prediction from the trained model
*/ */
def predict(testData: Array[Double]): Double def predict(testData: Array[Double]): Int
} }
...@@ -35,21 +35,21 @@ class LogisticRegressionModel( ...@@ -35,21 +35,21 @@ class LogisticRegressionModel(
// Create a column vector that can be used for predictions // Create a column vector that can be used for predictions
private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
override def predict(testData: spark.RDD[Array[Double]]) = { override def predict(testData: spark.RDD[Array[Double]]): RDD[Int] = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix // A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed. // and intercept is needed.
val localWeights = weightsMatrix val localWeights = weightsMatrix
val localIntercept = intercept val localIntercept = intercept
testData.map { x => testData.map { x =>
val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
1.0/ (1.0 + math.exp(margin * -1)) (1.0/ (1.0 + math.exp(margin * -1))).toInt
} }
} }
override def predict(testData: Array[Double]): Double = { override def predict(testData: Array[Double]): Int = {
val dataMat = new DoubleMatrix(1, testData.length, testData:_*) val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept
1.0/ (1.0 + math.exp(margin * -1)) (1.0/ (1.0 + math.exp(margin * -1))).toInt
} }
} }
...@@ -70,14 +70,6 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa ...@@ -70,14 +70,6 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa
this this
} }
/** /**
* Set fraction of data to be used for each SGD iteration. Default 1.0. * Set fraction of data to be used for each SGD iteration. Default 1.0.
*/ */
......
...@@ -35,19 +35,19 @@ class SVMModel( ...@@ -35,19 +35,19 @@ class SVMModel(
// Create a column vector that can be used for predictions // Create a column vector that can be used for predictions
private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
override def predict(testData: spark.RDD[Array[Double]]) = { override def predict(testData: spark.RDD[Array[Double]]): RDD[Int] = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix // A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed. // and intercept is needed.
val localWeights = weightsMatrix val localWeights = weightsMatrix
val localIntercept = intercept val localIntercept = intercept
testData.map { x => testData.map { x =>
signum(new DoubleMatrix(1, x.length, x:_*).dot(localWeights) + localIntercept) signum(new DoubleMatrix(1, x.length, x:_*).dot(localWeights) + localIntercept).toInt
} }
} }
override def predict(testData: Array[Double]): Double = { override def predict(testData: Array[Double]): Int = {
val dataMat = new DoubleMatrix(1, testData.length, testData:_*) val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
signum(dataMat.dot(weightsMatrix) + this.intercept) signum(dataMat.dot(weightsMatrix) + this.intercept).toInt
} }
} }
......
...@@ -70,8 +70,8 @@ class HingeGradient extends Gradient { ...@@ -70,8 +70,8 @@ class HingeGradient extends Gradient {
val dotProduct = data.dot(weights) val dotProduct = data.dot(weights)
if (1.0 > label * dotProduct) if (1.0 > label * dotProduct)
(data.mul(-label), 1.0 - label * dotProduct) (data.mul(-label), 1.0 - label * dotProduct)
else else
(DoubleMatrix.zeros(1,weights.length), 0.0) (DoubleMatrix.zeros(1,weights.length), 0.0)
} }
} }
...@@ -76,10 +76,10 @@ object GradientDescent { ...@@ -76,10 +76,10 @@ object GradientDescent {
weights = update._1 weights = update._1
reg_val = update._2 reg_val = update._2
stochasticLossHistory.append(lossSum / miniBatchSize + reg_val) stochasticLossHistory.append(lossSum / miniBatchSize + reg_val)
/*** /*
Xinghao: The loss here is sum of lossSum computed using the weights before applying updater, * NOTE(Xinghao): The loss here is sum of lossSum computed using the weights before applying updater,
and reg_val using weights after applying updater * and reg_val using weights after applying updater
***/ */
} }
(weights.toArray, stochasticLossHistory.toArray) (weights.toArray, stochasticLossHistory.toArray)
......
...@@ -46,17 +46,25 @@ class SimpleUpdater extends Updater { ...@@ -46,17 +46,25 @@ class SimpleUpdater extends Updater {
} }
/** /**
L1 regularization -- corresponding proximal operator is the soft-thresholding function * L1 regularization -- corresponding proximal operator is the soft-thresholding function
* That is, each weight component is shrunk towards 0 by shrinkageVal
* If w > shrinkageVal, set weight component to w-shrinkageVal.
* If w < -shrinkageVal, set weight component to w+shrinkageVal.
* If -shrinkageVal < w < shrinkageVal, set weight component to 0.
* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
**/ **/
class L1Updater extends Updater { class L1Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter) val thisIterStepSize = stepSize / math.sqrt(iter)
val normGradient = gradient.mul(thisIterStepSize) val normGradient = gradient.mul(thisIterStepSize)
// Take gradient step
val newWeights = weightsOld.sub(normGradient) val newWeights = weightsOld.sub(normGradient)
// Soft thresholding
val shrinkageVal = regParam * thisIterStepSize
(0 until newWeights.length).foreach(i => { (0 until newWeights.length).foreach(i => {
val wi = newWeights.get(i) val wi = newWeights.get(i)
newWeights.put(i, signum(wi) * max(0.0, abs(wi) - regParam * thisIterStepSize)) newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal))
}) })
(newWeights, newWeights.norm1 * regParam) (newWeights, newWeights.norm1 * regParam)
} }
......
...@@ -25,8 +25,6 @@ import org.scalatest.FunSuite ...@@ -25,8 +25,6 @@ import org.scalatest.FunSuite
import spark.SparkContext import spark.SparkContext
import java.io._
class SVMSuite extends FunSuite with BeforeAndAfterAll { class SVMSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test") val sc = new SparkContext("local", "test")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment