diff --git a/mllib/src/main/scala/spark/mllib/classification/Classification.scala b/mllib/src/main/scala/spark/mllib/classification/Classification.scala
index 96d7a54f187b89bc1a8645727ae1e255e6b1c2b4..d6154b66aed7df4686dc0d552f6b8e8d5cde4e97 100644
--- a/mllib/src/main/scala/spark/mllib/classification/Classification.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/Classification.scala
@@ -7,15 +7,15 @@ trait ClassificationModel extends Serializable {
    * Predict values for the given data set using the model trained.
    *
    * @param testData RDD representing data points to be predicted
-   * @return RDD[Double] where each entry contains the corresponding prediction
+   * @return RDD[Int] where each entry contains the corresponding prediction
    */
-  def predict(testData: RDD[Array[Double]]): RDD[Double]
+  def predict(testData: RDD[Array[Double]]): RDD[Int]
 
   /**
    * Predict values for a single data point using the model trained.
    *
    * @param testData array representing a single data point
-   * @return Double prediction from the trained model
+   * @return Int prediction from the trained model
    */
-  def predict(testData: Array[Double]): Double
+  def predict(testData: Array[Double]): Int
 }
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index 1b093187f29452350c908406b06521f9b9363a09..0a7effb1d77d230b6ef755f96a14c18115d33f71 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -35,21 +35,21 @@ class LogisticRegressionModel(
   // Create a column vector that can be used for predictions
   private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
 
-  override def predict(testData: spark.RDD[Array[Double]]) = {
+  override def predict(testData: spark.RDD[Array[Double]]): RDD[Int] = {
     // A small optimization to avoid serializing the entire model. Only the weightsMatrix
     // and intercept is needed.
     val localWeights = weightsMatrix
     val localIntercept = intercept
     testData.map { x =>
       val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
-      1.0/ (1.0 + math.exp(margin * -1))
+      (1.0/ (1.0 + math.exp(margin * -1))).toInt
     }
   }
 
-  override def predict(testData: Array[Double]): Double = {
+  override def predict(testData: Array[Double]): Int = {
     val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
     val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept
-    1.0/ (1.0 + math.exp(margin * -1))
+    (1.0/ (1.0 + math.exp(margin * -1))).toInt
   }
 }
 
@@ -70,14 +70,6 @@ class LogisticRegressionLocalRandomSGD private (var stepSize: Double, var miniBa
     this
   }
 
-
-
-
-
-
-
-
-
   /**
    * Set fraction of data to be used for each SGD iteration. Default 1.0.
    */
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
index 76844f6b9c3392c351da9d2f49452df604ef9257..30766a4c64c8a6a7e448d363bbe625bc70800521 100644
--- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
@@ -35,19 +35,19 @@ class SVMModel(
   // Create a column vector that can be used for predictions
   private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
 
-  override def predict(testData: spark.RDD[Array[Double]]) = {
+  override def predict(testData: spark.RDD[Array[Double]]): RDD[Int] = {
     // A small optimization to avoid serializing the entire model. Only the weightsMatrix
     // and intercept is needed.
     val localWeights = weightsMatrix
     val localIntercept = intercept
     testData.map { x => 
-      signum(new DoubleMatrix(1, x.length, x:_*).dot(localWeights) + localIntercept)
+      signum(new DoubleMatrix(1, x.length, x:_*).dot(localWeights) + localIntercept).toInt
     }
   }
 
-  override def predict(testData: Array[Double]): Double = {
+  override def predict(testData: Array[Double]): Int = {
     val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
-    signum(dataMat.dot(weightsMatrix) + this.intercept)
+    signum(dataMat.dot(weightsMatrix) + this.intercept).toInt
   }
 }
 
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
index 4864ab7ccfc0c2003a4141a1af982571337a41ca..22b2ec5ed60f9fec9e731205e38d5fbaba1319b1 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
@@ -70,8 +70,8 @@ class HingeGradient extends Gradient {
     val dotProduct = data.dot(weights)
 
     if (1.0 > label * dotProduct)
-      (data.mul(-label),                        1.0 - label * dotProduct)
+      (data.mul(-label), 1.0 - label * dotProduct)
     else
-      (DoubleMatrix.zeros(1,weights.length),    0.0)
+      (DoubleMatrix.zeros(1,weights.length), 0.0)
   }
 }
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index 8387d4939b96bdec0e043e6975e26e77f1a6b4e7..d4b83a14561d3cd3d54d6650918609b9e4eafd8c 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -76,10 +76,10 @@ object GradientDescent {
       weights = update._1
       reg_val = update._2
       stochasticLossHistory.append(lossSum / miniBatchSize + reg_val)
-      /***
-      Xinghao: The loss here is sum of lossSum computed using the weights before applying updater,
-      and reg_val using weights after applying updater
-      ***/
+      /*
+      * NOTE(Xinghao): The loss here is sum of lossSum computed using the weights before applying updater,
+      * and reg_val using weights after applying updater
+      */
     }
 
     (weights.toArray, stochasticLossHistory.toArray)
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
index cd344a668040dd7d2359dd2fe09028ee44751c60..188fe7d972d7bbe2ebab2bd4958da815ad0534ad 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
@@ -46,17 +46,25 @@ class SimpleUpdater extends Updater {
 }
 
 /**
-L1 regularization -- corresponding proximal operator is the soft-thresholding function
+* L1 regularization -- corresponding proximal operator is the soft-thresholding function
+* That is, each weight component is shrunk towards 0 by shrinkageVal
+* If w >  shrinkageVal, set weight component to w-shrinkageVal.
+* If w < -shrinkageVal, set weight component to w+shrinkageVal.
+* If -shrinkageVal < w < shrinkageVal, set weight component to 0.
+* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
 **/
 class L1Updater extends Updater {
   override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
       stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
     val thisIterStepSize = stepSize / math.sqrt(iter)
     val normGradient = gradient.mul(thisIterStepSize)
+    // Take gradient step
     val newWeights = weightsOld.sub(normGradient)
+    // Soft thresholding
+    val shrinkageVal = regParam * thisIterStepSize
     (0 until newWeights.length).foreach(i => {
       val wi = newWeights.get(i)
-      newWeights.put(i, signum(wi) * max(0.0, abs(wi) - regParam * thisIterStepSize))
+      newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal))
       })
     (newWeights, newWeights.norm1 * regParam)
   }
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 2a23825acccc8fc6f2df90bf9c17c6e75d4ddd67..91c037e9b1682ca67ebb7cbea490ddf8f2175590 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -25,8 +25,6 @@ import org.scalatest.FunSuite
 
 import spark.SparkContext
 
-import java.io._
-
 class SVMSuite extends FunSuite with BeforeAndAfterAll {
   val sc = new SparkContext("local", "test")