Skip to content
Snippets Groups Projects
Commit 9148b968 authored by Mark Hamstra's avatar Mark Hamstra
Browse files

mapWith, flatMapWith and filterWith

parent 9f0dc829
No related branches found
No related tags found
No related merge requests found
...@@ -364,6 +364,63 @@ abstract class RDD[T: ClassManifest]( ...@@ -364,6 +364,63 @@ abstract class RDD[T: ClassManifest](
preservesPartitioning: Boolean = false): RDD[U] = preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
/**
* Maps f over this RDD where f takes an additional parameter of type A. This
* additional parameter is produced by a factory method T => A which is called
* on each invocation of f. This factory method is produced by the factoryBuilder,
* an instance of which is constructed in each partition from the partition index
* and a seed value of type B.
*/
def mapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
f:(A, T) => U,
factoryBuilder: (Int, B) => (T => A),
factorySeed: B,
preservesPartitioning: Boolean = false): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val factory = factoryBuilder(index, factorySeed)
iter.map(t => f(factory(t), t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* FlatMaps f over this RDD where f takes an additional parameter of type A. This
* additional parameter is produced by a factory method T => A which is called
* on each invocation of f. This factory method is produced by the factoryBuilder,
* an instance of which is constructed in each partition from the partition index
* and a seed value of type B.
*/
def flatMapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
f:(A, T) => Seq[U],
factoryBuilder: (Int, B) => (T => A),
factorySeed: B,
preservesPartitioning: Boolean = false): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val factory = factoryBuilder(index, factorySeed)
iter.flatMap(t => f(factory(t), t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* Filters this RDD with p, where p takes an additional parameter of type A. This
* additional parameter is produced by a factory method T => A which is called
* on each invocation of p. This factory method is produced by the factoryBuilder,
* an instance of which is constructed in each partition from the partition index
* and a seed value of type B.
*/
def filterWith[A: ClassManifest, B: ClassManifest](
p:(A, T) => Boolean,
factoryBuilder: (Int, B) => (T => A),
factorySeed: B,
preservesPartitioning: Boolean = false): RDD[T] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val factory = factoryBuilder(index, factorySeed)
iter.filter(t => p(factory(t), t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/** /**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of * second element in each RDD, etc. Assumes that the two RDDs have the *same number of
...@@ -404,7 +461,7 @@ abstract class RDD[T: ClassManifest]( ...@@ -404,7 +461,7 @@ abstract class RDD[T: ClassManifest](
/** /**
* Return an RDD with the elements from `this` that are not in `other`. * Return an RDD with the elements from `this` that are not in `other`.
* *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us. * RDD will be <= us.
*/ */
......
...@@ -178,4 +178,70 @@ class RDDSuite extends FunSuite with LocalSparkContext { ...@@ -178,4 +178,70 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(prunedData.size === 1) assert(prunedData.size === 1)
assert(prunedData(0) === 10) assert(prunedData(0) === 10)
} }
test("mapWith") {
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(random: Double, t: Int) => random * t,
(index: Int, seed: Int) => {
val prng = new java.util.Random(index + seed)
(_ => prng.nextDouble)},
42).
collect()
val prn42_3 = {
val prng42 = new java.util.Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new java.util.Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(2) === prn42_3)
assert(randoms(5) === prn43_3)
}
test("flatMapWith") {
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(random: Double, t: Int) => Seq(random * t, random * t * 10),
(index: Int, seed: Int) => {
val prng = new java.util.Random(index + seed)
(_ => prng.nextDouble)},
42).
collect()
val prn42_3 = {
val prng42 = new java.util.Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new java.util.Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(5) === prn42_3 * 10)
assert(randoms(11) === prn43_3 * 10)
}
test("filterWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(random: Int, t: Int) => random == 0,
(index: Int, seed: Int) => {
val prng = new Random(index + seed)
(_ => prng.nextInt(3))},
42).
collect()
val checkSample = {
val prng42 = new Random(42)
val prng43 = new Random(43)
Array(1, 2, 3, 4, 5, 6).filter{i =>
if (i < 4) 0 == prng42.nextInt(3)
else 0 == prng43.nextInt(3)}
}
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment