diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index f9b6ee351a151cefcc310d17d50249fdc20b1d5b..043cb183bad179d864d92f11e68c397fc284bd29 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -93,6 +93,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
   def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
     fromRDD(srdd.coalesce(numPartitions, shuffle))
 
+  /**
+   * Return a new RDD that has exactly numPartitions partitions.
+   *
+   * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+   * a shuffle to redistribute data.
+   *
+   * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+   * which can avoid performing a shuffle.
+   */
+  def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
+
   /**
    * Return an RDD with the elements from `this` that are not in `other`.
    * 
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index b3eb739f4e701617ab8ad9d64589f140d73dbc7a..2142fd73278aca1a2dd3a94d7cfb3508fd8877ca 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -107,6 +107,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
   def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
     fromRDD(rdd.coalesce(numPartitions, shuffle))
 
+  /**
+   * Return a new RDD that has exactly numPartitions partitions.
+   *
+   * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+   * a shuffle to redistribute data.
+   *
+   * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+   * which can avoid performing a shuffle.
+   */
+  def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
+
   /**
    * Return a sampled subset of this RDD.
    */
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 662990049b0938f542e5966c992a2bd1554a5586..3b359a8fd60941b8cbcdcceb843d1d52c8887842 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -81,6 +81,17 @@ JavaRDDLike[T, JavaRDD[T]] {
   def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
     rdd.coalesce(numPartitions, shuffle)
 
+  /**
+   * Return a new RDD that has exactly numPartitions partitions.
+   *
+   * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+   * a shuffle to redistribute data.
+   *
+   * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+   * which can avoid performing a shuffle.
+   */
+  def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
+
   /**
    * Return a sampled subset of this RDD.
    */
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0355618e435bd3d00357406ca167ab324aae4645..6e88be6f6ac64bb3d99bb56aa276c789dfba7c6a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -265,6 +265,19 @@ abstract class RDD[T: ClassManifest](
 
   def distinct(): RDD[T] = distinct(partitions.size)
 
+  /**
+   * Return a new RDD that has exactly numPartitions partitions.
+   *
+   * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+   * a shuffle to redistribute data.
+   *
+   * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+   * which can avoid performing a shuffle.
+   */
+  def repartition(numPartitions: Int): RDD[T] = {
+    coalesce(numPartitions, true)
+  }
+
   /**
    * Return a new RDD that is reduced into `numPartitions` partitions.
    *
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 7b0bb89ab28ce0306e852da05263b4a94b03b9f1..352036f182e24c676c9792f83b369d43f0fdb48b 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -472,6 +472,27 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
   }
 
+  @Test
+  public void repartition() {
+    // Shrinking number of partitions
+    JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
+    JavaRDD<Integer> repartitioned1 = in1.repartition(4);
+    List<List<Integer>> result1 = repartitioned1.glom().collect();
+    Assert.assertEquals(4, result1.size());
+    for (List<Integer> l: result1) {
+      Assert.assertTrue(l.size() > 0);
+    }
+
+    // Growing number of partitions
+    JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
+    JavaRDD<Integer> repartitioned2 = in2.repartition(2);
+    List<List<Integer>> result2 = repartitioned2.glom().collect();
+    Assert.assertEquals(2, result2.size());
+    for (List<Integer> l: result2) {
+      Assert.assertTrue(l.size() > 0);
+    }
+  }
+
   @Test
   public void persist() {
     JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 6d1bc5e296e06beb137673088229da3750c0579c..354ab8ae5d7d5c425cd3b62ea689554cc56b3294 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -139,6 +139,26 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(rdd.union(emptyKv).collect().size === 2)
   }
 
+  test("repartitioned RDDs") {
+    val data = sc.parallelize(1 to 1000, 10)
+
+    // Coalesce partitions
+    val repartitioned1 = data.repartition(2)
+    assert(repartitioned1.partitions.size == 2)
+    val partitions1 = repartitioned1.glom().collect()
+    assert(partitions1(0).length > 0)
+    assert(partitions1(1).length > 0)
+    assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
+
+    // Split partitions
+    val repartitioned2 = data.repartition(20)
+    assert(repartitioned2.partitions.size == 20)
+    val partitions2 = repartitioned2.glom().collect()
+    assert(partitions2(0).length > 0)
+    assert(partitions2(19).length > 0)
+    assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
+  }
+
   test("coalesced RDDs") {
     val data = sc.parallelize(1 to 10, 10)
 
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 835b257238e4bf04d2e0024a37f93ffaec13f0b5..851e30fe761af664cae684acbd86aa56cff88c65 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -72,6 +72,10 @@ DStreams support many of the transformations available on normal Spark RDD's:
   <td> Similar to map, but runs separately on each partition (block) of the DStream, so <i>func</i> must be of type
     Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
 </tr>
+<tr>
+  <td> <b>repartition</b>(<i>numPartitions</i>) </td>
+  <td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td>
+</tr>
 <tr>
   <td> <b>union</b>(<i>otherStream</i>) </td>
   <td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 8d7cbae8214e49489ae2ee673bdab74450a42d2c..45fd30a7c836408813b473df32db3197508c3be0 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -292,6 +292,7 @@ object SparkBuild extends Build {
       "org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
         exclude("com.sun.jdmk", "jmxtools")
         exclude("com.sun.jmx", "jmxri")
+        exclude("net.sf.jopt-simple", "jopt-simple")
     )
   )
 
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 8022c4fe18917a1a671ceed7d779c2bd54ae6705..7a9ae6a97ba7ff615e3e24c76923dc941d76b8f3 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -73,6 +73,10 @@
           <groupId>com.sun.jdmk</groupId>
           <artifactId>jmxtools</artifactId>
         </exclusion>
+        <exclusion>
+          <groupId>net.sf.jopt-simple</groupId>
+          <artifactId>jopt-simple</artifactId>
+        </exclusion>
       </exclusions>
     </dependency>
     <dependency>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 38e34795b49f598bc136c6492c88996f4c42a652..9ceff754c4b7251dfc50e533d75f892505e713e6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] (
    */
   def glom(): DStream[Array[T]] = new GlommedDStream(this)
 
+
+  /**
+   * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+   * returned DStream has exactly numPartitions partitions.
+   */
+  def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions))
+
   /**
    * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
    * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
index d1932b6b05a093fbf814f274dc6dc488bca73053..1a2aeaa8797e1fcbabe04505fdcd0c78e98612a3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
@@ -94,6 +94,12 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM
    */
   def union(that: JavaDStream[T]): JavaDStream[T] =
     dstream.union(that.dstream)
+
+  /**
+   * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+   * returned DStream has exactly numPartitions partitions.
+   */
+  def repartition(numPartitions: Int): JavaDStream[T] = dstream.repartition(numPartitions)
 }
 
 object JavaDStream {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 4dd6b7d096e618b8b36a2286366cce59798aa2f7..c6cd635afa0c87f903f19a07fe18034c668f2292 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -59,6 +59,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
   /** Persist the RDDs of this DStream with the given storage level */
   def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel)
 
+  /**
+   * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+   * returned DStream has exactly numPartitions partitions.
+   */
+  def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions)
+
   /** Method that generates a RDD for the given Duration */
   def compute(validTime: Time): JavaPairRDD[K, V] = {
     dstream.compute(validTime) match {
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 5d4890866702e9b4eb01ab5d9de4d545bad32efe..ad4a8b95355b9a185eb76d4d4e68dabb2a34bff5 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -184,6 +184,39 @@ public class JavaAPISuite implements Serializable {
     assertOrderInvariantEquals(expected, result);
   }
 
+  @Test
+  public void testRepartitionMorePartitions() {
+    List<List<Integer>> inputData = Arrays.asList(
+      Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+      Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+    JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 2);
+    JavaDStream repartitioned = stream.repartition(4);
+    JavaTestUtils.attachTestOutputStream(repartitioned);
+    List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+    Assert.assertEquals(2, result.size());
+    for (List<List<Integer>> rdd : result) {
+      Assert.assertEquals(4, rdd.size());
+      Assert.assertEquals(
+        10, rdd.get(0).size() + rdd.get(1).size() + rdd.get(2).size() + rdd.get(3).size());
+    }
+  }
+
+  @Test
+  public void testRepartitionFewerPartitions() {
+    List<List<Integer>> inputData = Arrays.asList(
+      Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+      Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+    JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 4);
+    JavaDStream repartitioned = stream.repartition(2);
+    JavaTestUtils.attachTestOutputStream(repartitioned);
+    List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+    Assert.assertEquals(2, result.size());
+    for (List<List<Integer>> rdd : result) {
+      Assert.assertEquals(2, rdd.size());
+      Assert.assertEquals(10, rdd.get(0).size() + rdd.get(1).size());
+    }
+  }
+
   @Test
   public void testGlom() {
     List<List<String>> inputData = Arrays.asList(
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
index 8a6604904de57d39626eac6ec5f2b19405dad833..5e384eeee45f385ff83bba34f93c4b2228f5faaa 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
@@ -33,9 +33,9 @@ trait JavaTestBase extends TestSuiteBase {
    * The stream will be derived from the supplied lists of Java objects.
    **/
   def attachTestInputStream[T](
-    ssc: JavaStreamingContext,
-    data: JList[JList[T]],
-    numPartitions: Int) = {
+      ssc: JavaStreamingContext,
+      data: JList[JList[T]],
+      numPartitions: Int) = {
     val seqData = data.map(Seq(_:_*))
 
     implicit val cm: ClassManifest[T] =
@@ -50,12 +50,11 @@ trait JavaTestBase extends TestSuiteBase {
    * [[org.apache.spark.streaming.TestOutputStream]].
    **/
   def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]](
-    dstream: JavaDStreamLike[T, This, R]) =
+      dstream: JavaDStreamLike[T, This, R]) =
   {
     implicit val cm: ClassManifest[T] =
       implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
-    val ostream = new TestOutputStream(dstream.dstream,
-      new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]])
+    val ostream = new TestOutputStreamWithPartitions(dstream.dstream)
     dstream.dstream.ssc.registerOutputStream(ostream)
   }
 
@@ -63,9 +62,11 @@ trait JavaTestBase extends TestSuiteBase {
    * Process all registered streams for a numBatches batches, failing if
    * numExpectedOutput RDD's are not generated. Generated RDD's are collected
    * and returned, represented as a list for each batch interval.
+   *
+   * Returns a list of items for each RDD.
    */
   def runStreams[V](
-    ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
+      ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
     implicit val cm: ClassManifest[V] =
       implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
     val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput)
@@ -73,6 +74,27 @@ trait JavaTestBase extends TestSuiteBase {
     res.map(entry => out.append(new ArrayList[V](entry)))
     out
   }
+
+  /**
+   * Process all registered streams for a numBatches batches, failing if
+   * numExpectedOutput RDD's are not generated. Generated RDD's are collected
+   * and returned, represented as a list for each batch interval.
+   *
+   * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+   * representing one partition.
+   */
+  def runStreamsWithPartitions[V](ssc: JavaStreamingContext, numBatches: Int,
+      numExpectedOutput: Int): JList[JList[JList[V]]] = {
+    implicit val cm: ClassManifest[V] =
+      implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+    val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput)
+    val out = new ArrayList[JList[JList[V]]]()
+    res.map{entry =>
+      val lists = entry.map(new ArrayList[V](_))
+      out.append(new ArrayList[JList[V]](lists))
+    }
+    out
+  }
 }
 
 object JavaTestUtils extends JavaTestBase {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index a2ac510a98e0c3dcef76dd8c6a8e4075c1b40ce5..259ef1608cbc5b179372995ef4a426df0101247d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -85,6 +85,44 @@ class BasicOperationsSuite extends TestSuiteBase {
     testOperation(input, operation, output, true)
   }
 
+  test("repartition (more partitions)") {
+    val input = Seq(1 to 100, 101 to 200, 201 to 300)
+    val operation = (r: DStream[Int]) => r.repartition(5)
+    val ssc = setupStreams(input, operation, 2)
+    val output = runStreamsWithPartitions(ssc, 3, 3)
+    assert(output.size === 3)
+    val first = output(0)
+    val second = output(1)
+    val third = output(2)
+
+    assert(first.size === 5)
+    assert(second.size === 5)
+    assert(third.size === 5)
+
+    assert(first.flatten.toSet === (1 to 100).toSet)
+    assert(second.flatten.toSet === (101 to 200).toSet)
+    assert(third.flatten.toSet === (201 to 300).toSet)
+  }
+
+  test("repartition (fewer partitions)") {
+    val input = Seq(1 to 100, 101 to 200, 201 to 300)
+    val operation = (r: DStream[Int]) => r.repartition(2)
+    val ssc = setupStreams(input, operation, 5)
+    val output = runStreamsWithPartitions(ssc, 3, 3)
+    assert(output.size === 3)
+    val first = output(0)
+    val second = output(1)
+    val third = output(2)
+
+    assert(first.size === 2)
+    assert(second.size === 2)
+    assert(third.size === 2)
+
+    assert(first.flatten.toSet === (1 to 100).toSet)
+    assert(second.flatten.toSet === (101 to 200).toSet)
+    assert(third.flatten.toSet === (201 to 300).toSet)
+  }
+
   test("groupByKey") {
     testOperation(
       Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index a327de80b3eb3a0f98bcbd2486e0580225b93a8a..beb20831bd7b49d5a9c9443149a4c1eb9171384a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -366,7 +366,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
     logInfo("Manual clock after advancing = " + clock.time)
     Thread.sleep(batchDuration.milliseconds)
 
-    val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
-    outputStream.output
+    val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+    outputStream.output.map(_.flatten)
   }
 }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 37dd9c4cc61d29d28f247bbdf9044b80e8fce76b..be140699c2964456a747a076120987c9abc26f39 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -60,8 +60,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
 /**
  * This is a output stream just for the testsuites. All the output is collected into a
  * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of items
  */
-class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
+class TestOutputStream[T: ClassManifest](parent: DStream[T],
+    val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]())
   extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
     val collected = rdd.collect()
     output += collected
@@ -75,6 +78,30 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu
   }
 }
 
+/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each
+ * containing a sequence of items.
+ */
+class TestOutputStreamWithPartitions[T: ClassManifest](parent: DStream[T],
+    val output: ArrayBuffer[Seq[Seq[T]]] = ArrayBuffer[Seq[Seq[T]]]())
+  extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+    val collected = rdd.glom().collect().map(_.toSeq)
+    output += collected
+  }) {
+
+  // This is to clear the output buffer every it is read from a checkpoint
+  @throws(classOf[IOException])
+  private def readObject(ois: ObjectInputStream) {
+    ois.defaultReadObject()
+    output.clear()
+  }
+
+  def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten))
+}
+
 /**
  * This is the base trait for Spark Streaming testsuites. This provides basic functionality
  * to run user-defined set of input on user-defined stream operations, and verify the output.
@@ -108,7 +135,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
    */
   def setupStreams[U: ClassManifest, V: ClassManifest](
       input: Seq[Seq[U]],
-      operation: DStream[U] => DStream[V]
+      operation: DStream[U] => DStream[V],
+      numPartitions: Int = numInputPartitions
     ): StreamingContext = {
 
     // Create StreamingContext
@@ -118,9 +146,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
     }
 
     // Setup the stream computation
-    val inputStream = new TestInputStream(ssc, input, numInputPartitions)
+    val inputStream = new TestInputStream(ssc, input, numPartitions)
     val operatedStream = operation(inputStream)
-    val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
+    val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+      new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
     ssc.registerInputStream(inputStream)
     ssc.registerOutputStream(outputStream)
     ssc
@@ -146,7 +175,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
     val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions)
     val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions)
     val operatedStream = operation(inputStream1, inputStream2)
-    val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]])
+    val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+      new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]])
     ssc.registerInputStream(inputStream1)
     ssc.registerInputStream(inputStream2)
     ssc.registerOutputStream(outputStream)
@@ -157,18 +187,37 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
    * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
    * returns the collected output. It will wait until `numExpectedOutput` number of
    * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+   *
+   * Returns a sequence of items for each RDD.
    */
   def runStreams[V: ClassManifest](
       ssc: StreamingContext,
       numBatches: Int,
       numExpectedOutput: Int
     ): Seq[Seq[V]] = {
+    // Flatten each RDD into a single Seq
+    runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq)
+  }
+
+  /**
+   * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
+   * returns the collected output. It will wait until `numExpectedOutput` number of
+   * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+   *
+   * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+   * representing one partition.
+   */
+  def runStreamsWithPartitions[V: ClassManifest](
+      ssc: StreamingContext,
+      numBatches: Int,
+      numExpectedOutput: Int
+    ): Seq[Seq[Seq[V]]] = {
     assert(numBatches > 0, "Number of batches to run stream computation is zero")
     assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
     logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
 
     // Get the output buffer
-    val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+    val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
     val output = outputStream.output
 
     try {