diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 7dcfbf741c4f127259e607caa62da107d7dcd676..14fa9d8135afea3662270d27e7f245d6b3d15bab 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -228,6 +228,50 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, partitioner: Partitioner, seqFunc: JFunction2[U, V, U], + combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue, partitioner)(seqFunc, combFunc)) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, numPartitions: Int, seqFunc: JFunction2[U, V, U], + combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue, numPartitions)(seqFunc, combFunc)) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's. + * The former operation is used for merging values within a partition, and the latter is used for + * merging values between partitions. To avoid memory allocation, both of these functions are + * allowed to modify and return their first argument instead of creating a new U. + */ + def aggregateByKey[U](zeroValue: U, seqFunc: JFunction2[U, V, U], combFunc: JFunction2[U, U, U]): + JavaPairRDD[K, U] = { + implicit val ctag: ClassTag[U] = fakeClassTag + fromRDD(rdd.aggregateByKey(zeroValue)(seqFunc, combFunc)) + } + /** * Merge the values for each key using an associative function and a neutral "zero value" which * may be added to the result an arbitrary number of times, and must not change the result diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 8909980957058ef71baaf341d24a8044323b3e66..b6ad9b6c3e1682907048f4058d77cf541e0358c1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -118,6 +118,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) } + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) + + combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp) + } + + /** + * Aggregate the values of each key, using given combine functions and a neutral "zero value". + * This function can return a different result type, U, than the type of the values in this RDD, + * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + * as in scala.TraversableOnce. The former operation is used for merging values within a + * partition, and the latter is used for merging values between partitions. To avoid memory + * allocation, both of these functions are allowed to modify and return their first argument + * instead of creating a new U. + */ + def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U, + combOp: (U, U) => U): RDD[(K, U)] = { + aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp) + } + /** * Merge the values for each key using an associative function and a neutral "zero value" which * may be added to the result an arbitrary number of times, and must not change the result diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 50a62129116f188e7e8aca9a3db2750ff00cf36d..ef41bfb88de9d79a9b3bde452c14b74651ec12df 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -317,6 +317,37 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(33, sum); } + @Test + public void aggregateByKey() { + JavaPairRDD<Integer, Integer> pairs = sc.parallelizePairs( + Arrays.asList( + new Tuple2<Integer, Integer>(1, 1), + new Tuple2<Integer, Integer>(1, 1), + new Tuple2<Integer, Integer>(3, 2), + new Tuple2<Integer, Integer>(5, 1), + new Tuple2<Integer, Integer>(5, 3)), 2); + + Map<Integer, Set<Integer>> sets = pairs.aggregateByKey(new HashSet<Integer>(), + new Function2<Set<Integer>, Integer, Set<Integer>>() { + @Override + public Set<Integer> call(Set<Integer> a, Integer b) { + a.add(b); + return a; + } + }, + new Function2<Set<Integer>, Set<Integer>, Set<Integer>>() { + @Override + public Set<Integer> call(Set<Integer> a, Set<Integer> b) { + a.addAll(b); + return a; + } + }).collectAsMap(); + Assert.assertEquals(3, sets.size()); + Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet<Integer>(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1, 3)), sets.get(5)); + } + @SuppressWarnings("unchecked") @Test public void foldByKey() { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 9ddafc451878dc9dbd493a732d3791c1a978e9c4..0b9004448a63e05d4271f6f49afefbc10fdc7918 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._ import org.apache.spark.{Partitioner, SharedSparkContext} class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("aggregateByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2) + + val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect() + assert(sets.size === 3) + val valuesFor1 = sets.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1)) + val valuesFor3 = sets.find(_._1 == 3).get._2 + assert(valuesFor3.toList.sorted === List(2)) + val valuesFor5 = sets.find(_._1 == 5).get._2 + assert(valuesFor5.toList.sorted === List(1, 3)) + } + test("groupByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) val groups = pairs.groupByKey().collect() diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 7989e02dfb732b11b18a6e9ba910698bfff0913a..79784682bfd1b11919695862d1b0a63cd11b2801 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -890,6 +890,10 @@ for details. <td> <b>reduceByKey</b>(<i>func</i>, [<i>numTasks</i>]) </td> <td> When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td> </tr> +<tr> + <td> <b>aggregateByKey</b>(<i>zeroValue</i>)(<i>seqOp</i>, <i>combOp</i>, [<i>numTasks</i>]) </td> + <td> When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td> +</tr> <tr> <td> <b>sortByKey</b>([<i>ascending</i>], [<i>numTasks</i>]) </td> <td> When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean <code>ascending</code> argument.</td> diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c80ab9a9f8e60c14abdfdec63467474aec9ce04b..ee629794f60ad01230c3aba1191e7aede09f6fc7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -52,7 +52,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1") + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" + + "createZero$1") ) ++ Seq( // Ignore some private methods in ALS. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8a215fc51130a250f4c320dac036e1a84c6c2825..735389c69831c7a047a563ee7aae4e38e4f38ab7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1178,6 +1178,20 @@ class RDD(object): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) + + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + """ + Aggregate the values of each key, using given combine functions and a neutral "zero value". + This function can return a different result type, U, than the type of the values in this RDD, + V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, + The former operation is used for merging values within a partition, and the latter is used + for merging values between partitions. To avoid memory allocation, both of these functions are + allowed to modify and return their first argument instead of creating a new U. + """ + def createZero(): + return copy.deepcopy(zeroValue) + + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): """ @@ -1190,7 +1204,10 @@ class RDD(object): >>> rdd.foldByKey(0, add).collect() [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions) + def createZero(): + return copy.deepcopy(zeroValue) + + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) # TODO: support variant with custom partitioner diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 184ee810b861b9c7c20c320d1b79f1bff8262202..c15bb457759edf6ed8ec8024b2c29721c6027b1c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -188,6 +188,21 @@ class TestRDDFunctions(PySparkTestCase): os.unlink(tempFile.name) self.assertRaises(Exception, lambda: filtered_data.count()) + def testAggregateByKey(self): + data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + def seqOp(x, y): + x.add(y) + return x + + def combOp(x, y): + x |= y + return x + + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) + self.assertEqual(3, len(sets)) + self.assertEqual(set([1]), sets[1]) + self.assertEqual(set([2]), sets[3]) + self.assertEqual(set([1, 3]), sets[5]) class TestIO(PySparkTestCase):