diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index a3845d39777a45ce631fc20e74aa848f9d10d208..5694b3890fba4de660b894d62cef040d7bad9d5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -207,13 +207,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin def setMinTF(value: Double): this.type = set(minTF, value) /** - * Binary toggle to control the output vector values. - * If True, all non zero counts are set to 1. This is useful for discrete probabilistic - * models that model binary events rather than integer counts - * - * Default: false - * @group param - */ + * Binary toggle to control the output vector values. + * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for + * discrete probabilistic models that model binary events rather than integer counts. + * Default: false + * @group param + */ val binary: BooleanParam = new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " + "This is useful for discrete probabilistic models that model binary events rather " + @@ -248,17 +247,13 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } tokenCount += 1 } - val effectiveMinTF = if (minTf >= 1.0) { - minTf - } else { - tokenCount * minTf - } + val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf val effectiveCounts = if ($(binary)) { termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq - } - else { + } else { termCounts.filter(_._2 >= effectiveMinTF).toSeq } + Vectors.sparse(dictBr.value.size, effectiveCounts) } dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))