diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala index cbc0d03ae1a256accb7078618a22d41150b65cdf..bc1c32772928b09c73ba5b0be9389115e318432a 100644 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala @@ -21,6 +21,8 @@ import spark.{Logging, RDD, SparkContext} import spark.mllib.optimization._ import spark.mllib.util.MLUtils +import scala.math.round + import org.jblas.DoubleMatrix /** @@ -42,14 +44,14 @@ class LogisticRegressionModel( val localIntercept = intercept testData.map { x => val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept - (1.0/ (1.0 + math.exp(margin * -1))).toInt + round(1.0/ (1.0 + math.exp(margin * -1))).toInt } } override def predict(testData: Array[Double]): Int = { val dataMat = new DoubleMatrix(1, testData.length, testData:_*) val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept - (1.0/ (1.0 + math.exp(margin * -1))).toInt + round(1.0/ (1.0 + math.exp(margin * -1))).toInt } }