From fffeb6d7c37ee673a32584f3b2fd3afe86af793a Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@databricks.com> Date: Wed, 14 Jun 2017 22:11:41 -0700 Subject: [PATCH] [SPARK-21092][SQL] Wire SQLConf in logical plan and expressions ## What changes were proposed in this pull request? It is really painful to not have configs in logical plan and expressions. We had to add all sorts of hacks (e.g. pass SQLConf explicitly in functions). This patch exposes SQLConf in logical plan, using a thread local variable and a getter closure that's set once there is an active SparkSession. The implementation is a bit of a hack, since we didn't anticipate this need in the beginning (config was only exposed in physical plan). The implementation is described in `SQLConf.get`. In terms of future work, we should follow up to clean up CBO (remove the need for passing in config). ## How was this patch tested? Updated relevant tests for constraint propagation. Author: Reynold Xin <rxin@databricks.com> Closes #18299 from rxin/SPARK-21092. --- .../sql/catalyst/optimizer/Optimizer.scala | 25 +++++------ .../spark/sql/catalyst/optimizer/joins.scala | 5 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 3 ++ .../catalyst/plans/QueryPlanConstraints.scala | 33 +++++---------- .../apache/spark/sql/internal/SQLConf.scala | 42 +++++++++++++++++++ .../BinaryComparisonSimplificationSuite.scala | 2 +- .../BooleanSimplificationSuite.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 24 +++++------ .../optimizer/OuterJoinEliminationSuite.scala | 37 ++++++++-------- .../PropagateEmptyRelationSuite.scala | 4 +- .../optimizer/PruneFiltersSuite.scala | 36 +++++++--------- .../optimizer/SetOperationSuite.scala | 2 +- .../plans/ConstraintPropagationSuite.scala | 29 ++++++++----- .../org/apache/spark/sql/SparkSession.scala | 5 +++ 14 files changed, 141 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d16689a342..3ab70fb904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -77,12 +77,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), - EliminateOuterJoin(conf), + EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, // Operator combine CollapseRepartition, CollapseProject, @@ -102,7 +102,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - PruneFilters(conf), + PruneFilters, EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -619,14 +619,15 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -case class InferFiltersFromConstraints(conf: SQLConf) - extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { - inferFilters(plan) - } else { - plan - } +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + } private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, child) => @@ -717,7 +718,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -730,7 +731,7 @@ case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateH case f @ Filter(fc, p: LogicalPlan) => val (prunedPredicates, remainingPredicates) = splitConjunctivePredicates(fc).partition { cond => - cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond) + cond.deterministic && p.constraints.contains(cond) } if (prunedPredicates.isEmpty) { f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2fe3039774..bb97e2c808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -113,7 +113,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. @@ -129,8 +129,7 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred } private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val conditions = splitConjunctivePredicates(filter.condition) ++ - filter.getConstraints(conf.constraintPropagationEnabled) + val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 8bc462e1e7..9130b14763 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] @@ -27,6 +28,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] self: PlanType => + def conf: SQLConf = SQLConf.get + def output: Seq[Attribute] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala index 7d8a17d977..b08a009f0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala @@ -27,18 +27,20 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. */ - lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) - - /** - * Returns [[constraints]] depending on the config of enabling constraint propagation. If the - * flag is disabled, simply returning an empty constraints. - */ - def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = - if (constraintPropagationEnabled) { - constraints + lazy val constraints: ExpressionSet = { + if (conf.constraintPropagationEnabled) { + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } + } /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -50,19 +52,6 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl */ protected def validConstraints: Set[Expression] = Set.empty - /** - * Extracts the relevant constraints from a given set of constraints based on the attributes that - * appear in the [[outputSet]]. - */ - protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && - constraint.deterministic) - } - /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9f7c760fb9..6ab3a615e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.immutable @@ -64,6 +65,47 @@ object SQLConf { } } + /** + * Default config. Only used when there is no active SparkSession for the thread. + * See [[get]] for more information. + */ + private val fallbackConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = new SQLConf + } + + /** See [[get]] for more information. */ + def getFallbackConf: SQLConf = fallbackConf.get() + + /** + * Defines a getter that returns the SQLConf within scope. + * See [[get]] for more information. + */ + private val confGetter = new AtomicReference[() => SQLConf](() => fallbackConf.get()) + + /** + * Sets the active config object within the current scope. + * See [[get]] for more information. + */ + def setSQLConfGetter(getter: () => SQLConf): Unit = { + confGetter.set(getter) + } + + /** + * Returns the active config object within the current scope. If there is an active SparkSession, + * the proper SQLConf associated with the thread's session is used. + * + * The way this works is a little bit convoluted, due to the fact that config was added initially + * only for physical plans (and as a result not in sql/catalyst module). + * + * The first time a SparkSession is instantiated, we set the [[confGetter]] to return the + * active SparkSession's config. If there is no active SparkSession, it returns using the thread + * local [[fallbackConf]]. The reason [[fallbackConf]] is a thread local (rather than just a conf) + * is to support setting different config options for different threads so we can potentially + * run tests in parallel. At the time this feature was implemented, this was a no-op since we + * run unit tests (that does not involve SparkSession) in serial order. + */ + def get: SQLConf = confGetter.get()() + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index b29e1cbd14..2a04bd588d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -37,7 +37,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val nullableRelation = LocalRelation('a.int.withNullability(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index c275f997ba..1df0a89cf0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -38,7 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { NullPropagation(conf), ConstantFolding, BooleanSimplification, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9a4bcdb011..cdc9f25cf8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class InferFiltersFromConstraintsSuite extends PlanTest { @@ -32,20 +32,11 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, CombineFilters, BooleanSimplification) :: Nil } - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("InferAndPushDownFilters", FixedPoint(100), - PushPredicateThroughJoin, - PushDownPredicate, - InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), - CombineFilters) :: Nil - } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("filter: filter out constraints in condition") { @@ -215,8 +206,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("No inferred filter when constraint propagation is disabled") { - val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery) - comparePlans(optimized, originalQuery) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index b7136703b7..a37bc4bca2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -32,16 +32,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf), - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + EliminateOuterJoin, PushPredicateThroughJoin) :: Nil } @@ -243,19 +234,25 @@ class OuterJoinEliminationSuite extends PlanTest { } test("no outer join elimination if constraint propagation is disabled") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - // The predicate "x.b + y.d >= 3" will be inferred constraints like: - // "x.b != null" and "y.d != null", if constraint propagation is enabled. - // When we disable it, the predicate can't be evaluated on left or right plan and used to - // filter out nulls. So the Outer Join will not be eliminated. - val originalQuery = + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + // The predicate "x.b + y.d >= 3" will be inferred constraints like: + // "x.b != null" and "y.d != null", if constraint propagation is enabled. + // When we disable it, the predicate can't be evaluated on left or right plan and used to + // filter out nulls. So the Outer Join will not be eliminated. + val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) .where("x.b".attr + "y.d".attr >= 3) - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(optimized, originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 38dff4733f..2285be1693 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -33,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf), + PruneFilters, PropagateEmptyRelation) :: Nil } @@ -45,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 741dd0cf42..706634cdd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -34,18 +35,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(conf), - PushDownPredicate, - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Filter Pushdown and Pruning", Once, - CombineFilters, - PruneFilters(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + PruneFilters, PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -159,15 +149,19 @@ class PruneFiltersSuite extends PlanTest { ("tr1.a".attr > 10 || "tr1.c".attr < 10) && 'd.attr < 100) - val optimized = - OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze) - // When constraint propagation is disabled, the useless filter won't be pruned. - // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant - // and duplicate filters. - val correctAnswer = tr1 - .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) - .join(tr2.where('d.attr < 100).where('d.attr < 100), + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + try { + val optimized = Optimize.execute(queryWithUselessFilter.analyze) + // When constraint propagation is disabled, the useless filter won't be pruned. + // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant + // and duplicate filters. + val correctAnswer = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100).where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 756e0f35b2..21b7f49e14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -34,7 +34,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 4061394b86..a3948d90b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -399,20 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("enable/disable constraint propagation") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + try { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) - verifyConstraints( - filterRelation.analyze.getConstraints(constraintPropagationEnabled = true), - filterRelation.analyze.constraints) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(filterRelation.analyze.constraints.nonEmpty) - assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(filterRelation.analyze.constraints.isEmpty) - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) - verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true), - aliasedRelation.analyze.constraints) - assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(aliasedRelation.analyze.constraints.nonEmpty) + + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(aliasedRelation.analyze.constraints.isEmpty) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d2bf350711..2c38f7d7c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -87,6 +87,11 @@ class SparkSession private( sparkContext.assertNotStopped() + // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. + SQLConf.setSQLConfGetter(() => { + SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + }) + /** * The version of Spark on which this application is running. * -- GitLab