Skip to content
Snippets Groups Projects
Commit 14360aba authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #390 from JoshRosen/spark-654

Fix PythonPartitioner equality
parents fe85a075 9f211dd3
No related branches found
No related tags found
No related merge requests found
...@@ -6,8 +6,17 @@ import java.util.Arrays ...@@ -6,8 +6,17 @@ import java.util.Arrays
/** /**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
* lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/ */
private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { private[spark] class PythonPartitioner(
override val numPartitions: Int,
val pyPartitionFunctionId: Long)
extends Partitioner {
override def getPartition(key: Any): Int = { override def getPartition(key: Any): Int = {
if (key == null) { if (key == null) {
...@@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends ...@@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends
override def equals(other: Any): Boolean = other match { override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner => case h: PythonPartitioner =>
h.numPartitions == numPartitions h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ => case _ =>
false false
} }
......
...@@ -252,11 +252,6 @@ private object Pickle { ...@@ -252,11 +252,6 @@ private object Pickle {
val APPENDS: Byte = 'e' val APPENDS: Byte = 'e'
} }
private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
Array[Byte]), Array[Byte]] {
override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
}
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
} }
......
...@@ -33,6 +33,7 @@ class RDD(object): ...@@ -33,6 +33,7 @@ class RDD(object):
self._jrdd = jrdd self._jrdd = jrdd
self.is_cached = False self.is_cached = False
self.ctx = ctx self.ctx = ctx
self._partitionFunc = None
@property @property
def context(self): def context(self):
...@@ -497,7 +498,7 @@ class RDD(object): ...@@ -497,7 +498,7 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits) return python_right_outer_join(self, other, numSplits)
# TODO: add option to control map-side combining # TODO: add option to control map-side combining
def partitionBy(self, numSplits, hashFunc=hash): def partitionBy(self, numSplits, partitionFunc=hash):
""" """
Return a copy of the RDD partitioned using the specified partitioner. Return a copy of the RDD partitioned using the specified partitioner.
...@@ -514,17 +515,21 @@ class RDD(object): ...@@ -514,17 +515,21 @@ class RDD(object):
def add_shuffle_key(split, iterator): def add_shuffle_key(split, iterator):
buckets = defaultdict(list) buckets = defaultdict(list)
for (k, v) in iterator: for (k, v) in iterator:
buckets[hashFunc(k) % numSplits].append((k, v)) buckets[partitionFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems(): for (split, items) in buckets.iteritems():
yield str(split) yield str(split)
yield dump_pickle(Batch(items)) yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key) keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
jrdd = pairRDD.partitionBy(partitioner) id(partitionFunc))
jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) jrdd = pairRDD.partitionBy(partitioner).values()
return RDD(jrdd, self.ctx) rdd = RDD(jrdd, self.ctx)
# This is required so that id(partitionFunc) remains unique, even if
# partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
return rdd
# TODO: add control over map-side aggregation # TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners, def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment