diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index cbc0d03ae1a256accb7078618a22d41150b65cdf..bc1c32772928b09c73ba5b0be9389115e318432a 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -21,6 +21,8 @@ import spark.{Logging, RDD, SparkContext}
 import spark.mllib.optimization._
 import spark.mllib.util.MLUtils
 
+import scala.math.round
+
 import org.jblas.DoubleMatrix
 
 /**
@@ -42,14 +44,14 @@ class LogisticRegressionModel(
     val localIntercept = intercept
     testData.map { x =>
       val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
-      (1.0/ (1.0 + math.exp(margin * -1))).toInt
+      round(1.0/ (1.0 + math.exp(margin * -1))).toInt
     }
   }
 
   override def predict(testData: Array[Double]): Int = {
     val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
     val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept
-    (1.0/ (1.0 + math.exp(margin * -1))).toInt
+    round(1.0/ (1.0 + math.exp(margin * -1))).toInt
   }
 }