diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala index b090c6edf3c07adad0e796000cbc81194ebae6be..2be4e323bec9808338e8d48fd3463551ef504f69 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala @@ -17,12 +17,13 @@ package org.apache.spark.api.python -import org.apache.spark.Partitioner import java.util.Arrays + +import org.apache.spark.Partitioner import org.apache.spark.util.Utils /** - * A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * A [[org.apache.spark.Partitioner]] that performs handling of long-valued keys, for use by the Python API. * * Stores the unique id() of the Python-side partitioning function so that it is incorporated into * equality comparisons. Correctness requires that the id is a unique identifier for the @@ -30,6 +31,7 @@ import org.apache.spark.util.Utils * function). This can be ensured by using the Python id() function and maintaining a reference * to the Python partitioning function so that its id() is not reused. */ + private[spark] class PythonPartitioner( override val numPartitions: Int, val pyPartitionFunctionId: Long) @@ -37,7 +39,9 @@ private[spark] class PythonPartitioner( override def getPartition(key: Any): Int = key match { case null => 0 - case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions) + // we don't trust the Python partition function to return valid partition ID's so + // let's do a modulo numPartitions in any case + case key: Long => Utils.nonNegativeMod(key.toInt, numPartitions) case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ccd383396460254d5ba433d0c654004f8a5d5a23..1f8ad688a649dec0d14228e4db97c600d1877e7f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -187,14 +187,14 @@ private class PythonException(msg: String) extends Exception(msg) * This is used by PySpark's shuffle operations. */ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Array[Byte], Array[Byte])](prev) { + RDD[(Long, Array[Byte])](prev) { override def getPartitions = prev.partitions override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { - case Seq(a, b) => (a, b) + case Seq(a, b) => (Utils.deserializeLongValue(a), b) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) } - val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) + val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } private[spark] object PythonRDD { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 886f071503569af0bb6f14c475bad79e9e3c743e..f384875cc9af80889312ee61be325bc57ab9f7a3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -70,6 +70,19 @@ private[spark] object Utils extends Logging { return ois.readObject.asInstanceOf[T] } + /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */ + def deserializeLongValue(bytes: Array[Byte]) : Long = { + // Note: we assume that we are given a Long value encoded in network (big-endian) byte order + var result = bytes(7) & 0xFFL + result = result + ((bytes(6) & 0xFFL) << 8) + result = result + ((bytes(5) & 0xFFL) << 16) + result = result + ((bytes(4) & 0xFFL) << 24) + result = result + ((bytes(3) & 0xFFL) << 32) + result = result + ((bytes(2) & 0xFFL) << 40) + result = result + ((bytes(1) & 0xFFL) << 48) + result + ((bytes(0) & 0xFFL) << 56) + } + /** Serialize via nested stream using specific serializer */ def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { val osWrapper = ser.serializeStream(new OutputStream { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index e2859caf58d5a9fbfbce963cd496e3bf740ad87b..4684c8c972d3c31fa7fd2cb223958640fe475d4a 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import com.google.common.base.Charsets import com.google.common.io.Files import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File} +import java.nio.{ByteBuffer, ByteOrder} import org.scalatest.FunSuite import org.apache.commons.io.FileUtils import scala.util.Random @@ -135,5 +136,15 @@ class UtilsSuite extends FunSuite { FileUtils.deleteDirectory(tmpDir2) } + + test("deserialize long value") { + val testval : Long = 9730889947L + val bbuf = ByteBuffer.allocate(8) + assert(bbuf.hasArray) + bbuf.order(ByteOrder.BIG_ENDIAN) + bbuf.putLong(testval) + assert(bbuf.array.length === 8) + assert(Utils.deserializeLongValue(bbuf.array) === testval) + } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 58e1849cadac8611c0a4e85681cc67ea6bb7ca25..39c402b4120a375f87fdf041397cc9e3b41a0a29 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,7 +29,7 @@ from threading import Thread from pyspark import cloudpickle from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file + read_from_pickle_file, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -690,11 +690,13 @@ class RDD(object): # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. def add_shuffle_key(split, iterator): + buckets = defaultdict(list) + for (k, v) in iterator: buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): - yield str(split) + yield pack_long(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True @@ -831,8 +833,8 @@ class RDD(object): >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] """ - filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0 - map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]] + filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0 + map_func = lambda (key, vals): [(key, val) for val in vals[0]] return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) def subtract(self, other, numPartitions=None): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fecacd12416af2e703fd67c953c02162bb841c13..54fed1c9c70f66e503abb5c523d6327bb9bae8b4 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -67,6 +67,10 @@ def write_long(value, stream): stream.write(struct.pack("!q", value)) +def pack_long(value): + return struct.pack("!q", value) + + def read_int(stream): length = stream.read(4) if length == "":