Skip to content
Snippets Groups Projects
Commit 3e84ef0a authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-12770][SQL] Implement rules for branch elimination for CaseWhen

The three optimization cases are:

1. If the first branch's condition is a true literal, remove the CaseWhen and use the value from that branch.
2. If a branch's condition is a false or null literal, remove that branch.
3. If only the else branch is left, remove the CaseWhen and use the value from the else branch.

Author: Reynold Xin <rxin@databricks.com>

Closes #10827 from rxin/SPARK-12770.
parent f6f7ca9d
No related branches found
No related tags found
No related merge requests found
...@@ -635,6 +635,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { ...@@ -635,6 +635,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case q: LogicalPlan => q transformExpressionsUp { case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue case If(FalseLiteral, _, falseValue) => falseValue
case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
// Note that these two are handled together here in a single case statement because
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
val newBranches = branches.filter(_._1 != FalseLiteral)
if (newBranches.isEmpty) {
elseValue.getOrElse(Literal.create(null, e.dataType))
} else {
e.copy(branches = newBranches)
}
case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
// If the first branch is a true literal, remove the entire CaseWhen and use the value
// from that. Note that CaseWhen.branches should never be empty, and as a result the
// headOption (rather than head) added above is just a extra (and unnecessary) safeguard.
branches.head._2
} }
} }
} }
......
...@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite ...@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.IntegerType
class SimplifyConditionalSuite extends PlanTest with PredicateHelper { class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
...@@ -37,6 +38,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { ...@@ -37,6 +38,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
comparePlans(actual, correctAnswer) comparePlans(actual, correctAnswer)
} }
private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val unreachableBranch = (FalseLiteral, Literal(20))
test("simplify if") { test("simplify if") {
assertEquivalent( assertEquivalent(
If(TrueLiteral, Literal(10), Literal(20)), If(TrueLiteral, Literal(10), Literal(20)),
...@@ -47,4 +52,36 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { ...@@ -47,4 +52,36 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Literal(20)) Literal(20))
} }
test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
CaseWhen(normalBranch :: Nil, None))
}
test("remove entire CaseWhen if only the else branch is reachable") {
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
Literal(30))
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
Literal.create(null, IntegerType))
}
test("remove entire CaseWhen if the first branch is always true") {
assertEquivalent(
CaseWhen(trueBranch :: normalBranch :: Nil, None),
Literal(5))
// Test branch elimination and simplification in combination
assertEquivalent(
CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
Literal(5))
// Make sure this doesn't trigger if there is a non-foldable branch before the true branch
assertEquivalent(
CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment