diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 1484f29525a4eb18cf7af08d98488cb7753bd7e9..debbd8d7c26c9a005ea17288a708efb9bb5abc7d 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -108,11 +108,21 @@ class HashPartitioner(partitions: Int) extends Partitioner { class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, rdd: RDD[_ <: Product2[K, V]], - private var ascending: Boolean = true) + private var ascending: Boolean = true, + val samplePointsPerPartitionHint: Int = 20) extends Partitioner { + // A constructor declared in order to maintain backward compatibility for Java, when we add the + // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160. + // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor. + def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = { + this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20) + } + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + require(samplePointsPerPartitionHint > 0, + s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint") private var ordering = implicitly[Ordering[K]] @@ -122,7 +132,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. - val sampleSize = math.min(20.0 * partitions, 1e6) + // Cast to double to avoid overflowing ints or longs + val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 358cf621490709f46ebd0e0432e7098a4fdc5191..1a73d168b9b6e85ef65f0826867774c4822a4428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -907,6 +907,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION = + buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition") + .internal() + .doc("Number of points to sample per partition in order to determine the range boundaries" + + " for range partitioning, typically used in global sorting (without limit).") + .intConf + .createWithDefault(100) + val ARROW_EXECUTION_ENABLE = buildConf("spark.sql.execution.arrow.enabled") .internal() @@ -1199,6 +1207,8 @@ class SQLConf extends Serializable with Logging { def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 11c4aa9b4acf0f4760822351130542fd162d3c3c..5a1e217082bc2e2d1cbfe7433816dca2101a9992 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.MutablePair /** @@ -218,7 +219,11 @@ object ShuffleExchangeExec { iter.map(row => mutablePair.update(row.copy(), null)) } implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) + new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new Partitioner { override def numPartitions: Int = 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..2c1e5db5fd9bb043604f7843a23e8e83eefeca20 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.commons.math3.stat.inference.ChiSquareTest + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") { + // In this test, we run a sort and compute the histogram for partition size post shuffle. + // With a high sample count, the partition size should be more evenly distributed, and has a + // low chi-sq test value. + // Also the whole code path for range partitioning as implemented should be deterministic + // (it uses the partition id as the seed), so this test shouldn't be flaky. + + val numPartitions = 4 + + def computeChiSquareTest(): Double = { + val n = 10000 + // Trigger a sort + val data = spark.range(0, n, 1, 1).sort('id) + .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() + + // Compute histogram for the number of records per partition post sort + val dist = data.groupBy(_._1).map(_._2.length.toLong).toArray + assert(dist.length == 4) + + new ChiSquareTest().chiSquare( + Array.fill(numPartitions) { n.toDouble / numPartitions }, + dist) + } + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) { + // The default chi-sq value should be low + assert(computeChiSquareTest() < 100) + + withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { + // If we only sample one point, the range boundaries will be pretty bad and the + // chi-sq value would be very high. + assert(computeChiSquareTest() > 1000) + } + } + } + +}