diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 86d94aebd9442a39fa21e341f8d20106cb2d69d7..7f9d4c656394450c45799e9b114763bcbaca14fd 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -17,7 +17,8 @@ Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bay which is typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each -feature represents a term whose value is the frequency of the term. +feature represents a term whose value is the frequency of the term. +Feature values must be nonnegative to represent term frequencies. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 6c7be0a4f1dcb1ee06d117a712706f060752fcf6..8c8e4a161aa5b5d561897bf2c5f17a9afdb2e016 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] ( * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { @@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ def run(data: RDD[LabeledPoint]) = { + val requireNonnegativeValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => + sv.values + case dv: DenseVector => + dv.values + } + if (!values.forall(_ >= 0.0)) { + throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( - createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), - mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), + createCombiner = (v: Vector) => { + requireNonnegativeValues(v) + (1L, v.toBreeze.toDenseVector) + }, + mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + requireNonnegativeValues(v) + (c._1 + 1L, c._2 += v.toBreeze) + }, mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => (c1._1 + c2._1, c1._2 += c2._2) ).collect() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 06cdd04f5fdaed9423899159232ba5bcc3f09c7c..80989bc074e8400bd03668a311c79d2f87d040ad 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} @@ -95,6 +96,33 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("detect negative values") { + val dense = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(-1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(dense, 2)) + } + val sparse = Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(sparse, 2)) + } + val nan = Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(nan, 2)) + } + } } class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {