diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d3353beb09c5df24fa8fc89a52a416e5a1d67def..d4fc9e4da944aebd97539e1fc0b9fb46fda5800b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
     })
   }
 
+  private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = {
+    val common = a.intersect(b)
+    // The constraint with only one reference could be easily inferred as predicate
+    // Grouping the constraints by it's references so we can combine the constraints with same
+    // reference together
+    val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+    val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+    // loose the constraints by: A1 && B1 || A2 && B2  ->  (A1 || A2) && (B1 || B2)
+    val others = (othera.keySet intersect otherb.keySet).map { attr =>
+      Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
+    }
+    common ++ others
+  }
+
   override protected def validConstraints: Set[Expression] = {
     children
       .map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
-      .reduce(_ intersect _)
+      .reduce(merge(_, _))
   }
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 49c1353efb635b50268d30ed4b94e3b81cb2bbf2..81cc6b123cdd472a19e5e553ec37f5c70152105e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite {
       .analyze.constraints,
       ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
         IsNotNull(resolveColumn(tr1, "a")))))
+
+    val a = resolveColumn(tr1, "a")
+    verifyConstraints(tr1
+      .where('a.attr > 10)
+      .union(tr2.where('d.attr > 11))
+      .analyze.constraints,
+      ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a))))
+
+    val b = resolveColumn(tr1, "b")
+    verifyConstraints(tr1
+      .where('a.attr > 10 && 'b.attr < 10)
+      .union(tr2.where('d.attr > 11 && 'e.attr < 11))
+      .analyze.constraints,
+      ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b))))
   }
 
   test("propagating constraints in intersect") {