From 75f375730025788a5982146d97bf3df9ef69ab23 Mon Sep 17 00:00:00 2001 From: Xinghao <pxinghao@gmail.com> Date: Mon, 29 Jul 2013 09:19:56 -0700 Subject: [PATCH] Fix rounding error in LogisticRegression.scala --- .../spark/mllib/classification/LogisticRegression.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala index cbc0d03ae1..bc1c327729 100644 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala @@ -21,6 +21,8 @@ import spark.{Logging, RDD, SparkContext} import spark.mllib.optimization._ import spark.mllib.util.MLUtils +import scala.math.round + import org.jblas.DoubleMatrix /** @@ -42,14 +44,14 @@ class LogisticRegressionModel( val localIntercept = intercept testData.map { x => val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept - (1.0/ (1.0 + math.exp(margin * -1))).toInt + round(1.0/ (1.0 + math.exp(margin * -1))).toInt } } override def predict(testData: Array[Double]): Int = { val dataMat = new DoubleMatrix(1, testData.length, testData:_*) val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept - (1.0/ (1.0 + math.exp(margin * -1))).toInt + round(1.0/ (1.0 + math.exp(margin * -1))).toInt } } -- GitLab