diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ccebae3cc2701ddc6abdef3d195c518fb7f59efd..4d27ff2acdbadb18db62667120792826c9385200 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - 1 - } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum - } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val childStats = child.stats(conf) + val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) + // Don't propagate column stats, because we don't know the distribution after a limit operation + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount), + isBroadcastable = childStats.isBroadcastable) } } @@ -773,14 +772,21 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { + val childStats = child.stats(conf) + if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). - 1 + Statistics( + sizeInBytes = 1, + rowCount = Some(0), + isBroadcastable = childStats.isBroadcastable) } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum + // The output row count of LocalLimit should be the sum of row counts from each partition. + // However, since the number of partitions is not available here, we just use statistics of + // the child. Because the distribution after a limit operation is unknown, we do not propagate + // the column stats. + childStats.copy(attributeStats = AttributeMap(Nil)) } - child.stats(conf).copy(sizeInBytes = sizeInBytes) } } @@ -816,12 +822,14 @@ case class Sample( override def computeStats(conf: CatalystConf): Statistics = { val ratio = upperBound - lowerBound - // BigInt can't multiply with Double - var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100 + val childStats = child.stats(conf) + var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) + // Don't propagate column stats, because we don't know the distribution after a sample operation + Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e5dc811c8b7db3c2d87da949fffa7f3e6f31bba1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + + +class BasicStatsEstimationSuite extends StatsEstimationTestBase { + val attribute = attr("key") + val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + val plan = StatsTestPlan( + outputList = Seq(attribute), + attributeStats = AttributeMap(Seq(attribute -> colStat)), + rowCount = 10, + // row count * (overhead + column size) + size = Some(10 * (8 + 4))) + + test("limit estimation: limit < child's rowCount") { + val localLimit = LocalLimit(Literal(2), plan) + val globalLimit = GlobalLimit(Literal(2), plan) + // LocalLimit's stats is just its child's stats except column stats + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) + } + + test("limit estimation: limit > child's rowCount") { + val localLimit = LocalLimit(Literal(20), plan) + val globalLimit = GlobalLimit(Literal(20), plan) + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. + checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + } + + test("limit estimation: limit = 0") { + val localLimit = LocalLimit(Literal(0), plan) + val globalLimit = GlobalLimit(Literal(0), plan) + val stats = Statistics(sizeInBytes = 1, rowCount = Some(0)) + checkStats(localLimit, stats) + checkStats(globalLimit, stats) + } + + test("sample estimation") { + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) + + // Child doesn't have rowCount in stats + val childStats = Statistics(sizeInBytes = 120) + val childPlan = DummyLogicalPlan(childStats, childStats) + val sample2 = + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + checkStats(sample2, Statistics(sizeInBytes = 14)) + } + + test("estimate statistics when the conf changes") { + val expectedDefaultStats = + Statistics( + sizeInBytes = 40, + rowCount = Some(10), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), + isBroadcastable = false) + val expectedCboStats = + Statistics( + sizeInBytes = 4, + rowCount = Some(1), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), + isBroadcastable = false) + + val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) + checkStats( + plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) + } + + /** Check estimated stats when cbo is turned on/off. */ + private def checkStats( + plan: LogicalPlan, + expectedStatsCboOn: Statistics, + expectedStatsCboOff: Statistics): Unit = { + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) + } + + /** Check estimated stats when it's the same whether cbo is turned on or off. */ + private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = + checkStats(plan, expectedStats, expectedStats) +} + +/** + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. + */ +private case class DummyLogicalPlan( + defaultStats: Statistics, + cboStats: Statistics) extends LogicalPlan { + override def output: Seq[Attribute] = Nil + override def children: Seq[LogicalPlan] = Nil + override def computeStats(conf: CatalystConf): Statistics = + if (conf.cboEnabled) cboStats else defaultStats +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala deleted file mode 100644 index 212d57a9bcf9505f0f5375ba06339c604f40b6c3..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.statsEstimation - -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.IntegerType - - -class StatsConfSuite extends StatsEstimationTestBase { - test("estimate statistics when the conf changes") { - val expectedDefaultStats = - Statistics( - sizeInBytes = 40, - rowCount = Some(10), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), - isBroadcastable = false) - val expectedCboStats = - Statistics( - sizeInBytes = 4, - rowCount = Some(1), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), - isBroadcastable = false) - - val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - // Return the statistics estimated by cbo - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats) - // Invalidate statistics - plan.invalidateStatsCache() - // Return the simple statistics - assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats) - } -} - -/** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes - * a simple statistics or a cbo estimated statistics based on the conf. - */ -private case class DummyLogicalPlan( - defaultStats: Statistics, - cboStats: Statistics) extends LogicalPlan { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: CatalystConf): Statistics = - if (conf.cboEnabled) cboStats else defaultStats -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index bbb31dbc8f3de43821ef9a7864277384521ff269..1f547c5a2a8ff628bb251b2bb70c463b40e92423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared spark.sessionState.conf.autoBroadcastJoinThreshold) } - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.stats(conf).sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.stats(conf).sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - test("column stats round trip serialization") { // Make sure we serialize and then deserialize and we will get the result data val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)