From d6f76eb346b691a553e89d3283f2a235661ae78c Mon Sep 17 00:00:00 2001 From: Sean Owen <sowen@cloudera.com> Date: Wed, 14 Jun 2017 09:01:20 +0100 Subject: [PATCH] [SPARK-21057][ML] Do not use a PascalDistribution in countApprox ## What changes were proposed in this pull request? Use Poisson analysis for approx count in all cases. ## How was this patch tested? Existing tests. Author: Sean Owen <sowen@cloudera.com> Closes #18276 from srowen/SPARK-21057. --- .../apache/spark/partial/CountEvaluator.scala | 23 +++++-------------- .../spark/partial/CountEvaluatorSuite.scala | 12 +++++----- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 5a5bd7fbbe..cbee136871 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution.PoissonDistribution /** * An ApproximateEvaluator for counts. @@ -48,22 +48,11 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) private[partial] object CountEvaluator { def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = { - // Let the total count be N. A fraction p has been counted already, with sum 'sum', - // as if each element from the total data set had been seen with probability p. - val dist = - if (sum <= 10000) { - // The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal), - // where there have been 'sum' successes of probability p already. (There are several - // conventions, but this is the one followed by Commons Math3.) - new PascalDistribution(sum.toInt, p) - } else { - // For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has - // a different interpretation. "sum" elements have been observed having scanned a fraction - // p of the data. This suggests data is counted at a rate of sum / p across the whole data - // set. The total expected count from the rest is distributed as - // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) - new PoissonDistribution(sum * (1 - p) / p) - } + // "sum" elements have been observed having scanned a fraction + // p of the data. This suggests data is counted at a rate of sum / p across the whole data + // set. The total expected count from the rest is distributed as + // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) + val dist = new PoissonDistribution(sum * (1 - p) / p) // Not quite symmetric; calculate interval straight from discrete distribution val low = dist.inverseCumulativeProbability((1 - confidence) / 2) val high = dist.inverseCumulativeProbability((1 + confidence) / 2) diff --git a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala index da3256bd88..3c1208c2c3 100644 --- a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala @@ -23,21 +23,21 @@ class CountEvaluatorSuite extends SparkFunSuite { test("test count 0") { val evaluator = new CountEvaluator(10, 0.95) - assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity)) evaluator.merge(1, 0) - assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity)) } test("test count >= 1") { val evaluator = new CountEvaluator(10, 0.95) evaluator.merge(1, 1) - assert(new BoundedDouble(10.0, 0.95, 1.0, 36.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(10.0, 0.95, 5.0, 16.0)) evaluator.merge(1, 3) - assert(new BoundedDouble(20.0, 0.95, 7.0, 41.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(20.0, 0.95, 13.0, 28.0)) evaluator.merge(1, 8) - assert(new BoundedDouble(40.0, 0.95, 24.0, 61.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(40.0, 0.95, 30.0, 51.0)) (4 to 10).foreach(_ => evaluator.merge(1, 10)) - assert(new BoundedDouble(82.0, 1.0, 82.0, 82.0) == evaluator.currentResult()) + assert(evaluator.currentResult() === new BoundedDouble(82.0, 1.0, 82.0, 82.0)) } } -- GitLab