diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4c625d062eb9bef5abaf71a180af13cafaf7f18c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+import java.io.{ObjectOutputStream, IOException}
+import org.apache.spark.{TaskContext, OneToOneDependency, SparkContext, Partition}
+
+
+/**
+ * Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of corresponding partitions
+ * of parent RDDs.
+ */
+private[spark]
+class PartitionerAwareUnionRDDPartition(
+    @transient val rdds: Seq[RDD[_]],
+    val idx: Int
+  ) extends Partition {
+  var parents = rdds.map(_.partitions(idx)).toArray
+  
+  override val index = idx
+  override def hashCode(): Int = idx
+
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream) {
+    // Update the reference to parent partition at the time of task serialization
+    parents = rdds.map(_.partitions(index)).toArray
+    oos.defaultWriteObject()
+  }
+}
+
+/**
+ * Class representing an RDD that can take multiple RDDs partitioned by the same partitioner and
+ * unify them into a single RDD while preserving the partitioner. So m RDDs with p partitions each
+ * will be unified to a single RDD with p partitions and the same partitioner. The preferred
+ * location for each partition of the unified RDD will be the most common preferred location
+ * of the corresponding partitions of the parent RDDs. For example, location of partition 0
+ * of the unified RDD will be where most of partition 0 of the parent RDDs are located.
+ */
+private[spark]
+class PartitionerAwareUnionRDD[T: ClassTag](
+    sc: SparkContext,
+    var rdds: Seq[RDD[T]]
+  ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) {
+  require(rdds.length > 0)
+  require(rdds.flatMap(_.partitioner).toSet.size == 1,
+    "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
+
+  override val partitioner = rdds.head.partitioner
+
+  override def getPartitions: Array[Partition] = {
+    val numPartitions = partitioner.get.numPartitions
+    (0 until numPartitions).map(index => {
+      new PartitionerAwareUnionRDDPartition(rdds, index)
+    }).toArray
+  }
+
+  // Get the location where most of the partitions of parent RDDs are located
+  override def getPreferredLocations(s: Partition): Seq[String] = {
+    logDebug("Finding preferred location for " + this + ", partition " + s.index)
+    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
+    val locations = rdds.zip(parentPartitions).flatMap {
+      case (rdd, part) => {
+        val parentLocations = currPrefLocs(rdd, part)
+        logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
+        parentLocations
+      }
+    }
+    val location = if (locations.isEmpty) {
+      None
+    } else  {
+      // Find the location that maximum number of parent partitions prefer
+      Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
+    }
+    logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)
+    location.toSeq
+  }
+
+  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
+    val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
+    rdds.zip(parentPartitions).iterator.flatMap {
+      case (rdd, p) => rdd.iterator(p, context)
+    }
+  }
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    rdds = null
+  }
+
+  // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
+  private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
+    rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host)
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 642dabaad560e1d2a2f68f243e721530d5882efb..bc688110f47362cef8cdb5bf9e92d3e71b22f9c9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -40,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration {
  * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
  * of the checkpointed RDD.
  */
-private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T])
+private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
   extends Logging with Serializable {
 
   import CheckpointState._
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 70bfb81661381e87003afd2c2f8aefbe59afd537..ec13b329b25a8534078e5a55da8fe25801f3e0ca 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -55,15 +55,15 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
   }
 
   test("RDDs with one-to-one dependencies") {
-    testCheckpointing(_.map(x => x.toString))
-    testCheckpointing(_.flatMap(x => 1 to x))
-    testCheckpointing(_.filter(_ % 2 == 0))
-    testCheckpointing(_.sample(false, 0.5, 0))
-    testCheckpointing(_.glom())
-    testCheckpointing(_.mapPartitions(_.map(_.toString)))
-    testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
-    testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
-    testCheckpointing(_.pipe(Seq("cat")))
+    testRDD(_.map(x => x.toString))
+    testRDD(_.flatMap(x => 1 to x))
+    testRDD(_.filter(_ % 2 == 0))
+    testRDD(_.sample(false, 0.5, 0))
+    testRDD(_.glom())
+    testRDD(_.mapPartitions(_.map(_.toString)))
+    testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+    testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+    testRDD(_.pipe(Seq("cat")))
   }
 
   test("ParallelCollection") {
@@ -95,7 +95,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
   }
 
   test("ShuffledRDD") {
-    testCheckpointing(rdd => {
+    testRDD(rdd => {
       // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
       new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
     })
@@ -103,25 +103,17 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("UnionRDD") {
     def otherRDD = sc.makeRDD(1 to 10, 1)
-
-    // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed.
-    // Current implementation of UnionRDD has transient reference to parent RDDs,
-    // so only the partitions will reduce in serialized size, not the RDD.
-    testCheckpointing(_.union(otherRDD), false, true)
-    testParentCheckpointing(_.union(otherRDD), false, true)
+    testRDD(_.union(otherRDD))
+    testRDDPartitions(_.union(otherRDD))
   }
 
   test("CartesianRDD") {
     def otherRDD = sc.makeRDD(1 to 10, 1)
-    testCheckpointing(new CartesianRDD(sc, _, otherRDD))
-
-    // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
-    // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
-    // so only the RDD will reduce in serialized size, not the partitions.
-    testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+    testRDD(new CartesianRDD(sc, _, otherRDD))
+    testRDDPartitions(new CartesianRDD(sc, _, otherRDD))
 
     // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
-    // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+    // the parent RDD has been checkpointed and parent partitions have been changed.
     // Note that this test is very specific to the current implementation of CartesianRDD.
     val ones = sc.makeRDD(1 to 100, 10).map(x => x)
     ones.checkpoint() // checkpoint that MappedRDD
@@ -132,23 +124,20 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     val splitAfterCheckpoint =
       serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
     assert(
-      (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
-        (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
-      "CartesianRDD.parents not updated after parent RDD checkpointed"
+      (splitAfterCheckpoint.s1.getClass != splitBeforeCheckpoint.s1.getClass) &&
+        (splitAfterCheckpoint.s2.getClass != splitBeforeCheckpoint.s2.getClass),
+      "CartesianRDD.s1 and CartesianRDD.s2 not updated after parent RDD is checkpointed"
     )
   }
 
   test("CoalescedRDD") {
-    testCheckpointing(_.coalesce(2))
+    testRDD(_.coalesce(2))
+    testRDDPartitions(_.coalesce(2))
 
-    // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
-    // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
-    // so only the RDD will reduce in serialized size, not the partitions.
-    testParentCheckpointing(_.coalesce(2), true, false)
-
-    // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after
-    // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
-    // Note that this test is very specific to the current implementation of CoalescedRDDPartitions
+    // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents)
+    // after the parent RDD has been checkpointed and parent partitions have been changed.
+    // Note that this test is very specific to the current implementation of
+    // CoalescedRDDPartitions.
     val ones = sc.makeRDD(1 to 100, 10).map(x => x)
     ones.checkpoint() // checkpoint that MappedRDD
     val coalesced = new CoalescedRDD(ones, 2)
@@ -158,33 +147,78 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     val splitAfterCheckpoint =
       serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
     assert(
-      splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
-      "CoalescedRDDPartition.parents not updated after parent RDD checkpointed"
+      splitAfterCheckpoint.parents.head.getClass != splitBeforeCheckpoint.parents.head.getClass,
+      "CoalescedRDDPartition.parents not updated after parent RDD is checkpointed"
     )
   }
 
   test("CoGroupedRDD") {
-    val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
-    testCheckpointing(rdd => {
+    val longLineageRDD1 = generateFatPairRDD()
+    testRDD(rdd => {
       CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
-    }, false, true)
+    })
 
-    val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
-    testParentCheckpointing(rdd => {
+    val longLineageRDD2 = generateFatPairRDD()
+    testRDDPartitions(rdd => {
       CheckpointSuite.cogroup(
         longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
-    }, false, true)
+    })
   }
 
   test("ZippedRDD") {
-    testCheckpointing(
-      rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
-
-    // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
-    // Current implementation of ZippedRDDPartitions has transient references to parent RDDs,
-    // so only the RDD will reduce in serialized size, not the partitions.
-    testParentCheckpointing(
-      rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+    testRDD(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
+    testRDDPartitions(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
+
+    // Test that the ZippedPartition updates parent partitions
+    // after the parent RDD has been checkpointed and parent partitions have been changed.
+    // Note that this test is very specific to the current implementation of ZippedRDD.
+    val rdd = generateFatRDD()
+    val zippedRDD = new ZippedRDD(sc, rdd, rdd.map(x => x))
+    zippedRDD.rdd1.checkpoint()
+    zippedRDD.rdd2.checkpoint()
+    val partitionBeforeCheckpoint =
+      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+    zippedRDD.count()
+    val partitionAfterCheckpoint =
+      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+    assert(
+      partitionAfterCheckpoint.partition1.getClass != partitionBeforeCheckpoint.partition1.getClass &&
+        partitionAfterCheckpoint.partition2.getClass != partitionBeforeCheckpoint.partition2.getClass,
+      "ZippedRDD.partition1 and ZippedRDD.partition2 not updated after parent RDD is checkpointed"
+    )
+  }
+
+  test("PartitionerAwareUnionRDD") {
+    testRDD(rdd => {
+      new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
+        generateFatPairRDD(),
+        rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+      ))
+    })
+
+    testRDDPartitions(rdd => {
+      new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
+        generateFatPairRDD(),
+        rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+      ))
+    })
+
+    // Test that the PartitionerAwareUnionRDD updates parent partitions
+    // (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent
+    // partitions have been changed. Note that this test is very specific to the current
+    // implementation of PartitionerAwareUnionRDD.
+    val pairRDD = generateFatPairRDD()
+    pairRDD.checkpoint()
+    val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD))
+    val partitionBeforeCheckpoint =  serializeDeserialize(
+      unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
+    pairRDD.count()
+    val partitionAfterCheckpoint =  serializeDeserialize(
+      unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
+    assert(
+      partitionBeforeCheckpoint.parents.head.getClass != partitionAfterCheckpoint.parents.head.getClass,
+      "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed"
+    )
   }
 
   test("CheckpointRDD with zero partitions") {
@@ -198,29 +232,32 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
   }
 
   /**
-   * Test checkpointing of the final RDD generated by the given operation. By default,
-   * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
-   * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or
-   * not, but this is not done by default as usually the partitions do not refer to any RDD and
-   * therefore never store the lineage.
+   * Test checkpointing of the RDD generated by the given operation. It tests whether the
+   * serialized size of the RDD is reduce after checkpointing or not. This function should be called
+   * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.).
    */
-  def testCheckpointing[U: ClassTag](
-      op: (RDD[Int]) => RDD[U],
-      testRDDSize: Boolean = true,
-      testRDDPartitionSize: Boolean = false
-    ) {
+  def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U]) {
     // Generate the final RDD using given RDD operation
-    val baseRDD = generateLongLineageRDD()
+    val baseRDD = generateFatRDD()
     val operatedRDD = op(baseRDD)
     val parentRDD = operatedRDD.dependencies.headOption.orNull
     val rddType = operatedRDD.getClass.getSimpleName
     val numPartitions = operatedRDD.partitions.length
 
+    // Force initialization of all the data structures in RDDs
+    // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
+    initializeRdd(operatedRDD)
+
+    val partitionsBeforeCheckpoint = operatedRDD.partitions
+
     // Find serialized sizes before and after the checkpoint
-    val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
     operatedRDD.checkpoint()
     val result = operatedRDD.collect()
-    val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+    operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
+    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
 
     // Test whether the checkpoint file has been created
     assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
@@ -228,6 +265,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     // Test whether dependencies have been changed from its earlier parent RDD
     assert(operatedRDD.dependencies.head.rdd != parentRDD)
 
+    // Test whether the partitions have been changed from its earlier partitions
+    assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList)
+
     // Test whether the partitions have been changed to the new Hadoop partitions
     assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
 
@@ -237,122 +277,72 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     // Test whether the data in the checkpointed RDD is same as original
     assert(operatedRDD.collect() === result)
 
-    // Test whether serialized size of the RDD has reduced. If the RDD
-    // does not have any dependency to another RDD (e.g., ParallelCollection,
-    // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
-    if (testRDDSize) {
-      logInfo("Size of " + rddType +
-        "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
-      assert(
-        rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
-        "Size of " + rddType + " did not reduce after checkpointing " +
-          "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
-      )
-    }
+    // Test whether serialized size of the RDD has reduced.
+    logInfo("Size of " + rddType +
+      " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+    assert(
+      rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+      "Size of " + rddType + " did not reduce after checkpointing " +
+        " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+    )
 
-    // Test whether serialized size of the partitions has reduced. If the partitions
-    // do not have any non-transient reference to another RDD or another RDD's partitions, it
-    // does not refer to a lineage and therefore may not reduce in size after checkpointing.
-    // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions
-    // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
-    // replaced with the HadooPartitions of the checkpointed RDD.
-    if (testRDDPartitionSize) {
-      logInfo("Size of " + rddType + " partitions "
-        + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
-      assert(
-        splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
-        "Size of " + rddType + " partitions did not reduce after checkpointing " +
-          "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
-      )
-    }
   }
 
   /**
    * Test whether checkpointing of the parent of the generated RDD also
    * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
    * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
-   * this RDD will remember the partitions and therefore potentially the whole lineage.
+   * the generated RDD will remember the partitions and therefore potentially the whole lineage.
+   * This function should be called only those RDD whose partitions refer to parent RDD's
+   * partitions (i.e., do not call it on simple RDD like MappedRDD).
+   *
    */
-  def testParentCheckpointing[U: ClassTag](
-      op: (RDD[Int]) => RDD[U],
-      testRDDSize: Boolean,
-      testRDDPartitionSize: Boolean
-    ) {
+  def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U]) {
     // Generate the final RDD using given RDD operation
-    val baseRDD = generateLongLineageRDD()
+    val baseRDD = generateFatRDD()
     val operatedRDD = op(baseRDD)
-    val parentRDD = operatedRDD.dependencies.head.rdd
+    val parentRDDs = operatedRDD.dependencies.map(_.rdd)
     val rddType = operatedRDD.getClass.getSimpleName
-    val parentRDDType = parentRDD.getClass.getSimpleName
 
-    // Get the partitions and dependencies of the parent in case they're lazily computed
-    parentRDD.dependencies
-    parentRDD.partitions
+    // Force initialization of all the data structures in RDDs
+    // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
+    initializeRdd(operatedRDD)
 
     // Find serialized sizes before and after the checkpoint
-    val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
-    parentRDD.checkpoint()  // checkpoint the parent RDD, not the generated one
-    val result = operatedRDD.collect()
-    val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+    parentRDDs.foreach(_.checkpoint())  // checkpoint the parent RDD, not the generated one
+    val result = operatedRDD.collect()  // force checkpointing
+    operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
+    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
 
     // Test whether the data in the checkpointed RDD is same as original
     assert(operatedRDD.collect() === result)
 
-    // Test whether serialized size of the RDD has reduced because of its parent being
-    // checkpointed. If this RDD or its parent RDD do not have any dependency
-    // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
-    // not reduce in size after checkpointing.
-    if (testRDDSize) {
-      assert(
-        rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
-        "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
-          "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
-      )
-    }
-
-    // Test whether serialized size of the partitions has reduced because of its parent being
-    // checkpointed. If the partitions do not have any non-transient reference to another RDD
-    // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce
-    // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent
-    // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's
-    // partitions must have changed after checkpointing.
-    if (testRDDPartitionSize) {
-      assert(
-        splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
-        "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType +
-          "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
-      )
-    }
-
+    // Test whether serialized size of the partitions has reduced
+    logInfo("Size of partitions of " + rddType +
+      " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]")
+    assert(
+      partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint,
+      "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" +
+        " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]"
+    )
   }
 
   /**
-   * Generate an RDD with a long lineage of one-to-one dependencies.
+   * Generate an RDD such that both the RDD and its partitions have large size.
    */
-  def generateLongLineageRDD(): RDD[Int] = {
-    var rdd = sc.makeRDD(1 to 100, 4)
-    for (i <- 1 to 50) {
-      rdd = rdd.map(x => x + 1)
-    }
-    rdd
+  def generateFatRDD(): RDD[Int] = {
+    new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x)
   }
 
   /**
-   * Generate an RDD with a long lineage specifically for CoGroupedRDD.
-   * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
-   * and narrow dependency with this RDD. This method generate such an RDD by a sequence
-   * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+   * Generate an pair RDD (with partitioner) such that both the RDD and its partitions
+   * have large size.
    */
-  def generateLongLineageRDDForCoGroupedRDD() = {
-    val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
-
-    def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
-
-    var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
-    for(i <- 1 to 10) {
-      cogrouped = cogrouped.mapValues(add).cogroup(ones)
-    }
-    cogrouped.mapValues(add)
+  def generateFatPairRDD() = {
+    new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
   }
 
   /**
@@ -360,8 +350,26 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
    * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
    */
   def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
-    (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
-     Utils.serialize(rdd.partitions).length)
+    val rddSize = Utils.serialize(rdd).size
+    val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
+    val rddPartitionSize = Utils.serialize(rdd.partitions).size
+    val rddDependenciesSize = Utils.serialize(rdd.dependencies).size
+
+    // Print detailed size, helps in debugging
+    logInfo("Serialized sizes of " + rdd +
+      ": RDD = " + rddSize +
+      ", RDD checkpoint data = " + rddCpDataSize +
+      ", RDD partitions = " + rddPartitionSize +
+      ", RDD dependencies = " + rddDependenciesSize
+    )
+    // this makes sure that serializing the RDD's checkpoint data does not
+    // serialize the whole RDD as well
+    assert(
+      rddSize > rddCpDataSize,
+      "RDD's checkpoint data (" + rddCpDataSize  + ") is equal or larger than the " +
+        "whole RDD with checkpoint data (" + rddSize + ")"
+    )
+    (rddSize - rddCpDataSize, rddPartitionSize)
   }
 
   /**
@@ -373,8 +381,49 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     val bytes = Utils.serialize(obj)
     Utils.deserialize[T](bytes)
   }
+
+  /**
+   * Recursively force the initialization of the all members of an RDD and it parents.
+   */
+  def initializeRdd(rdd: RDD[_]) {
+    rdd.partitions // forces the
+    rdd.dependencies.map(_.rdd).foreach(initializeRdd(_))
+  }
 }
 
+/** RDD partition that has large serialized size. */
+class FatPartition(val partition: Partition) extends Partition {
+  val bigData = new Array[Byte](10000)
+  def index: Int = partition.index
+}
+
+/** RDD that has large serialized size. */
+class FatRDD(parent: RDD[Int]) extends RDD[Int](parent) {
+  val bigData = new Array[Byte](100000)
+
+  protected def getPartitions: Array[Partition] = {
+    parent.partitions.map(p => new FatPartition(p))
+  }
+
+  def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+    parent.compute(split.asInstanceOf[FatPartition].partition, context)
+  }
+}
+
+/** Pair RDD that has large serialized size. */
+class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, Int)](parent) {
+  val bigData = new Array[Byte](100000)
+
+  protected def getPartitions: Array[Partition] = {
+    parent.partitions.map(p => new FatPartition(p))
+  }
+
+  @transient override val partitioner = Some(_partitioner)
+
+  def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = {
+    parent.compute(split.asInstanceOf[FatPartition].partition, context).map(x => (x, x))
+  }
+}
 
 object CheckpointSuite {
   // This is a custom cogroup function that does not use mapValues like
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 1383359f85db0ce77663a8a646c374537e0f3c8a..559ea051d35335de1dfc53a8ce8a973e86b46d1e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -84,6 +84,33 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
   }
 
+  test("partitioner aware union") {
+    import SparkContext._
+    def makeRDDWithPartitioner(seq: Seq[Int]) = {
+      sc.makeRDD(seq, 1)
+        .map(x => (x, null))
+        .partitionBy(new HashPartitioner(2))
+        .mapPartitions(_.map(_._1), true)
+    }
+
+    val nums1 = makeRDDWithPartitioner(1 to 4)
+    val nums2 = makeRDDWithPartitioner(5 to 8)
+    assert(nums1.partitioner == nums2.partitioner)
+    assert(new PartitionerAwareUnionRDD(sc, Seq(nums1)).collect().toSet === Set(1, 2, 3, 4))
+
+    val union = new PartitionerAwareUnionRDD(sc, Seq(nums1, nums2))
+    assert(union.collect().toSet === Set(1, 2, 3, 4, 5, 6, 7, 8))
+    val nums1Parts = nums1.collectPartitions()
+    val nums2Parts = nums2.collectPartitions()
+    val unionParts = union.collectPartitions()
+    assert(nums1Parts.length === 2)
+    assert(nums2Parts.length === 2)
+    assert(unionParts.length === 2)
+    assert((nums1Parts(0) ++ nums2Parts(0)).toList === unionParts(0).toList)
+    assert((nums1Parts(1) ++ nums2Parts(1)).toList === unionParts(1).toList)
+    assert(union.partitioner === nums1.partitioner)
+  }
+
   test("aggregate") {
     val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
     type StringMap = HashMap[String, Int]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
index 80af96c060a14b994f8b8ae9873512df77dd94c6..56dbcbda23a7029bda765437ce235430136fa4df 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
@@ -108,8 +108,9 @@ extends Serializable {
     createCombiner: V => C,
     mergeValue: (C, V) => C,
     mergeCombiner: (C, C) => C,
-    partitioner: Partitioner) : DStream[(K, C)] = {
-    new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner)
+    partitioner: Partitioner,
+    mapSideCombine: Boolean = true): DStream[(K, C)] = {
+    new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine)
   }
 
   /**
@@ -173,7 +174,13 @@ extends Serializable {
       slideDuration: Duration,
       partitioner: Partitioner
     ): DStream[(K, Seq[V])] = {
-    self.window(windowDuration, slideDuration).groupByKey(partitioner)
+    val createCombiner = (v: Seq[V]) => new ArrayBuffer[V] ++= v
+    val mergeValue = (buf: ArrayBuffer[V], v: Seq[V]) => buf ++= v
+    val mergeCombiner = (buf1: ArrayBuffer[V], buf2: ArrayBuffer[V]) => buf1 ++= buf2
+    self.groupByKey(partitioner)
+        .window(windowDuration, slideDuration)
+        .combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, partitioner)
+        .asInstanceOf[DStream[(K, Seq[V])]]
   }
 
   /**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index dfd6e27c3e9101dbe125c26b0e5a4b99469c4a27..6c3467d4056d721a90e0efc55e903862af12526e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -155,7 +155,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
 
   /**
    * Combine elements of each key in DStream's RDDs using custom function. This is similar to the
-   * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more
+   * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.rdd.PairRDDFunctions]] for more
    * information.
    */
   def combineByKey[C](createCombiner: JFunction[V, C],
@@ -168,6 +168,22 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
     dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner)
   }
 
+  /**
+   * Combine elements of each key in DStream's RDDs using custom function. This is similar to the
+   * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.rdd.PairRDDFunctions]] for more
+   * information.
+   */
+  def combineByKey[C](createCombiner: JFunction[V, C],
+      mergeValue: JFunction2[C, V, C],
+      mergeCombiners: JFunction2[C, C, C],
+      partitioner: Partitioner,
+      mapSideCombine: Boolean
+    ): JavaPairDStream[K, C] = {
+    implicit val cm: ClassTag[C] =
+      implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
+    dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine)
+  }
+
   /**
    * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to
    * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
index e6e00220979981c2f5a5db69d8254e26add226dc..84e69f277b22e97d2ac8303bc5d784faa8b96b06 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
@@ -29,8 +29,9 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag](
     createCombiner: V => C,
     mergeValue: (C, V) => C,
     mergeCombiner: (C, C) => C,
-    partitioner: Partitioner
-  ) extends DStream [(K,C)] (parent.ssc) {
+    partitioner: Partitioner,
+    mapSideCombine: Boolean = true
+  ) extends DStream[(K,C)] (parent.ssc) {
 
   override def dependencies = List(parent)
 
@@ -38,8 +39,8 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag](
 
   override def compute(validTime: Time): Option[RDD[(K,C)]] = {
     parent.getOrCompute(validTime) match {
-      case Some(rdd) =>
-        Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner))
+      case Some(rdd) => Some(rdd.combineByKey[C](
+          createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine))
       case None => None
     }
   }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
index 73d959331a3c24974cfe8c42e809783cf56eed28..89c43ff935bb415183abe74e943ed953e4531406 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
@@ -17,10 +17,10 @@
 
 package org.apache.spark.streaming.dstream
 
-import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.UnionRDD
+import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD}
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Duration, Interval, Time, DStream}
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.Duration
 
 import scala.reflect.ClassTag
 
@@ -51,6 +51,14 @@ class WindowedDStream[T: ClassTag](
 
   override def compute(validTime: Time): Option[RDD[T]] = {
     val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime)
-    Some(new UnionRDD(ssc.sc, parent.slice(currentWindow)))
+    val rddsInWindow = parent.slice(currentWindow)
+    val windowRDD = if (rddsInWindow.flatMap(_.partitioner).distinct.length == 1) {
+      logDebug("Using partition aware union for windowing at " + validTime)
+      new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow)
+    } else {
+      logDebug("Using normal union for windowing at " + validTime)
+      new UnionRDD(ssc.sc,rddsInWindow)
+    }
+    Some(windowRDD)
   }
 }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
index c92c34d49bd367da19f36cc0f16f64bb38970436..c39abfc21b3ba219a08c6bf6e514db052dbcd7f0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
@@ -224,9 +224,7 @@ class WindowOperationsSuite extends TestSuiteBase {
     val slideDuration = Seconds(1)
     val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
     val operation = (s: DStream[(String, Int)]) => {
-      s.groupByKeyAndWindow(windowDuration, slideDuration)
-       .map(x => (x._1, x._2.toSet))
-       .persist()
+      s.groupByKeyAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toSet))
     }
     testOperation(input, operation, expectedOutput, numBatches, true)
   }