Skip to content
Snippets Groups Projects
Commit 08c1a42d authored by Patrick Wendell's avatar Patrick Wendell
Browse files

Add a `repartition` operator.

This patch adds an operator called repartition with more straightforward
semantics than the current `coalesce` operator. There are a few use cases
where this operator is useful:

1. If a user wants to increase the number of partitions in the RDD. This
is more common now with streaming. E.g. a user is ingesting data on one
node but they want to add more partitions to ensure parallelism of
subsequent operations across threads or the cluster.

Right now they have to call rdd.coalesce(numSplits, shuffle=true) - that's
super confusing.

2. If a user has input data where the number of partitions is not known. E.g.

> sc.textFile("some file").coalesce(50)....

This is both vague semantically (am I growing or shrinking this RDD) but also,
may not work correctly if the base RDD has fewer than 50 partitions.

The new operator forces shuffles every time, so it will always produce exactly
the number of new partitions. It also throws an exception rather than silently
not-working if a bad input is passed.

I am currently adding streaming tests (requires refactoring some of the test
suite to allow testing at partition granularity), so this is not ready for
merge yet. But feedback is welcome.
parent 1dc776b8
No related branches found
No related tags found
No related merge requests found
...@@ -265,6 +265,27 @@ abstract class RDD[T: ClassManifest]( ...@@ -265,6 +265,27 @@ abstract class RDD[T: ClassManifest](
def distinct(): RDD[T] = distinct(partitions.size) def distinct(): RDD[T] = distinct(partitions.size)
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Used to increase or decrease the level of parallelism in this RDD. By default, this will use
* a shuffle to redistribute data. If you are shrinking the RDD into fewer partitions, you can
* set skipShuffle = false to avoid a shuffle. Skipping shuffles is not supported when
* increasing the number of partitions.
*
* Similar to `coalesce`, but shuffles by default, allowing you to call this safely even
* if you don't know the number of partitions.
*/
def repartition(numPartitions: Int, skipShuffle: Boolean = false): RDD[T] = {
if (skipShuffle && numPartitions > this.partitions.size) {
val msg = "repartition must grow %s from %s to %s partitions, cannot skip shuffle.".format(
this.name, this.partitions.size, numPartitions
)
throw new IllegalArgumentException(msg)
}
coalesce(numPartitions, !skipShuffle)
}
/** /**
* Return a new RDD that is reduced into `numPartitions` partitions. * Return a new RDD that is reduced into `numPartitions` partitions.
* *
......
...@@ -139,6 +139,39 @@ class RDDSuite extends FunSuite with SharedSparkContext { ...@@ -139,6 +139,39 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.union(emptyKv).collect().size === 2) assert(rdd.union(emptyKv).collect().size === 2)
} }
test("repartitioned RDDs") {
val data = sc.parallelize(1 to 1000, 10)
// Coalesce partitions
val repartitioned1 = data.repartition(2)
assert(repartitioned1.partitions.size == 2)
val partitions1 = repartitioned1.glom().collect()
assert(partitions1(0).length > 0)
assert(partitions1(1).length > 0)
assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
// Split partitions
val repartitioned2 = data.repartition(20)
assert(repartitioned2.partitions.size == 20)
val partitions2 = repartitioned2.glom().collect()
assert(partitions2(0).length > 0)
assert(partitions2(19).length > 0)
assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
// Coalesce partitions - no shuffle
val repartitioned3 = data.repartition(2, skipShuffle = true)
assert(repartitioned3.partitions.size == 2)
val partitions3 = repartitioned3.glom().collect()
assert(partitions3(0).toList === (1 to 500).toList)
assert(partitions3(1).toList === (501 to 1000).toList)
assert(repartitioned3.collect().toSet === (1 to 1000).toSet)
// Split partitions - no shuffle (should throw exn)
intercept[IllegalArgumentException] {
data.repartition(20, skipShuffle = true)
}
}
test("coalesced RDDs") { test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10) val data = sc.parallelize(1 to 10, 10)
......
...@@ -72,6 +72,10 @@ DStreams support many of the transformations available on normal Spark RDD's: ...@@ -72,6 +72,10 @@ DStreams support many of the transformations available on normal Spark RDD's:
<td> Similar to map, but runs separately on each partition (block) of the DStream, so <i>func</i> must be of type <td> Similar to map, but runs separately on each partition (block) of the DStream, so <i>func</i> must be of type
Iterator[T] => Iterator[U] when running on an DStream of type T. </td> Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
</tr> </tr>
<tr>
<td> <b>repartition</b>(<i>numPartitions</i>) </td>
<td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td>
</tr>
<tr> <tr>
<td> <b>union</b>(<i>otherStream</i>) </td> <td> <b>union</b>(<i>otherStream</i>) </td>
<td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td> <td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>
......
...@@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] ( ...@@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] (
*/ */
def glom(): DStream[Array[T]] = new GlommedDStream(this) def glom(): DStream[Array[T]] = new GlommedDStream(this)
/**
* Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
* returned DStream has exactly numPartitions partitions.
*/
def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions))
/** /**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
......
...@@ -123,6 +123,13 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T ...@@ -123,6 +123,13 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
def glom(): JavaDStream[JList[T]] = def glom(): JavaDStream[JList[T]] =
new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
/**
* Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
* returned DStream has exactly numPartitions partitions.
*/
def repartition(numPartitions: Int): JavaDStream[T] =
new JavaDStream(dstream.repartition(numPartitions))
/** Return the StreamingContext associated with this DStream */ /** Return the StreamingContext associated with this DStream */
def context(): StreamingContext = dstream.context() def context(): StreamingContext = dstream.context()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment