diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index f9b6ee351a151cefcc310d17d50249fdc20b1d5b..043cb183bad179d864d92f11e68c397fc284bd29 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -93,6 +93,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions, shuffle)) + /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions)) + /** * Return an RDD with the elements from `this` that are not in `other`. * diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index b3eb739f4e701617ab8ad9d64589f140d73dbc7a..2142fd73278aca1a2dd3a94d7cfb3508fd8877ca 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -107,6 +107,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions, shuffle)) + /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions)) + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 662990049b0938f542e5966c992a2bd1554a5586..3b359a8fd60941b8cbcdcceb843d1d52c8887842 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -81,6 +81,17 @@ JavaRDDLike[T, JavaRDD[T]] { def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] = rdd.coalesce(numPartitions, shuffle) + /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0355618e435bd3d00357406ca167ab324aae4645..6e88be6f6ac64bb3d99bb56aa276c789dfba7c6a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -265,6 +265,19 @@ abstract class RDD[T: ClassManifest]( def distinct(): RDD[T] = distinct(partitions.size) + /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): RDD[T] = { + coalesce(numPartitions, true) + } + /** * Return a new RDD that is reduced into `numPartitions` partitions. * diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 7b0bb89ab28ce0306e852da05263b4a94b03b9f1..352036f182e24c676c9792f83b369d43f0fdb48b 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -472,6 +472,27 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void repartition() { + // Shrinking number of partitions + JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2); + JavaRDD<Integer> repartitioned1 = in1.repartition(4); + List<List<Integer>> result1 = repartitioned1.glom().collect(); + Assert.assertEquals(4, result1.size()); + for (List<Integer> l: result1) { + Assert.assertTrue(l.size() > 0); + } + + // Growing number of partitions + JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4); + JavaRDD<Integer> repartitioned2 = in2.repartition(2); + List<List<Integer>> result2 = repartitioned2.glom().collect(); + Assert.assertEquals(2, result2.size()); + for (List<Integer> l: result2) { + Assert.assertTrue(l.size() > 0); + } + } + @Test public void persist() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6d1bc5e296e06beb137673088229da3750c0579c..354ab8ae5d7d5c425cd3b62ea689554cc56b3294 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -139,6 +139,26 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(rdd.union(emptyKv).collect().size === 2) } + test("repartitioned RDDs") { + val data = sc.parallelize(1 to 1000, 10) + + // Coalesce partitions + val repartitioned1 = data.repartition(2) + assert(repartitioned1.partitions.size == 2) + val partitions1 = repartitioned1.glom().collect() + assert(partitions1(0).length > 0) + assert(partitions1(1).length > 0) + assert(repartitioned1.collect().toSet === (1 to 1000).toSet) + + // Split partitions + val repartitioned2 = data.repartition(20) + assert(repartitioned2.partitions.size == 20) + val partitions2 = repartitioned2.glom().collect() + assert(partitions2(0).length > 0) + assert(partitions2(19).length > 0) + assert(repartitioned2.collect().toSet === (1 to 1000).toSet) + } + test("coalesced RDDs") { val data = sc.parallelize(1 to 10, 10) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 835b257238e4bf04d2e0024a37f93ffaec13f0b5..851e30fe761af664cae684acbd86aa56cff88c65 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -72,6 +72,10 @@ DStreams support many of the transformations available on normal Spark RDD's: <td> Similar to map, but runs separately on each partition (block) of the DStream, so <i>func</i> must be of type Iterator[T] => Iterator[U] when running on an DStream of type T. </td> </tr> +<tr> + <td> <b>repartition</b>(<i>numPartitions</i>) </td> + <td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td> +</tr> <tr> <td> <b>union</b>(<i>otherStream</i>) </td> <td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 8d7cbae8214e49489ae2ee673bdab74450a42d2c..45fd30a7c836408813b473df32db3197508c3be0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -292,6 +292,7 @@ object SparkBuild extends Build { "org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1" exclude("com.sun.jdmk", "jmxtools") exclude("com.sun.jmx", "jmxri") + exclude("net.sf.jopt-simple", "jopt-simple") ) ) diff --git a/streaming/pom.xml b/streaming/pom.xml index 8022c4fe18917a1a671ceed7d779c2bd54ae6705..7a9ae6a97ba7ff615e3e24c76923dc941d76b8f3 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -73,6 +73,10 @@ <groupId>com.sun.jdmk</groupId> <artifactId>jmxtools</artifactId> </exclusion> + <exclusion> + <groupId>net.sf.jopt-simple</groupId> + <artifactId>jopt-simple</artifactId> + </exclusion> </exclusions> </dependency> <dependency> diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 38e34795b49f598bc136c6492c88996f4c42a652..9ceff754c4b7251dfc50e533d75f892505e713e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] ( */ def glom(): DStream[Array[T]] = new GlommedDStream(this) + + /** + * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the + * returned DStream has exactly numPartitions partitions. + */ + def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions)) + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index d1932b6b05a093fbf814f274dc6dc488bca73053..1a2aeaa8797e1fcbabe04505fdcd0c78e98612a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -94,6 +94,12 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM */ def union(that: JavaDStream[T]): JavaDStream[T] = dstream.union(that.dstream) + + /** + * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the + * returned DStream has exactly numPartitions partitions. + */ + def repartition(numPartitions: Int): JavaDStream[T] = dstream.repartition(numPartitions) } object JavaDStream { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 4dd6b7d096e618b8b36a2286366cce59798aa2f7..c6cd635afa0c87f903f19a07fe18034c668f2292 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -59,6 +59,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** Persist the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) + /** + * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the + * returned DStream has exactly numPartitions partitions. + */ + def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions) + /** Method that generates a RDD for the given Duration */ def compute(validTime: Time): JavaPairRDD[K, V] = { dstream.compute(validTime) match { diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 5d4890866702e9b4eb01ab5d9de4d545bad32efe..ad4a8b95355b9a185eb76d4d4e68dabb2a34bff5 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -184,6 +184,39 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result); } + @Test + public void testRepartitionMorePartitions() { + List<List<Integer>> inputData = Arrays.asList( + Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 2); + JavaDStream repartitioned = stream.repartition(4); + JavaTestUtils.attachTestOutputStream(repartitioned); + List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2); + Assert.assertEquals(2, result.size()); + for (List<List<Integer>> rdd : result) { + Assert.assertEquals(4, rdd.size()); + Assert.assertEquals( + 10, rdd.get(0).size() + rdd.get(1).size() + rdd.get(2).size() + rdd.get(3).size()); + } + } + + @Test + public void testRepartitionFewerPartitions() { + List<List<Integer>> inputData = Arrays.asList( + Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 4); + JavaDStream repartitioned = stream.repartition(2); + JavaTestUtils.attachTestOutputStream(repartitioned); + List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2); + Assert.assertEquals(2, result.size()); + for (List<List<Integer>> rdd : result) { + Assert.assertEquals(2, rdd.size()); + Assert.assertEquals(10, rdd.get(0).size() + rdd.get(1).size()); + } + } + @Test public void testGlom() { List<List<String>> inputData = Arrays.asList( diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index 8a6604904de57d39626eac6ec5f2b19405dad833..5e384eeee45f385ff83bba34f93c4b2228f5faaa 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -33,9 +33,9 @@ trait JavaTestBase extends TestSuiteBase { * The stream will be derived from the supplied lists of Java objects. **/ def attachTestInputStream[T]( - ssc: JavaStreamingContext, - data: JList[JList[T]], - numPartitions: Int) = { + ssc: JavaStreamingContext, + data: JList[JList[T]], + numPartitions: Int) = { val seqData = data.map(Seq(_:_*)) implicit val cm: ClassManifest[T] = @@ -50,12 +50,11 @@ trait JavaTestBase extends TestSuiteBase { * [[org.apache.spark.streaming.TestOutputStream]]. **/ def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]]( - dstream: JavaDStreamLike[T, This, R]) = + dstream: JavaDStreamLike[T, This, R]) = { implicit val cm: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val ostream = new TestOutputStream(dstream.dstream, - new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) + val ostream = new TestOutputStreamWithPartitions(dstream.dstream) dstream.dstream.ssc.registerOutputStream(ostream) } @@ -63,9 +62,11 @@ trait JavaTestBase extends TestSuiteBase { * Process all registered streams for a numBatches batches, failing if * numExpectedOutput RDD's are not generated. Generated RDD's are collected * and returned, represented as a list for each batch interval. + * + * Returns a list of items for each RDD. */ def runStreams[V]( - ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { + ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = { implicit val cm: ClassManifest[V] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) @@ -73,6 +74,27 @@ trait JavaTestBase extends TestSuiteBase { res.map(entry => out.append(new ArrayList[V](entry))) out } + + /** + * Process all registered streams for a numBatches batches, failing if + * numExpectedOutput RDD's are not generated. Generated RDD's are collected + * and returned, represented as a list for each batch interval. + * + * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each + * representing one partition. + */ + def runStreamsWithPartitions[V](ssc: JavaStreamingContext, numBatches: Int, + numExpectedOutput: Int): JList[JList[JList[V]]] = { + implicit val cm: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) + val out = new ArrayList[JList[JList[V]]]() + res.map{entry => + val lists = entry.map(new ArrayList[V](_)) + out.append(new ArrayList[JList[V]](lists)) + } + out + } } object JavaTestUtils extends JavaTestBase { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index a2ac510a98e0c3dcef76dd8c6a8e4075c1b40ce5..259ef1608cbc5b179372995ef4a426df0101247d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -85,6 +85,44 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(input, operation, output, true) } + test("repartition (more partitions)") { + val input = Seq(1 to 100, 101 to 200, 201 to 300) + val operation = (r: DStream[Int]) => r.repartition(5) + val ssc = setupStreams(input, operation, 2) + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet === (1 to 100).toSet) + assert(second.flatten.toSet === (101 to 200).toSet) + assert(third.flatten.toSet === (201 to 300).toSet) + } + + test("repartition (fewer partitions)") { + val input = Seq(1 to 100, 101 to 200, 201 to 300) + val operation = (r: DStream[Int]) => r.repartition(2) + val ssc = setupStreams(input, operation, 5) + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet === (1 to 100).toSet) + assert(second.flatten.toSet === (101 to 200).toSet) + assert(third.flatten.toSet === (201 to 300).toSet) + } + test("groupByKey") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index a327de80b3eb3a0f98bcbd2486e0580225b93a8a..beb20831bd7b49d5a9c9443149a4c1eb9171384a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -366,7 +366,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { logInfo("Manual clock after advancing = " + clock.time) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] - outputStream.output + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + outputStream.output.map(_.flatten) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 37dd9c4cc61d29d28f247bbdf9044b80e8fce76b..be140699c2964456a747a076120987c9abc26f39 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -60,8 +60,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ /** * This is a output stream just for the testsuites. All the output is collected into a * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * + * The buffer contains a sequence of RDD's, each containing a sequence of items */ -class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) +class TestOutputStream[T: ClassManifest](parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected @@ -75,6 +78,30 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu } } +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * + * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each + * containing a sequence of items. + */ +class TestOutputStreamWithPartitions[T: ClassManifest](parent: DStream[T], + val output: ArrayBuffer[Seq[Seq[T]]] = ArrayBuffer[Seq[Seq[T]]]()) + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.glom().collect().map(_.toSeq) + output += collected + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } + + def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) +} + /** * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. @@ -108,7 +135,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { */ def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], - operation: DStream[U] => DStream[V] + operation: DStream[U] => DStream[V], + numPartitions: Int = numInputPartitions ): StreamingContext = { // Create StreamingContext @@ -118,9 +146,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { } // Setup the stream computation - val inputStream = new TestInputStream(ssc, input, numInputPartitions) + val inputStream = new TestInputStream(ssc, input, numPartitions) val operatedStream = operation(inputStream) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) ssc.registerInputStream(inputStream) ssc.registerOutputStream(outputStream) ssc @@ -146,7 +175,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) val operatedStream = operation(inputStream1, inputStream2) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]]) ssc.registerInputStream(inputStream1) ssc.registerInputStream(inputStream2) ssc.registerOutputStream(outputStream) @@ -157,18 +187,37 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and * returns the collected output. It will wait until `numExpectedOutput` number of * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + * + * Returns a sequence of items for each RDD. */ def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { + // Flatten each RDD into a single Seq + runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq) + } + + /** + * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and + * returns the collected output. It will wait until `numExpectedOutput` number of + * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + * + * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each + * representing one partition. + */ + def runStreamsWithPartitions[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + numExpectedOutput: Int + ): Seq[Seq[Seq[V]]] = { assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] val output = outputStream.output try {