diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 088b298aada26ace6c21df446c65516e0fd1e336..04a8d05988f168a2a294c939aa70e9c8ef78824c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -218,19 +218,11 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * Partitioner to partition the output RDD. */ def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = { - val createHLL = (v: V) => { - val hll = new SerializableHyperLogLog(new HyperLogLog(relativeSD)) - hll.value.offer(v) - hll - } - val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => { - hll.value.offer(v) - hll - } + val createHLL = (v: V) => new SerializableHyperLogLog(new HyperLogLog(relativeSD)).add(v) + val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => hll.add(v) val mergeHLL = (h1: SerializableHyperLogLog, h2: SerializableHyperLogLog) => h1.merge(h2) combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.value.cardinality()) - } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 161fd067e1ea254e1344da046eb257aa901fcd38..4960e6e82f3781ee14469e2c6c25a93b6ba471c5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -798,20 +798,8 @@ abstract class RDD[T: ClassTag]( * relativeSD is 0.05. */ def countApproxDistinct(relativeSD: Double = 0.05): Long = { - - def hllCountPartition(iter: Iterator[T]): Iterator[SerializableHyperLogLog] = { - val hllCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD)) - while (iter.hasNext) { - val v = iter.next() - hllCounter.value.offer(v) - } - Iterator(hllCounter) - } - def mergeCounters(c1: SerializableHyperLogLog, c2: SerializableHyperLogLog) = c1.merge(c2) - val zeroCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD)) - mapPartitions(hllCountPartition).aggregate(zeroCounter)(mergeCounters, mergeCounters) - .value.cardinality() + aggregate(zeroCounter)(_.add(_), _.merge(_)).value.cardinality() } /** diff --git a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala index 9cfd41407f97b78bac1fbe268e62e30327975f8a..8b4e7c104cb19424f1868c4e48fc5baa3880c34b 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala @@ -30,6 +30,11 @@ class SerializableHyperLogLog(var value: ICardinality) extends Externalizable { def merge(other: SerializableHyperLogLog) = new SerializableHyperLogLog(value.merge(other.value)) + def add[T](elem: T) = { + this.value.offer(elem) + this + } + def readExternal(in: ObjectInput) { val byteLength = in.readInt() val bytes = new Array[Byte](byteLength)