Skip to content
Snippets Groups Projects
Commit 7d2a7a91 authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-3235][SQL] Ensure in-memory tables don't always broadcast.

Author: Michael Armbrust <michael@databricks.com>

Closes #2147 from marmbrus/inMemDefaultSize and squashes the following commits:

5390360 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into inMemDefaultSize
14204d3 [Michael Armbrust] Set the context before creating SparkLogicalPlans.
8da4414 [Michael Armbrust] Make sure we throw errors when leaf nodes fail to provide statistcs
18ce029 [Michael Armbrust] Ensure in-memory tables don't always broadcast.
parent 65253502
No related branches found
No related tags found
No related merge requests found
...@@ -41,9 +41,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { ...@@ -41,9 +41,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
case class Statistics( case class Statistics(
sizeInBytes: BigInt sizeInBytes: BigInt
) )
lazy val statistics: Statistics = Statistics( lazy val statistics: Statistics = {
sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product if (children.size == 0) {
) throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
Statistics(
sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product)
}
/** /**
* Returns the set of attributes that this node takes as * Returns the set of attributes that this node takes as
...@@ -117,9 +122,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { ...@@ -117,9 +122,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
*/ */
abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
self: Product => self: Product =>
override lazy val statistics: Statistics =
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
} }
/** /**
......
...@@ -89,8 +89,10 @@ class SQLContext(@transient val sparkContext: SparkContext) ...@@ -89,8 +89,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
* *
* @group userf * @group userf
*/ */
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
SparkPlan.currentContext.set(self)
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self))
}
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
......
...@@ -39,6 +39,9 @@ private[sql] case class InMemoryRelation( ...@@ -39,6 +39,9 @@ private[sql] case class InMemoryRelation(
(private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null)
extends LogicalPlan with MultiInstanceRelation { extends LogicalPlan with MultiInstanceRelation {
override lazy val statistics =
Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
// If the cached column buffers were not passed in, we calculate them in the constructor. // If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy. // As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) { if (_cachedColumnBuffers == null) {
......
...@@ -49,7 +49,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ ...@@ -49,7 +49,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* populated by the query planning infrastructure. * populated by the query planning infrastructure.
*/ */
@transient @transient
protected val sqlContext = SparkPlan.currentContext.get() protected[spark] val sqlContext = SparkPlan.currentContext.get()
protected def sparkContext = sqlContext.sparkContext protected def sparkContext = sqlContext.sparkContext
......
...@@ -33,6 +33,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { ...@@ -33,6 +33,14 @@ class InMemoryColumnarQuerySuite extends QueryTest {
checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq)
} }
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst")
cacheTable("sizeTst")
assert(
table("sizeTst").queryExecution.logical.statistics.sizeInBytes > autoBroadcastJoinThreshold)
}
test("projection") { test("projection") {
val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, plan) val scan = InMemoryRelation(useCompression = true, 5, plan)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment