From c1ea3afb516c204925259f0928dfb17d0fa89621 Mon Sep 17 00:00:00 2001
From: Prashant Sharma <prashant.s@imaginea.com>
Date: Thu, 3 Apr 2014 15:42:17 -0700
Subject: [PATCH] Spark 1162 Implemented takeOrdered in pyspark.

Since python does not have a library for max heap and usual tricks like inverting values etc.. does not work for all cases.

We have our own implementation of max heap.

Author: Prashant Sharma <prashant.s@imaginea.com>

Closes #97 from ScrapCodes/SPARK-1162/pyspark-top-takeOrdered2 and squashes the following commits:

35f86ba [Prashant Sharma] code review
2b1124d [Prashant Sharma] fixed tests
e8a08e2 [Prashant Sharma] Code review comments.
49e6ba7 [Prashant Sharma] SPARK-1162 added takeOrdered to pyspark
---
 python/pyspark/rdd.py | 107 ++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 102 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 019c249699..9943296b92 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
 import warnings
-from heapq import heappush, heappop, heappushpop
+import heapq
 
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -41,9 +41,9 @@ from pyspark.storagelevel import StorageLevel
 
 from py4j.java_collections import ListConverter, MapConverter
 
-
 __all__ = ["RDD"]
 
+
 def _extract_concise_traceback():
     """
     This function returns the traceback info for a callsite, returns a dict
@@ -91,6 +91,73 @@ class _JavaStackTrace(object):
         if _spark_stack_depth == 0:
             self._context._jsc.setCallSite(None)
 
+class MaxHeapQ(object):
+    """
+    An implementation of MaxHeap.
+    >>> import pyspark.rdd
+    >>> heap = pyspark.rdd.MaxHeapQ(5)
+    >>> [heap.insert(i) for i in range(10)]
+    [None, None, None, None, None, None, None, None, None, None]
+    >>> sorted(heap.getElements())
+    [0, 1, 2, 3, 4]
+    >>> heap = pyspark.rdd.MaxHeapQ(5)
+    >>> [heap.insert(i) for i in range(9, -1, -1)]
+    [None, None, None, None, None, None, None, None, None, None]
+    >>> sorted(heap.getElements())
+    [0, 1, 2, 3, 4]
+    >>> heap = pyspark.rdd.MaxHeapQ(1)
+    >>> [heap.insert(i) for i in range(9, -1, -1)]
+    [None, None, None, None, None, None, None, None, None, None]
+    >>> heap.getElements()
+    [0]
+    """
+
+    def __init__(self, maxsize):
+        # we start from q[1], this makes calculating children as trivial as 2 * k
+        self.q = [0]
+        self.maxsize = maxsize
+
+    def _swim(self, k):
+        while (k > 1) and (self.q[k/2] < self.q[k]):
+            self._swap(k, k/2)
+            k = k/2
+
+    def _swap(self, i, j):
+        t = self.q[i]
+        self.q[i] = self.q[j]
+        self.q[j] = t
+
+    def _sink(self, k):
+        N = self.size()
+        while 2 * k <= N:
+            j = 2 * k
+            # Here we test if both children are greater than parent
+            # if not swap with larger one.
+            if j < N and self.q[j] < self.q[j + 1]:
+                j = j + 1
+            if(self.q[k] > self.q[j]):
+                break
+            self._swap(k, j)
+            k = j
+
+    def size(self):
+        return len(self.q) - 1
+
+    def insert(self, value):
+        if (self.size()) < self.maxsize:
+            self.q.append(value)
+            self._swim(self.size())
+        else:
+            self._replaceRoot(value)
+
+    def getElements(self):
+        return self.q[1:]
+
+    def _replaceRoot(self, value):
+        if(self.q[1] > value):
+            self.q[1] = value
+            self._sink(1)
+
 class RDD(object):
     """
     A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -696,16 +763,16 @@ class RDD(object):
         Note: It returns the list sorted in descending order.
         >>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
         [12]
-        >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
+        >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
         [6, 5]
         """
         def topIterator(iterator):
             q = []
             for k in iterator:
                 if len(q) < num:
-                    heappush(q, k)
+                    heapq.heappush(q, k)
                 else:
-                    heappushpop(q, k)
+                    heapq.heappushpop(q, k)
             yield q
 
         def merge(a, b):
@@ -713,6 +780,36 @@ class RDD(object):
 
         return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
 
+    def takeOrdered(self, num, key=None):
+        """
+        Get the N elements from a RDD ordered in ascending order or as specified
+        by the optional key function. 
+
+        >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
+        [1, 2, 3, 4, 5, 6]
+        >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
+        [10, 9, 7, 6, 5, 4]
+        """
+
+        def topNKeyedElems(iterator, key_=None):
+            q = MaxHeapQ(num)
+            for k in iterator:
+                if key_ != None:
+                    k = (key_(k), k)
+                q.insert(k)
+            yield q.getElements()
+
+        def unKey(x, key_=None):
+            if key_ != None:
+                x = [i[1] for i in x]
+            return x
+        
+        def merge(a, b):
+            return next(topNKeyedElems(a + b))
+        result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
+        return sorted(unKey(result, key), key=key)
+
+
     def take(self, num):
         """
         Take the first num elements of the RDD.
-- 
GitLab