diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
index 648d9402b07951b654a06238ef0cf2763bc577aa..519e31032304e16522319ae103ab417aa05348f4 100644
--- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -6,8 +6,17 @@ import java.util.Arrays
 
 /**
  * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons.  Correctness requires that the id is a unique identifier for the
+ * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * function).  This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
  */
-private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
+private[spark] class PythonPartitioner(
+  override val numPartitions: Int,
+  val pyPartitionFunctionId: Long)
+  extends Partitioner {
 
   override def getPartition(key: Any): Int = {
     if (key == null) {
@@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends
 
   override def equals(other: Any): Boolean = other match {
     case h: PythonPartitioner =>
-      h.numPartitions == numPartitions
+      h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
     case _ =>
       false
   }
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 89f7c316dc7798b754fc1a821a2529a8c908375b..e4c0530241556ed4d5584a20e175ec52ecf6d7ae 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -252,11 +252,6 @@ private object Pickle {
   val APPENDS: Byte = 'e'
 }
 
-private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
-  Array[Byte]), Array[Byte]] {
-  override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
-}
-
 private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
   override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d705f0f9e102909c48e53e8e08134b6db6757002..b58bf24e3e46c8c2f44b6b7c0282cb2c26696ede 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -33,6 +33,7 @@ class RDD(object):
         self._jrdd = jrdd
         self.is_cached = False
         self.ctx = ctx
+        self._partitionFunc = None
 
     @property
     def context(self):
@@ -497,7 +498,7 @@ class RDD(object):
         return python_right_outer_join(self, other, numSplits)
 
     # TODO: add option to control map-side combining
-    def partitionBy(self, numSplits, hashFunc=hash):
+    def partitionBy(self, numSplits, partitionFunc=hash):
         """
         Return a copy of the RDD partitioned using the specified partitioner.
 
@@ -514,17 +515,21 @@ class RDD(object):
         def add_shuffle_key(split, iterator):
             buckets = defaultdict(list)
             for (k, v) in iterator:
-                buckets[hashFunc(k) % numSplits].append((k, v))
+                buckets[partitionFunc(k) % numSplits].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield str(split)
                 yield dump_pickle(Batch(items))
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
-        partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
-        jrdd = pairRDD.partitionBy(partitioner)
-        jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
-        return RDD(jrdd, self.ctx)
+        partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
+                                                     id(partitionFunc))
+        jrdd = pairRDD.partitionBy(partitioner).values()
+        rdd = RDD(jrdd, self.ctx)
+        # This is required so that id(partitionFunc) remains unique, even if
+        # partitionFunc is a lambda:
+        rdd._partitionFunc = partitionFunc
+        return rdd
 
     # TODO: add control over map-side aggregation
     def combineByKey(self, createCombiner, mergeValue, mergeCombiners,