Skip to content
Snippets Groups Projects
Commit 586e716e authored by Reynold Xin's avatar Reynold Xin
Browse files

Reservoir sampling implementation.

This is going to be used in https://issues.apache.org/jira/browse/SPARK-2568

Author: Reynold Xin <rxin@apache.org>

Closes #1478 from rxin/reservoirSample and squashes the following commits:

17bcbf3 [Reynold Xin] Added seed.
badf20d [Reynold Xin] Renamed the method.
6940010 [Reynold Xin] Reservoir sampling implementation.
parent 7f87ab98
No related branches found
No related tags found
No related merge requests found
......@@ -17,8 +17,54 @@
package org.apache.spark.util.random
import scala.reflect.ClassTag
import scala.util.Random
private[spark] object SamplingUtils {
/**
* Reservoir sampling implementation that also returns the input size.
*
* @param input input size
* @param k reservoir size
* @param seed random seed
* @return (samples, input size)
*/
def reservoirSampleAndCount[T: ClassTag](
input: Iterator[T],
k: Int,
seed: Long = Random.nextLong())
: (Array[T], Int) = {
val reservoir = new Array[T](k)
// Put the first k elements in the reservoir.
var i = 0
while (i < k && input.hasNext) {
val item = input.next()
reservoir(i) = item
i += 1
}
// If we have consumed all the elements, return them. Otherwise do the replacement.
if (i < k) {
// If input size < k, trim the array to return only an array of input size.
val trimReservoir = new Array[T](i)
System.arraycopy(reservoir, 0, trimReservoir, 0, i)
(trimReservoir, i)
} else {
// If input size > k, continue the sampling process.
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
val replacementIndex = rand.nextInt(i)
if (replacementIndex < k) {
reservoir(replacementIndex) = item
}
i += 1
}
(reservoir, i)
}
}
/**
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
* the time.
......
......@@ -17,11 +17,32 @@
package org.apache.spark.util.random
import scala.util.Random
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
import org.scalatest.FunSuite
class SamplingUtilsSuite extends FunSuite {
test("reservoirSampleAndCount") {
val input = Seq.fill(100)(Random.nextInt())
// input size < k
val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150)
assert(count1 === 100)
assert(input === sample1.toSeq)
// input size == k
val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100)
assert(count2 === 100)
assert(input === sample2.toSeq)
// input size > k
val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10)
assert(count3 === 100)
assert(sample3.length === 10)
}
test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment