diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 088b298aada26ace6c21df446c65516e0fd1e336..04a8d05988f168a2a294c939aa70e9c8ef78824c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -218,19 +218,11 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
    * Partitioner to partition the output RDD.
    */
   def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
-    val createHLL = (v: V) => {
-      val hll = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
-      hll.value.offer(v)
-      hll
-    }
-    val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => {
-      hll.value.offer(v)
-      hll
-    }
+    val createHLL = (v: V) => new SerializableHyperLogLog(new HyperLogLog(relativeSD)).add(v)
+    val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => hll.add(v)
     val mergeHLL = (h1: SerializableHyperLogLog, h2: SerializableHyperLogLog) => h1.merge(h2)
 
     combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.value.cardinality())
-
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 161fd067e1ea254e1344da046eb257aa901fcd38..4960e6e82f3781ee14469e2c6c25a93b6ba471c5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -798,20 +798,8 @@ abstract class RDD[T: ClassTag](
    * relativeSD is 0.05.
    */
   def countApproxDistinct(relativeSD: Double = 0.05): Long = {
-
-    def hllCountPartition(iter: Iterator[T]): Iterator[SerializableHyperLogLog] = {
-      val hllCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
-      while (iter.hasNext) {
-        val v = iter.next()
-        hllCounter.value.offer(v)
-      }
-      Iterator(hllCounter)
-    }
-    def mergeCounters(c1: SerializableHyperLogLog, c2: SerializableHyperLogLog) = c1.merge(c2)
-
     val zeroCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
-    mapPartitions(hllCountPartition).aggregate(zeroCounter)(mergeCounters, mergeCounters)
-      .value.cardinality()
+    aggregate(zeroCounter)(_.add(_), _.merge(_)).value.cardinality()
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
index 9cfd41407f97b78bac1fbe268e62e30327975f8a..8b4e7c104cb19424f1868c4e48fc5baa3880c34b 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala
@@ -30,6 +30,11 @@ class SerializableHyperLogLog(var value: ICardinality) extends Externalizable {
 
   def merge(other: SerializableHyperLogLog) = new SerializableHyperLogLog(value.merge(other.value))
 
+  def add[T](elem: T) = {
+    this.value.offer(elem)
+    this
+  }
+
   def readExternal(in: ObjectInput) {
     val byteLength = in.readInt()
     val bytes = new Array[Byte](byteLength)