Skip to content
Snippets Groups Projects
Commit c12dff9b authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Xiangrui Meng
Browse files

[SPARK-7652] [MLLIB] Update the implementation of naive Bayes prediction with BLAS

JIRA: https://issues.apache.org/jira/browse/SPARK-7652

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #6189 from viirya/naive_bayes_blas_prediction and squashes the following commits:

ab611fd [Liang-Chi Hsieh] Remove unnecessary space.
ddc48b9 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into naive_bayes_blas_prediction
b5772b4 [Liang-Chi Hsieh] Fix binary compatibility.
2f65186 [Liang-Chi Hsieh] Remove toDense.
1b6cdfe [Liang-Chi Hsieh] Update the implementation of naive Bayes prediction with BLAS.
parent 68fb2a46
No related branches found
No related tags found
No related merge requests found
...@@ -21,13 +21,11 @@ import java.lang.{Iterable => JIterable} ...@@ -21,13 +21,11 @@ import java.lang.{Iterable => JIterable}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.numerics.{exp => brzExp, log => brzLog}
import org.json4s.JsonDSL._ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
...@@ -50,6 +48,9 @@ class NaiveBayesModel private[mllib] ( ...@@ -50,6 +48,9 @@ class NaiveBayesModel private[mllib] (
val modelType: String) val modelType: String)
extends ClassificationModel with Serializable with Saveable { extends ClassificationModel with Serializable with Saveable {
private val piVector = new DenseVector(pi)
private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, "Multinomial") this(labels, pi, theta, "Multinomial")
...@@ -60,17 +61,18 @@ class NaiveBayesModel private[mllib] ( ...@@ -60,17 +61,18 @@ class NaiveBayesModel private[mllib] (
theta: JIterable[JIterable[Double]]) = theta: JIterable[JIterable[Double]]) =
this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
// application of this condition (in predict function). // application of this condition (in predict function).
private val (brzNegTheta, brzNegThetaSum) = modelType match { private val (thetaMinusNegTheta, negThetaSum) = modelType match {
case "Multinomial" => (None, None) case "Multinomial" => (None, None)
case "Bernoulli" => case "Bernoulli" =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
(Option(negTheta), Option(brzSum(negTheta, Axis._1))) val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
val thetaMinusNegTheta = thetaMatrix.map { value =>
value - math.log(1.0 - math.exp(value))
}
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ => case _ =>
// This should never happen. // This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
...@@ -85,17 +87,22 @@ class NaiveBayesModel private[mllib] ( ...@@ -85,17 +87,22 @@ class NaiveBayesModel private[mllib] (
} }
override def predict(testData: Vector): Double = { override def predict(testData: Vector): Double = {
val brzData = testData.toBreeze
modelType match { modelType match {
case "Multinomial" => case "Multinomial" =>
labels(brzArgmax(brzPi + brzTheta * brzData)) val prob = thetaMatrix.multiply(testData)
BLAS.axpy(1.0, piVector, prob)
labels(prob.argmax)
case "Bernoulli" => case "Bernoulli" =>
if (!brzData.forall(v => v == 0.0 || v == 1.0)) { testData.foreachActive { (index, value) =>
throw new SparkException( if (value != 0.0 && value != 1.0) {
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
}
} }
labels(brzArgmax(brzPi + val prob = thetaMinusNegTheta.get.multiply(testData)
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) BLAS.axpy(1.0, piVector, prob)
BLAS.axpy(1.0, negThetaSum.get, prob)
labels(prob.argmax)
case _ => case _ =>
// This should never happen. // This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment