Skip to content
Snippets Groups Projects
Commit 6566a19b authored by Patrick Wendell's avatar Patrick Wendell
Browse files

Merge pull request #9 from rxin/limit

Smarter take/limit implementation.
parents 9d34838b 42571d30
No related branches found
No related tags found
No related merge requests found
...@@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest]( ...@@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest](
} }
/** /**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so * Take the first num elements of the RDD. It works by first scanning one partition, and use the
* it will be slow if a lot of partitions are required. In that case, use collect() to get the * results from that partition to estimate the number of additional partitions needed to satisfy
* whole RDD instead. * the limit.
*/ */
def take(num: Int): Array[T] = { def take(num: Int): Array[T] = {
if (num == 0) { if (num == 0) {
return new Array[T](0) return new Array[T](0)
} }
val buf = new ArrayBuffer[T] val buf = new ArrayBuffer[T]
var p = 0 val totalParts = this.partitions.length
while (buf.size < num && p < partitions.size) { var partsScanned = 0
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (buf.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - buf.size val left = num - buf.size
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
buf ++= res(0) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)
if (buf.size == num)
return buf.toArray res.foreach(buf ++= _.take(num - buf.size))
p += 1 partsScanned += numPartsToTry
} }
return buf.toArray return buf.toArray
} }
......
...@@ -320,6 +320,44 @@ class RDDSuite extends FunSuite with SharedSparkContext { ...@@ -320,6 +320,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
} }
test("take") {
var nums = sc.makeRDD(Range(1, 1000), 1)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)
nums = sc.makeRDD(Range(1, 1000), 2)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)
nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)
nums = sc.makeRDD(Range(1, 1000), 1000)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)
}
test("top with predefined ordering") { test("top with predefined ordering") {
val nums = Array.range(1, 100000) val nums = Array.range(1, 100000)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment