diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index 03966f1c96c0109dbff67a4532ae514ee41d9319..eec0e8dd79da4e70994a57c021077245de612e56 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -10,12 +10,21 @@ abstract class Partitioner extends Serializable { } object Partitioner { + + private val useDefaultParallelism = System.getProperty("spark.default.parallelism") != null + /** - * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of - * the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner. + * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. + * + * If any of the RDDs already has a partitioner, choose that one. * - * The number of partitions will be the same as the number of partitions in the largest upstream - * RDD, as this should be least likely to cause out-of-memory errors. + * Otherwise, we use a default HashPartitioner. For the number of partitions, if + * spark.default.parallelism is set, then we'll use the value from SparkContext + * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * + * Unless spark.default.parallelism is set, He number of partitions will be the + * same as the number of partitions in the largest upstream RDD, as this should + * be least likely to cause out-of-memory errors. * * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ @@ -24,7 +33,11 @@ object Partitioner { for (r <- bySize if r.partitioner != None) { return r.partitioner.get } - return new HashPartitioner(bySize.head.partitions.size) + if (useDefaultParallelism) { + return new HashPartitioner(rdd.context.defaultParallelism) + } else { + return new HashPartitioner(bySize.head.partitions.size) + } } } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 2099999ed7ee385b66a1052a49121f3e45146f0e..8411291b2caa31e86f58bbf17f39fbf68020a669 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -235,7 +235,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(rdd.values.collect().toList === List("a", "b")) } - test("default partitioner uses split size") { + test("default partitioner uses partition size") { sc = new SparkContext("local", "test") // specify 2000 partitions val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)