From db436e36c4e20893de708a0bc07a5a8877c49563 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies.liu@gmail.com>
Date: Sat, 23 Aug 2014 18:55:13 -0700
Subject: [PATCH] [SPARK-2871] [PySpark] add `key` argument for max(), min()
 and top(n)

RDD.max(key=None)

        param key: A function used to generate key for comparing

        >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
        >>> rdd.max()
        43.0
        >>> rdd.max(key=str)
        5.0

RDD.min(key=None)

        Find the minimum item in this RDD.

        param key: A function used to generate key for comparing

        >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
        >>> rdd.min()
        2.0
        >>> rdd.min(key=str)
        10.0

RDD.top(num, key=None)

        Get the top N elements from a RDD.

        Note: It returns the list sorted in descending order.
        >>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
        [12]
        >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
        [6, 5]
        >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
        [4, 3, 2]

Author: Davies Liu <davies.liu@gmail.com>

Closes #2094 from davies/cmp and squashes the following commits:

ccbaf25 [Davies Liu] add `key` to top()
ad7e374 [Davies Liu] fix tests
2f63512 [Davies Liu] change `comp` to `key` in min/max
dd91e08 [Davies Liu] add `comp` argument for RDD.max() and RDD.min()
---
 python/pyspark/rdd.py | 44 ++++++++++++++++++++++++++-----------------
 1 file changed, 27 insertions(+), 17 deletions(-)

diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 3eefc878d2..bdd8bc8286 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -810,23 +810,37 @@ class RDD(object):
 
         return self.mapPartitions(func).fold(zeroValue, combOp)
 
-    def max(self):
+    def max(self, key=None):
         """
         Find the maximum item in this RDD.
 
-        >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max()
+        @param key: A function used to generate key for comparing
+
+        >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
+        >>> rdd.max()
         43.0
+        >>> rdd.max(key=str)
+        5.0
         """
-        return self.reduce(max)
+        if key is None:
+            return self.reduce(max)
+        return self.reduce(lambda a, b: max(a, b, key=key))
 
-    def min(self):
+    def min(self, key=None):
         """
         Find the minimum item in this RDD.
 
-        >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
-        1.0
+        @param key: A function used to generate key for comparing
+
+        >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
+        >>> rdd.min()
+        2.0
+        >>> rdd.min(key=str)
+        10.0
         """
-        return self.reduce(min)
+        if key is None:
+            return self.reduce(min)
+        return self.reduce(lambda a, b: min(a, b, key=key))
 
     def sum(self):
         """
@@ -924,7 +938,7 @@ class RDD(object):
             return m1
         return self.mapPartitions(countPartition).reduce(mergeMaps)
 
-    def top(self, num):
+    def top(self, num, key=None):
         """
         Get the top N elements from a RDD.
 
@@ -933,20 +947,16 @@ class RDD(object):
         [12]
         >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
         [6, 5]
+        >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
+        [4, 3, 2]
         """
         def topIterator(iterator):
-            q = []
-            for k in iterator:
-                if len(q) < num:
-                    heapq.heappush(q, k)
-                else:
-                    heapq.heappushpop(q, k)
-            yield q
+            yield heapq.nlargest(num, iterator, key=key)
 
         def merge(a, b):
-            return next(topIterator(a + b))
+            return heapq.nlargest(num, a + b, key=key)
 
-        return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
+        return self.mapPartitions(topIterator).reduce(merge)
 
     def takeOrdered(self, num, key=None):
         """
-- 
GitLab