From bee445c927586136673f39259f23642a5a6e8efe Mon Sep 17 00:00:00 2001
From: Hossein Falaki <falaki@gmail.com>
Date: Tue, 31 Dec 2013 16:58:18 -0800
Subject: [PATCH] Made the code more compact and readable

---
 .../org/apache/spark/rdd/PairRDDFunctions.scala    | 12 ++----------
 core/src/main/scala/org/apache/spark/rdd/RDD.scala | 14 +-------------
 .../spark/util/SerializableHyperLogLog.scala       |  5 +++++
 3 files changed, 8 insertions(+), 23 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 088b298aad..04a8d05988 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -218,19 +218,11 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
    * Partitioner to partition the output RDD.
    */
   def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
-    val createHLL = (v: V) => {
-      val hll = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
-      hll.value.offer(v)
-      hll
-    }
-    val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => {
-      hll.value.offer(v)
-      hll
-    }
+    val createHLL = (v: V) => new SerializableHyperLogLog(new HyperLogLog(relativeSD)).add(v)
+    val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => hll.add(v)
     val mergeHLL = (h1: SerializableHyperLogLog, h2: SerializableHyperLogLog) => h1.merge(h2)
 
     combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.value.cardinality())
-
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 161fd067e1..4960e6e82f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -798,20 +798,8 @@ abstract class RDD[T: ClassTag](
    * relativeSD is 0.05.
    */
   def countApproxDistinct(relativeSD: Double = 0.05): Long = {
-
-    def hllCountPartition(iter: Iterator[T]): Iterator[SerializableHyperLogLog] = {
-      val hllCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
-      while (iter.hasNext) {
-        val v = iter.next()
-        hllCounter.value.offer(v)
-      }
-      Iterator(hllCounter)
-    }
-    def mergeCounters(c1: SerializableHyperLogLog, c2: SerializableHyperLogLog) = c1.merge(c2)
-
     val zeroCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
-    mapPartitions(hllCountPartition).aggregate(zeroCounter)(mergeCounters, mergeCounters)
-      .value.cardinality()
+    aggregate(zeroCounter)(_.add(_), _.merge(_)).value.cardinality()
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
index 9cfd41407f..8b4e7c104c 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
@@ -30,6 +30,11 @@ class SerializableHyperLogLog(var value: ICardinality) extends Externalizable {
 
   def merge(other: SerializableHyperLogLog) = new SerializableHyperLogLog(value.merge(other.value))
 
+  def add[T](elem: T) = {
+    this.value.offer(elem)
+    this
+  }
+
   def readExternal(in: ObjectInput) {
     val byteLength = in.readInt()
     val bytes = new Array[Byte](byteLength)
-- 
GitLab