diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 3eefc878d274e910a00980ca54853a67da7b0674..bdd8bc82869fb371b7d9aef8cb9885b7e9100063 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -810,23 +810,37 @@ class RDD(object):
 
         return self.mapPartitions(func).fold(zeroValue, combOp)
 
-    def max(self):
+    def max(self, key=None):
         """
         Find the maximum item in this RDD.
 
-        >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max()
+        @param key: A function used to generate key for comparing
+
+        >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
+        >>> rdd.max()
         43.0
+        >>> rdd.max(key=str)
+        5.0
         """
-        return self.reduce(max)
+        if key is None:
+            return self.reduce(max)
+        return self.reduce(lambda a, b: max(a, b, key=key))
 
-    def min(self):
+    def min(self, key=None):
         """
         Find the minimum item in this RDD.
 
-        >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
-        1.0
+        @param key: A function used to generate key for comparing
+
+        >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
+        >>> rdd.min()
+        2.0
+        >>> rdd.min(key=str)
+        10.0
         """
-        return self.reduce(min)
+        if key is None:
+            return self.reduce(min)
+        return self.reduce(lambda a, b: min(a, b, key=key))
 
     def sum(self):
         """
@@ -924,7 +938,7 @@ class RDD(object):
             return m1
         return self.mapPartitions(countPartition).reduce(mergeMaps)
 
-    def top(self, num):
+    def top(self, num, key=None):
         """
         Get the top N elements from a RDD.
 
@@ -933,20 +947,16 @@ class RDD(object):
         [12]
         >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
         [6, 5]
+        >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
+        [4, 3, 2]
         """
         def topIterator(iterator):
-            q = []
-            for k in iterator:
-                if len(q) < num:
-                    heapq.heappush(q, k)
-                else:
-                    heapq.heappushpop(q, k)
-            yield q
+            yield heapq.nlargest(num, iterator, key=key)
 
         def merge(a, b):
-            return next(topIterator(a + b))
+            return heapq.nlargest(num, a + b, key=key)
 
-        return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
+        return self.mapPartitions(topIterator).reduce(merge)
 
     def takeOrdered(self, num, key=None):
         """