Skip to content
Snippets Groups Projects
Commit 08641932 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #33 from AndreSchumacher/pyspark_partition_key_change

Fixing SPARK-602: PythonPartitioner

Currently PythonPartitioner determines partition ID by hashing a
byte-array representation of PySpark's key. This PR lets
PythonPartitioner use the actual partition ID, which is required e.g.
for sorting via PySpark.
parents 232765f7 c84946fe
No related branches found
No related tags found
No related merge requests found
......@@ -17,12 +17,13 @@
package org.apache.spark.api.python
import org.apache.spark.Partitioner
import java.util.Arrays
import org.apache.spark.Partitioner
import org.apache.spark.util.Utils
/**
* A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
* A [[org.apache.spark.Partitioner]] that performs handling of long-valued keys, for use by the Python API.
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
......@@ -30,6 +31,7 @@ import org.apache.spark.util.Utils
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/
private[spark] class PythonPartitioner(
override val numPartitions: Int,
val pyPartitionFunctionId: Long)
......@@ -37,7 +39,9 @@ private[spark] class PythonPartitioner(
override def getPartition(key: Any): Int = key match {
case null => 0
case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions)
// we don't trust the Python partition function to return valid partition ID's so
// let's do a modulo numPartitions in any case
case key: Long => Utils.nonNegativeMod(key.toInt, numPartitions)
case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions)
}
......
......@@ -187,14 +187,14 @@ private class PythonException(msg: String) extends Exception(msg)
* This is used by PySpark's shuffle operations.
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Array[Byte], Array[Byte])](prev) {
RDD[(Long, Array[Byte])](prev) {
override def getPartitions = prev.partitions
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case Seq(a, b) => (Utils.deserializeLongValue(a), b)
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
private[spark] object PythonRDD {
......
......@@ -70,6 +70,19 @@ private[spark] object Utils extends Logging {
return ois.readObject.asInstanceOf[T]
}
/** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */
def deserializeLongValue(bytes: Array[Byte]) : Long = {
// Note: we assume that we are given a Long value encoded in network (big-endian) byte order
var result = bytes(7) & 0xFFL
result = result + ((bytes(6) & 0xFFL) << 8)
result = result + ((bytes(5) & 0xFFL) << 16)
result = result + ((bytes(4) & 0xFFL) << 24)
result = result + ((bytes(3) & 0xFFL) << 32)
result = result + ((bytes(2) & 0xFFL) << 40)
result = result + ((bytes(1) & 0xFFL) << 48)
result + ((bytes(0) & 0xFFL) << 56)
}
/** Serialize via nested stream using specific serializer */
def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = {
val osWrapper = ser.serializeStream(new OutputStream {
......
......@@ -20,6 +20,7 @@ package org.apache.spark.util
import com.google.common.base.Charsets
import com.google.common.io.Files
import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File}
import java.nio.{ByteBuffer, ByteOrder}
import org.scalatest.FunSuite
import org.apache.commons.io.FileUtils
import scala.util.Random
......@@ -135,5 +136,15 @@ class UtilsSuite extends FunSuite {
FileUtils.deleteDirectory(tmpDir2)
}
test("deserialize long value") {
val testval : Long = 9730889947L
val bbuf = ByteBuffer.allocate(8)
assert(bbuf.hasArray)
bbuf.order(ByteOrder.BIG_ENDIAN)
bbuf.putLong(testval)
assert(bbuf.array.length === 8)
assert(Utils.deserializeLongValue(bbuf.array) === testval)
}
}
......@@ -29,7 +29,7 @@ from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
read_from_pickle_file
read_from_pickle_file, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
......@@ -690,11 +690,13 @@ class RDD(object):
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield pack_long(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
......@@ -831,8 +833,8 @@ class RDD(object):
>>> sorted(x.subtractByKey(y).collect())
[('b', 4), ('b', 5)]
"""
filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0
map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]]
filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0
map_func = lambda (key, vals): [(key, val) for val in vals[0]]
return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
def subtract(self, other, numPartitions=None):
......
......@@ -67,6 +67,10 @@ def write_long(value, stream):
stream.write(struct.pack("!q", value))
def pack_long(value):
return struct.pack("!q", value)
def read_int(stream):
length = stream.read(4)
if length == "":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment