From 2287f3d0b85730995bedc489a017de5700d6e1e4 Mon Sep 17 00:00:00 2001
From: wangzhenhua <wangzhenhua@huawei.com>
Date: Sat, 1 Apr 2017 22:19:08 +0800
Subject: [PATCH] [SPARK-20186][SQL] BroadcastHint should use child's stats

## What changes were proposed in this pull request?

`BroadcastHint` should use child's statistics and set `isBroadcastable` to true.

## How was this patch tested?

Added a new stats estimation test for `BroadcastHint`.

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #17504 from wzhfy/broadcastHintEstimation.
---
 .../plans/logical/basicLogicalOperators.scala |  2 +-
 .../BasicStatsEstimationSuite.scala           | 21 ++++++++++++++++++-
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 5cbf263d1c..19db42c808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -383,7 +383,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
 
   // set isBroadcastable to true so the child will be broadcasted
   override def computeStats(conf: CatalystConf): Statistics =
-    super.computeStats(conf).copy(isBroadcastable = true)
+    child.stats(conf).copy(isBroadcastable = true)
 }
 
 /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index e5dc811c8b..0d92c1e355 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -35,6 +35,23 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
     // row count * (overhead + column size)
     size = Some(10 * (8 + 4)))
 
+  test("BroadcastHint estimation") {
+    val filter = Filter(Literal(true), plan)
+    val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false,
+      rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat)))
+    val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false)
+    checkStats(
+      filter,
+      expectedStatsCboOn = filterStatsCboOn,
+      expectedStatsCboOff = filterStatsCboOff)
+
+    val broadcastHint = BroadcastHint(filter)
+    checkStats(
+      broadcastHint,
+      expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
+      expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true))
+  }
+
   test("limit estimation: limit < child's rowCount") {
     val localLimit = LocalLimit(Literal(2), plan)
     val globalLimit = GlobalLimit(Literal(2), plan)
@@ -97,8 +114,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
       plan: LogicalPlan,
       expectedStatsCboOn: Statistics,
       expectedStatsCboOff: Statistics): Unit = {
-    assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
     // Invalidate statistics
+    plan.invalidateStatsCache()
+    assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
+
     plan.invalidateStatsCache()
     assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
   }
-- 
GitLab