From 75f375730025788a5982146d97bf3df9ef69ab23 Mon Sep 17 00:00:00 2001
From: Xinghao <pxinghao@gmail.com>
Date: Mon, 29 Jul 2013 09:19:56 -0700
Subject: [PATCH] Fix rounding error in LogisticRegression.scala

---
 .../spark/mllib/classification/LogisticRegression.scala     | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index cbc0d03ae1..bc1c327729 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -21,6 +21,8 @@ import spark.{Logging, RDD, SparkContext}
 import spark.mllib.optimization._
 import spark.mllib.util.MLUtils
 
+import scala.math.round
+
 import org.jblas.DoubleMatrix
 
 /**
@@ -42,14 +44,14 @@ class LogisticRegressionModel(
     val localIntercept = intercept
     testData.map { x =>
       val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
-      (1.0/ (1.0 + math.exp(margin * -1))).toInt
+      round(1.0/ (1.0 + math.exp(margin * -1))).toInt
     }
   }
 
   override def predict(testData: Array[Double]): Int = {
     val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
     val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept
-    (1.0/ (1.0 + math.exp(margin * -1))).toInt
+    round(1.0/ (1.0 + math.exp(margin * -1))).toInt
   }
 }
 
-- 
GitLab