diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 648d9402b07951b654a06238ef0cf2763bc577aa..519e31032304e16522319ae103ab417aa05348f4 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -6,8 +6,17 @@ import java.util.Arrays /** * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * + * Stores the unique id() of the Python-side partitioning function so that it is incorporated into + * equality comparisons. Correctness requires that the id is a unique identifier for the + * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning + * function). This can be ensured by using the Python id() function and maintaining a reference + * to the Python partitioning function so that its id() is not reused. */ -private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { +private[spark] class PythonPartitioner( + override val numPartitions: Int, + val pyPartitionFunctionId: Long) + extends Partitioner { override def getPartition(key: Any): Int = { if (key == null) { @@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends override def equals(other: Any): Boolean = other match { case h: PythonPartitioner => - h.numPartitions == numPartitions + h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId case _ => false } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc7798b754fc1a821a2529a8c908375b..e4c0530241556ed4d5584a20e175ec52ecf6d7ae 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -252,11 +252,6 @@ private object Pickle { val APPENDS: Byte = 'e' } -private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], - Array[Byte]), Array[Byte]] { - override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 -} - private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e102909c48e53e8e08134b6db6757002..b58bf24e3e46c8c2f44b6b7c0282cb2c26696ede 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -33,6 +33,7 @@ class RDD(object): self._jrdd = jrdd self.is_cached = False self.ctx = ctx + self._partitionFunc = None @property def context(self): @@ -497,7 +498,7 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, partitionFunc=hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -514,17 +515,21 @@ class RDD(object): def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) + buckets[partitionFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + partitioner = self.ctx.jvm.PythonPartitioner(numSplits, + id(partitionFunc)) + jrdd = pairRDD.partitionBy(partitioner).values() + rdd = RDD(jrdd, self.ctx) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + rdd._partitionFunc = partitionFunc + return rdd # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners,