diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391337baf00d1ac935a225414ff74d895..5163c8013475d73a6e25efedf9b70cc0baaeeec9 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -329,6 +329,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial */ def toArray(): Array[T] = collect() + /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + /** * Reduces the elements of this RDD using the specified associative binary operator. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840c6f8ba82a2b378bf4c30fc7677a2b6..872b06fd080a96305b951525e934ca77542a26ba 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7))