diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0355618e435bd3d00357406ca167ab324aae4645..e2652f13c46b86e64b2ceb94b4cdafb4ffd81e27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -265,6 +265,27 @@ abstract class RDD[T: ClassManifest]( def distinct(): RDD[T] = distinct(partitions.size) + /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Used to increase or decrease the level of parallelism in this RDD. By default, this will use + * a shuffle to redistribute data. If you are shrinking the RDD into fewer partitions, you can + * set skipShuffle = false to avoid a shuffle. Skipping shuffles is not supported when + * increasing the number of partitions. + * + * Similar to `coalesce`, but shuffles by default, allowing you to call this safely even + * if you don't know the number of partitions. + */ + def repartition(numPartitions: Int, skipShuffle: Boolean = false): RDD[T] = { + if (skipShuffle && numPartitions > this.partitions.size) { + val msg = "repartition must grow %s from %s to %s partitions, cannot skip shuffle.".format( + this.name, this.partitions.size, numPartitions + ) + throw new IllegalArgumentException(msg) + } + coalesce(numPartitions, !skipShuffle) + } + /** * Return a new RDD that is reduced into `numPartitions` partitions. * diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6d1bc5e296e06beb137673088229da3750c0579c..fd00183668b6174cb39d19add5715607eb2081f9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -139,6 +139,39 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(rdd.union(emptyKv).collect().size === 2) } + test("repartitioned RDDs") { + val data = sc.parallelize(1 to 1000, 10) + + // Coalesce partitions + val repartitioned1 = data.repartition(2) + assert(repartitioned1.partitions.size == 2) + val partitions1 = repartitioned1.glom().collect() + assert(partitions1(0).length > 0) + assert(partitions1(1).length > 0) + assert(repartitioned1.collect().toSet === (1 to 1000).toSet) + + // Split partitions + val repartitioned2 = data.repartition(20) + assert(repartitioned2.partitions.size == 20) + val partitions2 = repartitioned2.glom().collect() + assert(partitions2(0).length > 0) + assert(partitions2(19).length > 0) + assert(repartitioned2.collect().toSet === (1 to 1000).toSet) + + // Coalesce partitions - no shuffle + val repartitioned3 = data.repartition(2, skipShuffle = true) + assert(repartitioned3.partitions.size == 2) + val partitions3 = repartitioned3.glom().collect() + assert(partitions3(0).toList === (1 to 500).toList) + assert(partitions3(1).toList === (501 to 1000).toList) + assert(repartitioned3.collect().toSet === (1 to 1000).toSet) + + // Split partitions - no shuffle (should throw exn) + intercept[IllegalArgumentException] { + data.repartition(20, skipShuffle = true) + } + } + test("coalesced RDDs") { val data = sc.parallelize(1 to 10, 10) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 835b257238e4bf04d2e0024a37f93ffaec13f0b5..851e30fe761af664cae684acbd86aa56cff88c65 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -72,6 +72,10 @@ DStreams support many of the transformations available on normal Spark RDD's: <td> Similar to map, but runs separately on each partition (block) of the DStream, so <i>func</i> must be of type Iterator[T] => Iterator[U] when running on an DStream of type T. </td> </tr> +<tr> + <td> <b>repartition</b>(<i>numPartitions</i>) </td> + <td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td> +</tr> <tr> <td> <b>union</b>(<i>otherStream</i>) </td> <td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td> diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 80da6bd30b4b99663f7517531e5d005f62ba8c1c..6da2261f06af400fb28ae611fd6faed3ea3983f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] ( */ def glom(): DStream[Array[T]] = new GlommedDStream(this) + + /** + * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the + * returned DStream has exactly numPartitions partitions. + */ + def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions)) + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 459695b7cabab6c70da9a646d5697592ae35703b..eae517cff0e353115712633b6222d676b3dde31d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -123,6 +123,13 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def glom(): JavaDStream[JList[T]] = new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + /** + * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the + * returned DStream has exactly numPartitions partitions. + */ + def repartition(numPartitions: Int): JavaDStream[T] = + new JavaDStream(dstream.repartition(numPartitions)) + /** Return the StreamingContext associated with this DStream */ def context(): StreamingContext = dstream.context()