diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 39c402b4120a375f87fdf041397cc9e3b41a0a29..7dfabb0b7d3d7f3472c24825bbd7ae6b50e2756b 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -263,7 +263,53 @@ class RDD(object):
             raise TypeError
         return self.union(other)
 
-    # TODO: sort
+    def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
+        """
+        Sorts this RDD, which is assumed to consist of (key, value) pairs.
+
+        >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+        >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
+        [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
+        >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
+        >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
+        >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
+        [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
+        """
+        if numPartitions is None:
+            numPartitions = self.ctx.defaultParallelism
+
+        bounds = list()
+
+        # first compute the boundary of each part via sampling: we want to partition
+        # the key-space into bins such that the bins have roughly the same
+        # number of (key, value) pairs falling into them
+        if numPartitions > 1:
+            rddSize = self.count()
+            maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
+            fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+
+            samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()        
+            samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+            # we have numPartitions many parts but one of the them has
+            # an implicit boundary
+            for i in range(0, numPartitions - 1):
+                index = (len(samples) - 1) * (i + 1) / numPartitions
+                bounds.append(samples[index])
+
+        def rangePartitionFunc(k):
+            p = 0
+            while p < len(bounds) and keyfunc(k) > bounds[p]:
+                p += 1
+            if ascending:
+                return p
+            else:
+                return numPartitions-1-p
+
+        def mapFunc(iterator):
+            yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+
+        return self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc).mapPartitions(mapFunc,preservesPartitioning=True).flatMap(lambda x: x, preservesPartitioning=True)
 
     def glom(self):
         """