Skip to content
Snippets Groups Projects
Commit fd53f2fc authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #510 from markhamstra/WithThing

mapWith, flatMapWith and filterWith
parents 4c5efcf6 ab33e27c
No related branches found
No related tags found
No related merge requests found
...@@ -364,6 +364,62 @@ abstract class RDD[T: ClassManifest]( ...@@ -364,6 +364,62 @@ abstract class RDD[T: ClassManifest](
preservesPartitioning: Boolean = false): RDD[U] = preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => U): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.map(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* FlatMaps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => Seq[U]): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.flatMap(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* Applies f to each element of this RDD, where f takes an additional parameter of type A.
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def foreachWith[A: ClassManifest](constructA: Int => A)
(f:(T, A) => Unit) {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.map(t => {f(t, a); t})
}
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
}
/**
* Filters this RDD with p, where p takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def filterWith[A: ClassManifest](constructA: Int => A)
(p:(T, A) => Boolean): RDD[T] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.filter(t => p(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
}
/** /**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of * second element in each RDD, etc. Assumes that the two RDDs have the *same number of
...@@ -382,6 +438,14 @@ abstract class RDD[T: ClassManifest]( ...@@ -382,6 +438,14 @@ abstract class RDD[T: ClassManifest](
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
} }
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/** /**
* Return an array that contains all of the elements in this RDD. * Return an array that contains all of the elements in this RDD.
*/ */
...@@ -404,7 +468,7 @@ abstract class RDD[T: ClassManifest]( ...@@ -404,7 +468,7 @@ abstract class RDD[T: ClassManifest](
/** /**
* Return an RDD with the elements from `this` that are not in `other`. * Return an RDD with the elements from `this` that are not in `other`.
* *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us. * RDD will be <= us.
*/ */
......
...@@ -208,4 +208,64 @@ class RDDSuite extends FunSuite with LocalSparkContext { ...@@ -208,4 +208,64 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(prunedData.size === 1) assert(prunedData.size === 1)
assert(prunedData(0) === 10) assert(prunedData(0) === 10)
} }
test("mapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(index: Int) => new Random(index + 42))
{(t: Int, prng: Random) => prng.nextDouble * t}.collect()
val prn42_3 = {
val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(2) === prn42_3)
assert(randoms(5) === prn43_3)
}
test("flatMapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(index: Int) => new Random(index + 42))
{(t: Int, prng: Random) =>
val random = prng.nextDouble()
Seq(random * t, random * t * 10)}.
collect()
val prn42_3 = {
val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(5) === prn42_3 * 10)
assert(randoms(11) === prn43_3 * 10)
}
test("filterWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(index: Int) => new Random(index + 42))
{(t: Int, prng: Random) => prng.nextInt(3) == 0}.
collect()
val checkSample = {
val prng42 = new Random(42)
val prng43 = new Random(43)
Array(1, 2, 3, 4, 5, 6).filter{i =>
if (i < 4) 0 == prng42.nextInt(3)
else 0 == prng43.nextInt(3)}
}
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment