Skip to content
Snippets Groups Projects
Unverified Commit 79f5f281 authored by Sean Owen's avatar Sean Owen
Browse files

[SPARK-18678][ML] Skewed reservoir sampling in SamplingUtils

## What changes were proposed in this pull request?

Fix reservoir sampling bias for small k. An off-by-one error meant that the probability of replacement was slightly too high -- k/(l-1) after l element instead of k/l, which matters for small k.

## How was this patch tested?

Existing test plus new test case.

Author: Sean Owen <sowen@cloudera.com>

Closes #16129 from srowen/SPARK-18678.
parent b8280271
No related branches found
No related tags found
No related merge requests found
......@@ -1007,10 +1007,11 @@ test_that("spark.randomForest", {
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
numTrees = 20, seed = 123)
predictions <- collect(predict(model, data))
expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
63.736, 64.296, 64.868, 64.300,
66.709, 67.697, 67.966, 67.252,
68.866, 69.593, 69.195, 69.658),
expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070,
63.53160, 64.05470, 65.12710, 64.30450,
66.70910, 67.86125, 68.08700, 67.21865,
68.89275, 69.53180, 69.39640, 69.68250),
tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$numTrees, 20)
......
......@@ -56,11 +56,14 @@ private[spark] object SamplingUtils {
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
l += 1
// There are k elements in the reservoir, and the l-th element has been
// consumed. It should be chosen with probability k/l. The expression
// below is a random long chosen uniformly from [0,l)
val replacementIndex = (rand.nextDouble() * l).toLong
if (replacementIndex < k) {
reservoir(replacementIndex.toInt) = item
}
l += 1
}
(reservoir, l)
}
......
......@@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite {
assert(sample3.length === 10)
}
test("SPARK-18678 reservoirSampleAndCount with tiny input") {
val input = Seq(0, 1)
val counts = new Array[Int](input.size)
for (i <- 0 until 500) {
val (samples, inputSize) = SamplingUtils.reservoirSampleAndCount(input.iterator, 1)
assert(inputSize === 2)
assert(samples.length === 1)
counts(samples.head) += 1
}
// If correct, should be true with prob ~ 0.99999707
assert(math.abs(counts(0) - counts(1)) <= 100)
}
test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment