diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3353beb09c5df24fa8fc89a52a416e5a1d67def..d4fc9e4da944aebd97539e1fc0b9fb46fda5800b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { }) } + private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = { + val common = a.intersect(b) + // The constraint with only one reference could be easily inferred as predicate + // Grouping the constraints by it's references so we can combine the constraints with same + // reference together + val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) + val others = (othera.keySet intersect otherb.keySet).map { attr => + Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + } + common ++ others + } + override protected def validConstraints: Set[Expression] = { children .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) - .reduce(_ intersect _) + .reduce(merge(_, _)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 49c1353efb635b50268d30ed4b94e3b81cb2bbf2..81cc6b123cdd472a19e5e553ec37f5c70152105e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite { .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) + + val a = resolveColumn(tr1, "a") + verifyConstraints(tr1 + .where('a.attr > 10) + .union(tr2.where('d.attr > 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) + + val b = resolveColumn(tr1, "b") + verifyConstraints(tr1 + .where('a.attr > 10 && 'b.attr < 10) + .union(tr2.where('d.attr > 11 && 'e.attr < 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) } test("propagating constraints in intersect") {