Skip to content
Snippets Groups Projects
Commit d6f76eb3 authored by Sean Owen's avatar Sean Owen
Browse files

[SPARK-21057][ML] Do not use a PascalDistribution in countApprox

## What changes were proposed in this pull request?

Use Poisson analysis for approx count in all cases.

## How was this patch tested?

Existing tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #18276 from srowen/SPARK-21057.
parent 4d01aa46
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
package org.apache.spark.partial package org.apache.spark.partial
import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution} import org.apache.commons.math3.distribution.PoissonDistribution
/** /**
* An ApproximateEvaluator for counts. * An ApproximateEvaluator for counts.
...@@ -48,22 +48,11 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) ...@@ -48,22 +48,11 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
private[partial] object CountEvaluator { private[partial] object CountEvaluator {
def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = { def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = {
// Let the total count be N. A fraction p has been counted already, with sum 'sum', // "sum" elements have been observed having scanned a fraction
// as if each element from the total data set had been seen with probability p. // p of the data. This suggests data is counted at a rate of sum / p across the whole data
val dist = // set. The total expected count from the rest is distributed as
if (sum <= 10000) { // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p)
// The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal), val dist = new PoissonDistribution(sum * (1 - p) / p)
// where there have been 'sum' successes of probability p already. (There are several
// conventions, but this is the one followed by Commons Math3.)
new PascalDistribution(sum.toInt, p)
} else {
// For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has
// a different interpretation. "sum" elements have been observed having scanned a fraction
// p of the data. This suggests data is counted at a rate of sum / p across the whole data
// set. The total expected count from the rest is distributed as
// (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p)
new PoissonDistribution(sum * (1 - p) / p)
}
// Not quite symmetric; calculate interval straight from discrete distribution // Not quite symmetric; calculate interval straight from discrete distribution
val low = dist.inverseCumulativeProbability((1 - confidence) / 2) val low = dist.inverseCumulativeProbability((1 - confidence) / 2)
val high = dist.inverseCumulativeProbability((1 + confidence) / 2) val high = dist.inverseCumulativeProbability((1 + confidence) / 2)
......
...@@ -23,21 +23,21 @@ class CountEvaluatorSuite extends SparkFunSuite { ...@@ -23,21 +23,21 @@ class CountEvaluatorSuite extends SparkFunSuite {
test("test count 0") { test("test count 0") {
val evaluator = new CountEvaluator(10, 0.95) val evaluator = new CountEvaluator(10, 0.95)
assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity))
evaluator.merge(1, 0) evaluator.merge(1, 0)
assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) assert(evaluator.currentResult() === new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity))
} }
test("test count >= 1") { test("test count >= 1") {
val evaluator = new CountEvaluator(10, 0.95) val evaluator = new CountEvaluator(10, 0.95)
evaluator.merge(1, 1) evaluator.merge(1, 1)
assert(new BoundedDouble(10.0, 0.95, 1.0, 36.0) == evaluator.currentResult()) assert(evaluator.currentResult() === new BoundedDouble(10.0, 0.95, 5.0, 16.0))
evaluator.merge(1, 3) evaluator.merge(1, 3)
assert(new BoundedDouble(20.0, 0.95, 7.0, 41.0) == evaluator.currentResult()) assert(evaluator.currentResult() === new BoundedDouble(20.0, 0.95, 13.0, 28.0))
evaluator.merge(1, 8) evaluator.merge(1, 8)
assert(new BoundedDouble(40.0, 0.95, 24.0, 61.0) == evaluator.currentResult()) assert(evaluator.currentResult() === new BoundedDouble(40.0, 0.95, 30.0, 51.0))
(4 to 10).foreach(_ => evaluator.merge(1, 10)) (4 to 10).foreach(_ => evaluator.merge(1, 10))
assert(new BoundedDouble(82.0, 1.0, 82.0, 82.0) == evaluator.currentResult()) assert(evaluator.currentResult() === new BoundedDouble(82.0, 1.0, 82.0, 82.0))
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment