From c12dff9b82e4869f866a9b96ce0bf05503dd7dda Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh <viirya@gmail.com> Date: Tue, 19 May 2015 13:53:08 -0700 Subject: [PATCH] [SPARK-7652] [MLLIB] Update the implementation of naive Bayes prediction with BLAS JIRA: https://issues.apache.org/jira/browse/SPARK-7652 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6189 from viirya/naive_bayes_blas_prediction and squashes the following commits: ab611fd [Liang-Chi Hsieh] Remove unnecessary space. ddc48b9 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into naive_bayes_blas_prediction b5772b4 [Liang-Chi Hsieh] Fix binary compatibility. 2f65186 [Liang-Chi Hsieh] Remove toDense. 1b6cdfe [Liang-Chi Hsieh] Update the implementation of naive Bayes prediction with BLAS. --- .../mllib/classification/NaiveBayes.scala | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index ac0ebeceaa..53fb2cba03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,13 +21,11 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import breeze.numerics.{exp => brzExp, log => brzLog} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext, SparkException} -import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -50,6 +48,9 @@ class NaiveBayesModel private[mllib] ( val modelType: String) extends ClassificationModel with Serializable with Saveable { + private val piVector = new DenseVector(pi) + private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true) + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = this(labels, pi, theta, "Multinomial") @@ -60,17 +61,18 @@ class NaiveBayesModel private[mllib] ( theta: JIterable[JIterable[Double]]) = this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) - private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t - // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. - // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra // application of this condition (in predict function). - private val (brzNegTheta, brzNegThetaSum) = modelType match { + private val (thetaMinusNegTheta, negThetaSum) = modelType match { case "Multinomial" => (None, None) case "Bernoulli" => - val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) - (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) + val thetaMinusNegTheta = thetaMatrix.map { value => + value - math.log(1.0 - math.exp(value)) + } + (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") @@ -85,17 +87,22 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { - val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels(brzArgmax(brzPi + brzTheta * brzData)) + val prob = thetaMatrix.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + labels(prob.argmax) case "Bernoulli" => - if (!brzData.forall(v => v == 0.0 || v == 1.0)) { - throw new SparkException( - s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") + testData.foreachActive { (index, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") + } } - labels(brzArgmax(brzPi + - (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) + val prob = thetaMinusNegTheta.get.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + labels(prob.argmax) case _ => // This should never happen. throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") -- GitLab