diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
index b090c6edf3c07adad0e796000cbc81194ebae6be..2be4e323bec9808338e8d48fd3463551ef504f69 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.api.python
 
-import org.apache.spark.Partitioner
 import java.util.Arrays
+
+import org.apache.spark.Partitioner
 import org.apache.spark.util.Utils
 
 /**
- * A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ * A [[org.apache.spark.Partitioner]] that performs handling of long-valued keys, for use by the Python API.
  *
  * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
  * equality comparisons.  Correctness requires that the id is a unique identifier for the
@@ -30,6 +31,7 @@ import org.apache.spark.util.Utils
  * function).  This can be ensured by using the Python id() function and maintaining a reference
  * to the Python partitioning function so that its id() is not reused.
  */
+
 private[spark] class PythonPartitioner(
   override val numPartitions: Int,
   val pyPartitionFunctionId: Long)
@@ -37,7 +39,9 @@ private[spark] class PythonPartitioner(
 
   override def getPartition(key: Any): Int = key match {
     case null => 0
-    case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions)
+    // we don't trust the Python partition function to return valid partition ID's so
+    // let's do a modulo numPartitions in any case
+    case key: Long => Utils.nonNegativeMod(key.toInt, numPartitions)
     case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ccd383396460254d5ba433d0c654004f8a5d5a23..1f8ad688a649dec0d14228e4db97c600d1877e7f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -187,14 +187,14 @@ private class PythonException(msg: String) extends Exception(msg)
  * This is used by PySpark's shuffle operations.
  */
 private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
-  RDD[(Array[Byte], Array[Byte])](prev) {
+  RDD[(Long, Array[Byte])](prev) {
   override def getPartitions = prev.partitions
   override def compute(split: Partition, context: TaskContext) =
     prev.iterator(split, context).grouped(2).map {
-      case Seq(a, b) => (a, b)
+      case Seq(a, b) => (Utils.deserializeLongValue(a), b)
       case x          => throw new SparkException("PairwiseRDD: unexpected value: " + x)
     }
-  val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
+  val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
 private[spark] object PythonRDD {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 886f071503569af0bb6f14c475bad79e9e3c743e..f384875cc9af80889312ee61be325bc57ab9f7a3 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -70,6 +70,19 @@ private[spark] object Utils extends Logging {
     return ois.readObject.asInstanceOf[T]
   }
 
+  /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */
+  def deserializeLongValue(bytes: Array[Byte]) : Long = {
+    // Note: we assume that we are given a Long value encoded in network (big-endian) byte order
+    var result = bytes(7) & 0xFFL
+    result = result + ((bytes(6) & 0xFFL) << 8)
+    result = result + ((bytes(5) & 0xFFL) << 16)
+    result = result + ((bytes(4) & 0xFFL) << 24)
+    result = result + ((bytes(3) & 0xFFL) << 32)
+    result = result + ((bytes(2) & 0xFFL) << 40)
+    result = result + ((bytes(1) & 0xFFL) << 48)
+    result + ((bytes(0) & 0xFFL) << 56)
+  }
+
   /** Serialize via nested stream using specific serializer */
   def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = {
     val osWrapper = ser.serializeStream(new OutputStream {
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index e2859caf58d5a9fbfbce963cd496e3bf740ad87b..4684c8c972d3c31fa7fd2cb223958640fe475d4a 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util
 import com.google.common.base.Charsets
 import com.google.common.io.Files
 import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File}
+import java.nio.{ByteBuffer, ByteOrder}
 import org.scalatest.FunSuite
 import org.apache.commons.io.FileUtils
 import scala.util.Random
@@ -135,5 +136,15 @@ class UtilsSuite extends FunSuite {
 
     FileUtils.deleteDirectory(tmpDir2)
   }
+
+  test("deserialize long value") {
+    val testval : Long = 9730889947L
+    val bbuf = ByteBuffer.allocate(8)
+    assert(bbuf.hasArray)
+    bbuf.order(ByteOrder.BIG_ENDIAN)
+    bbuf.putLong(testval)
+    assert(bbuf.array.length === 8)
+    assert(Utils.deserializeLongValue(bbuf.array) === testval)
+  }
 }
 
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 58e1849cadac8611c0a4e85681cc67ea6bb7ca25..39c402b4120a375f87fdf041397cc9e3b41a0a29 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ from threading import Thread
 
 from pyspark import cloudpickle
 from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
-    read_from_pickle_file
+    read_from_pickle_file, pack_long
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
@@ -690,11 +690,13 @@ class RDD(object):
         # form the hash buckets in Python, transferring O(numPartitions) objects
         # to Java.  Each object is a (splitNumber, [objects]) pair.
         def add_shuffle_key(split, iterator):
+
             buckets = defaultdict(list)
+
             for (k, v) in iterator:
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
             for (split, items) in buckets.iteritems():
-                yield str(split)
+                yield pack_long(split)
                 yield dump_pickle(Batch(items))
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
@@ -831,8 +833,8 @@ class RDD(object):
         >>> sorted(x.subtractByKey(y).collect())
         [('b', 4), ('b', 5)]
         """ 
-        filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0
-        map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]]
+        filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0
+        map_func = lambda (key, vals): [(key, val) for val in vals[0]]
         return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
 
     def subtract(self, other, numPartitions=None):
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fecacd12416af2e703fd67c953c02162bb841c13..54fed1c9c70f66e503abb5c523d6327bb9bae8b4 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -67,6 +67,10 @@ def write_long(value, stream):
     stream.write(struct.pack("!q", value))
 
 
+def pack_long(value):
+    return struct.pack("!q", value)
+
+
 def read_int(stream):
     length = stream.read(4)
     if length == "":