diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 338dff40618cc8469f693e092439f88e7913e8b2..4ffec433a823460866a9200b81dac93821e8595d 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,17 +1,17 @@
 package spark
 
 import java.io.EOFException
-import java.net.URL
 import java.io.ObjectInputStream
-import java.util.concurrent.atomic.AtomicLong
+import java.net.URL
 import java.util.Random
 import java.util.Date
 import java.util.{HashMap => JHashMap}
+import java.util.concurrent.atomic.AtomicLong
 
-import scala.collection.mutable.ArrayBuffer
 import scala.collection.Map
-import scala.collection.mutable.HashMap
 import scala.collection.JavaConversions.mapAsScalaMap
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
 
 import org.apache.hadoop.io.BytesWritable
 import org.apache.hadoop.io.NullWritable
@@ -47,7 +47,7 @@ import spark.storage.StorageLevel
 import SparkContext._
 
 /**
- * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, 
+ * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
  * partitioned collection of elements that can be operated on in parallel. This class contains the
  * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
  * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
@@ -86,28 +86,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
   @transient val dependencies: List[Dependency[_]]
 
   // Methods available on all RDDs:
-  
+
   /** Record user function generating this RDD. */
   private[spark] val origin = Utils.getSparkCallSite
-  
+
   /** Optionally overridden by subclasses to specify how they are partitioned. */
   val partitioner: Option[Partitioner] = None
 
   /** Optionally overridden by subclasses to specify placement preferences. */
   def preferredLocations(split: Split): Seq[String] = Nil
-  
+
   /** The [[spark.SparkContext]] that this RDD was created on. */
   def context = sc
 
   private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
-  
+
   /** A unique ID for this RDD (within its SparkContext). */
   val id = sc.newRddId()
-  
+
   // Variables relating to persistence
   private var storageLevel: StorageLevel = StorageLevel.NONE
-  
-  /** 
+
+  /**
    * Set this RDD's storage level to persist its values across operations after the first time
    * it is computed. Can only be called once on each RDD.
    */
@@ -123,32 +123,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
 
   /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
   def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
-  
+
   /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
   def cache(): RDD[T] = persist()
 
   /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
   def getStorageLevel = storageLevel
-  
+
   private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
     if (!level.useDisk && level.replication < 2) {
       throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
-    } 
-    
+    }
+
     // This is a hack. Ideally this should re-use the code used by the CacheTracker
     // to generate the key.
     def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
-    
+
     persist(level)
     sc.runJob(this, (iter: Iterator[T]) => {} )
-    
+
     val p = this.partitioner
-    
+
     new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
-      override val partitioner = p 
+      override val partitioner = p
     }
   }
-  
+
   /**
    * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
    * This should ''not'' be called by users directly, but is available for implementors of custom
@@ -161,9 +161,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
       compute(split)
     }
   }
-  
+
   // Transformations (return a new RDD)
-  
+
   /**
    * Return a new RDD by applying a function to all elements of this RDD.
    */
@@ -199,13 +199,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
     var multiplier = 3.0
     var initialCount = count()
     var maxSelected = 0
-    
+
     if (initialCount > Integer.MAX_VALUE - 1) {
       maxSelected = Integer.MAX_VALUE - 1
     } else {
       maxSelected = initialCount.toInt
     }
-    
+
     if (num > initialCount) {
       total = maxSelected
       fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
@@ -215,14 +215,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
       fraction = math.min(multiplier * (num + 1) / initialCount, 1.0)
       total = num
     }
-  
+
     val rand = new Random(seed)
     var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
-  
+
     while (samples.length < total) {
       samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
     }
-  
+
     Utils.randomizeInPlace(samples, rand).take(total)
   }
 
@@ -290,8 +290,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
    * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
    * of the original partition.
    */
-  def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
-    new MapPartitionsWithSplitRDD(this, sc.clean(f))
+  def mapPartitionsWithSplit[U: ClassManifest](
+    f: (Int, Iterator[T]) => Iterator[U],
+    preservesPartitioning: Boolean = false): RDD[U] =
+    new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)
 
   // Actions (launch a job to return a value to the user program)
 
@@ -342,7 +344,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
 
   /**
    * Aggregate the elements of each partition, and then the results for all the partitions, using a
-   * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to 
+   * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
    * modify t1 and return it as its result value to avoid object allocation; however, it should not
    * modify t2.
    */
@@ -443,7 +445,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
     val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
     sc.runApproximateJob(this, countPartition, evaluator, timeout)
   }
-  
+
   /**
    * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
    * it will be slow if a lot of partitions are required. In that case, use collect() to get the
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index adc541694e73ba1ef887d61439fd6dc35b2ef350..14e390c43b9cf3922b3678bcc67ab645cb9d3798 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -12,9 +12,11 @@ import spark.Split
 private[spark]
 class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
     prev: RDD[T],
-    f: (Int, Iterator[T]) => Iterator[U])
+    f: (Int, Iterator[T]) => Iterator[U],
+    preservesPartitioning: Boolean)
   extends RDD[U](prev.context) {
 
+  override val partitioner = if (preservesPartitioning) prev.partitioner else None
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
   override def compute(split: Split) = f(split.index, prev.iterator(split))