diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 0802a2ae48e47e79473a9fd7dea84b516573b48b..4758e40e41be55d0fe5b249667ad14029e838bd6 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -1007,10 +1007,11 @@ test_that("spark.randomForest", { model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 20, seed = 123) predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, - 63.736, 64.296, 64.868, 64.300, - 66.709, 67.697, 67.966, 67.252, - 68.866, 69.593, 69.195, 69.658), + expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, + 63.53160, 64.05470, 65.12710, 64.30450, + 66.70910, 67.86125, 68.08700, 67.21865, + 68.89275, 69.53180, 69.39640, 69.68250), + tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 297524c943e1fe49bde58bbb5c38f0f5b3624b2d..a7e0075debedbe488c754a65c9210eed32fb8ae7 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -56,11 +56,14 @@ private[spark] object SamplingUtils { val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() + l += 1 + // There are k elements in the reservoir, and the l-th element has been + // consumed. It should be chosen with probability k/l. The expression + // below is a random long chosen uniformly from [0,l) val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { reservoir(replacementIndex.toInt) = item } - l += 1 } (reservoir, l) } diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index 667a4db6f7bb6aa94aa3c976a9fe3ad9b313bb26..55c5dd5e2460ddd1309401a53ca3c969d8a35cf7 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite { assert(sample3.length === 10) } + test("SPARK-18678 reservoirSampleAndCount with tiny input") { + val input = Seq(0, 1) + val counts = new Array[Int](input.size) + for (i <- 0 until 500) { + val (samples, inputSize) = SamplingUtils.reservoirSampleAndCount(input.iterator, 1) + assert(inputSize === 2) + assert(samples.length === 1) + counts(samples.head) += 1 + } + // If correct, should be true with prob ~ 0.99999707 + assert(math.abs(counts(0) - counts(1)) <= 100) + } + test("computeFraction") { // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001