diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 5cbf263d1ce42d8617410b9fad0dc54800f34543..19db42c80895cfce21514bbd06b3ea8f55600e03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -383,7 +383,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
 
   // set isBroadcastable to true so the child will be broadcasted
   override def computeStats(conf: CatalystConf): Statistics =
-    super.computeStats(conf).copy(isBroadcastable = true)
+    child.stats(conf).copy(isBroadcastable = true)
 }
 
 /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index e5dc811c8b7db3c2d87da949fffa7f3e6f31bba1..0d92c1e35565abb19296c7305c786e5c56a25a44 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -35,6 +35,23 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
     // row count * (overhead + column size)
     size = Some(10 * (8 + 4)))
 
+  test("BroadcastHint estimation") {
+    val filter = Filter(Literal(true), plan)
+    val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false,
+      rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat)))
+    val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false)
+    checkStats(
+      filter,
+      expectedStatsCboOn = filterStatsCboOn,
+      expectedStatsCboOff = filterStatsCboOff)
+
+    val broadcastHint = BroadcastHint(filter)
+    checkStats(
+      broadcastHint,
+      expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
+      expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true))
+  }
+
   test("limit estimation: limit < child's rowCount") {
     val localLimit = LocalLimit(Literal(2), plan)
     val globalLimit = GlobalLimit(Literal(2), plan)
@@ -97,8 +114,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
       plan: LogicalPlan,
       expectedStatsCboOn: Statistics,
       expectedStatsCboOff: Statistics): Unit = {
-    assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
     // Invalidate statistics
+    plan.invalidateStatsCache()
+    assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
+
     plan.invalidateStatsCache()
     assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
   }