Skip to content
Snippets Groups Projects
Commit 9909f6d3 authored by wangzhenhua's avatar wangzhenhua Committed by Xiao Li
Browse files

[SPARK-19350][SQL] Cardinality estimation of Limit and Sample

## What changes were proposed in this pull request?

Before this pr, LocalLimit/GlobalLimit/Sample propagates the same row count and column stats from its child, which is incorrect.
We can get the correct rowCount in Statistics for GlobalLimit/Sample whether cbo is enabled or not.
We don't know the rowCount for LocalLimit because we don't know the partition number at that time. Column stats should not be propagated because we don't know the distribution of columns after Limit or Sample.

## How was this patch tested?

Added test cases.

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #16696 from wzhfy/limitEstimation.
parent b0a5cd89
No related branches found
No related tags found
No related merge requests found
......@@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
}
child.stats(conf).copy(sizeInBytes = sizeInBytes)
val childStats = child.stats(conf)
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
// Don't propagate column stats, because we don't know the distribution after a limit operation
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats),
rowCount = Some(rowCount),
isBroadcastable = childStats.isBroadcastable)
}
}
......@@ -773,14 +772,21 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = if (limit == 0) {
val childStats = child.stats(conf)
if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
Statistics(
sizeInBytes = 1,
rowCount = Some(0),
isBroadcastable = childStats.isBroadcastable)
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
// The output row count of LocalLimit should be the sum of row counts from each partition.
// However, since the number of partitions is not available here, we just use statistics of
// the child. Because the distribution after a limit operation is unknown, we do not propagate
// the column stats.
childStats.copy(attributeStats = AttributeMap(Nil))
}
child.stats(conf).copy(sizeInBytes = sizeInBytes)
}
}
......@@ -816,12 +822,14 @@ case class Sample(
override def computeStats(conf: CatalystConf): Statistics = {
val ratio = upperBound - lowerBound
// BigInt can't multiply with Double
var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100
val childStats = child.stats(conf)
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) {
sizeInBytes = 1
}
child.stats(conf).copy(sizeInBytes = sizeInBytes)
val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio))
// Don't propagate column stats, because we don't know the distribution after a sample operation
Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable)
}
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
......
......@@ -18,12 +18,59 @@
package org.apache.spark.sql.catalyst.statsEstimation
import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
class StatsConfSuite extends StatsEstimationTestBase {
class BasicStatsEstimationSuite extends StatsEstimationTestBase {
val attribute = attr("key")
val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
val plan = StatsTestPlan(
outputList = Seq(attribute),
attributeStats = AttributeMap(Seq(attribute -> colStat)),
rowCount = 10,
// row count * (overhead + column size)
size = Some(10 * (8 + 4)))
test("limit estimation: limit < child's rowCount") {
val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan)
// LocalLimit's stats is just its child's stats except column stats
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
}
test("limit estimation: limit > child's rowCount") {
val localLimit = LocalLimit(Literal(20), plan)
val globalLimit = GlobalLimit(Literal(20), plan)
checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
// Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats.
checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
}
test("limit estimation: limit = 0") {
val localLimit = LocalLimit(Literal(0), plan)
val globalLimit = GlobalLimit(Literal(0), plan)
val stats = Statistics(sizeInBytes = 1, rowCount = Some(0))
checkStats(localLimit, stats)
checkStats(globalLimit, stats)
}
test("sample estimation") {
val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)()
checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5)))
// Child doesn't have rowCount in stats
val childStats = Statistics(sizeInBytes = 120)
val childPlan = DummyLogicalPlan(childStats, childStats)
val sample2 =
Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)()
checkStats(sample2, Statistics(sizeInBytes = 14))
}
test("estimate statistics when the conf changes") {
val expectedDefaultStats =
Statistics(
......@@ -41,13 +88,24 @@ class StatsConfSuite extends StatsEstimationTestBase {
isBroadcastable = false)
val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
// Return the statistics estimated by cbo
assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats)
checkStats(
plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats)
}
/** Check estimated stats when cbo is turned on/off. */
private def checkStats(
plan: LogicalPlan,
expectedStatsCboOn: Statistics,
expectedStatsCboOff: Statistics): Unit = {
assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
// Invalidate statistics
plan.invalidateStatsCache()
// Return the simple statistics
assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats)
assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
}
/** Check estimated stats when it's the same whether cbo is turned on or off. */
private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit =
checkStats(plan, expectedStats, expectedStats)
}
/**
......
......@@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
test("estimates the size of limit") {
withTempView("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
val df = sql(s"""SELECT * FROM test limit $limit""")
val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
g.stats(conf).sizeInBytes
}
assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesGlobalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
l.stats(conf).sizeInBytes
}
assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesLocalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
}
}
}
test("column stats round trip serialization") {
// Make sure we serialize and then deserialize and we will get the result data
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment