diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 0bc893224026eb214592434a324439f04a5ff25c..6006e7bf00c1326e88e0392cb56600c4c67ffcbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import javax.annotation.Nullable
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -713,39 +715,68 @@ object HiveTypeCoercion {
 
       case e: ExpectsInputTypes =>
         val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
-          implicitCast(in, expected)
+          // If we cannot do the implicit cast, just use the original input.
+          implicitCast(in, expected).getOrElse(in)
         }
         e.withNewChildren(children)
     }
 
     /**
-     * If needed, cast the expression into the expected type.
-     * If the implicit cast is not allowed, return the expression itself.
+     * Given an expected data type, try to cast the expression and return the cast expression.
+     *
+     * If the expression already fits the input type, we simply return the expression itself.
+     * If the expression has an incompatible type that cannot be implicitly cast, return None.
      */
-    def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = {
+    def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
       val inType = e.dataType
-      (inType, expectedType) match {
+
+      // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
+      // We wrap immediately an Option after this.
+      @Nullable val ret: Expression = (inType, expectedType) match {
+
+        // If the expected type is already a parent of the input type, no need to cast.
+        case _ if expectedType.isParentOf(inType) => e
+
         // Cast null type (usually from null literals) into target types
-        case (NullType, target: DataType) => Cast(e, target.defaultConcreteType)
+        case (NullType, target) => Cast(e, target.defaultConcreteType)
 
         // Implicit cast among numeric types
+        // If input is decimal, and we expect a decimal type, just use the input.
+        case (_: DecimalType, DecimalType) => e
+        // If input is a numeric type but not decimal, and we expect a decimal type,
+        // cast the input to unlimited precision decimal.
+        case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
+          Cast(e, DecimalType.Unlimited)
+        // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
         case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
+        case (_: NumericType, target: NumericType) => e
 
         // Implicit cast between date time types
         case (DateType, TimestampType) => Cast(e, TimestampType)
         case (TimestampType, DateType) => Cast(e, DateType)
 
         // Implicit cast from/to string
-        case (StringType, NumericType) => Cast(e, DoubleType)
+        case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited)
         case (StringType, target: NumericType) => Cast(e, target)
         case (StringType, DateType) => Cast(e, DateType)
         case (StringType, TimestampType) => Cast(e, TimestampType)
         case (StringType, BinaryType) => Cast(e, BinaryType)
         case (any, StringType) if any != StringType => Cast(e, StringType)
 
+        // Type collection.
+        // First see if we can find our input type in the type collection. If we can, then just
+        // use the current expression; otherwise, find the first one we can implicitly cast.
+        case (_, TypeCollection(types)) =>
+          if (types.exists(_.isParentOf(inType))) {
+            e
+          } else {
+            types.flatMap(implicitCast(e, _)).headOption.orNull
+          }
+
         // Else, just return the same input expression
-        case _ => e
+        case _ => null
       }
+      Option(ret)
     }
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 43e2f8a46e62e237fac329f1e4c9c6c7b13c15e0..e5dc99fb625d8d320d8d6548a29f01d9f7c837a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -28,7 +28,45 @@ import org.apache.spark.util.Utils
  * A non-concrete data type, reserved for internal uses.
  */
 private[sql] abstract class AbstractDataType {
+  /**
+   * The default concrete type to use if we want to cast a null literal into this type.
+   */
   private[sql] def defaultConcreteType: DataType
+
+  /**
+   * Returns true if this data type is a parent of the `childCandidate`.
+   */
+  private[sql] def isParentOf(childCandidate: DataType): Boolean
+}
+
+
+/**
+ * A collection of types that can be used to specify type constraints. The sequence also specifies
+ * precedence: an earlier type takes precedence over a latter type.
+ *
+ * {{{
+ *   TypeCollection(StringType, BinaryType)
+ * }}}
+ *
+ * This means that we prefer StringType over BinaryType if it is possible to cast to StringType.
+ */
+private[sql] class TypeCollection(private val types: Seq[DataType]) extends AbstractDataType {
+  require(types.nonEmpty, s"TypeCollection ($types) cannot be empty")
+
+  private[sql] override def defaultConcreteType: DataType = types.head
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
+}
+
+
+private[sql] object TypeCollection {
+
+  def apply(types: DataType*): TypeCollection = new TypeCollection(types)
+
+  def unapply(typ: AbstractDataType): Option[Seq[DataType]] = typ match {
+    case typ: TypeCollection => Some(typ.types)
+    case _ => None
+  }
 }
 
 
@@ -61,7 +99,7 @@ abstract class NumericType extends AtomicType {
 }
 
 
-private[sql] object NumericType extends AbstractDataType {
+private[sql] object NumericType {
   /**
    * Enables matching against NumericType for expressions:
    * {{{
@@ -70,12 +108,10 @@ private[sql] object NumericType extends AbstractDataType {
    * }}}
    */
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
-
-  private[sql] override def defaultConcreteType: DataType = IntegerType
 }
 
 
-private[sql] object IntegralType extends AbstractDataType {
+private[sql] object IntegralType {
   /**
    * Enables matching against IntegralType for expressions:
    * {{{
@@ -84,8 +120,6 @@ private[sql] object IntegralType extends AbstractDataType {
    * }}}
    */
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
-
-  private[sql] override def defaultConcreteType: DataType = IntegerType
 }
 
 
@@ -94,7 +128,7 @@ private[sql] abstract class IntegralType extends NumericType {
 }
 
 
-private[sql] object FractionalType extends AbstractDataType {
+private[sql] object FractionalType {
   /**
    * Enables matching against FractionalType for expressions:
    * {{{
@@ -103,8 +137,6 @@ private[sql] object FractionalType extends AbstractDataType {
    * }}}
    */
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType]
-
-  private[sql] override def defaultConcreteType: DataType = DoubleType
 }
 
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 81553e7fc91a85e4d6c4a8c674f7d8042f63cccf..8ea6cb14c360ea38d8bf30d667f394caace16835 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -26,7 +26,11 @@ object ArrayType extends AbstractDataType {
   /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
   def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true)
 
-  override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+  private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[ArrayType]
+  }
 }
 
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index c333fa70d1ef4e007b3fec89a5c13b2c332947db..7d00047d08d742d5c8e6a2a31475895c2972a235 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -75,7 +75,9 @@ abstract class DataType extends AbstractDataType {
    */
   private[spark] def asNullable: DataType
 
-  override def defaultConcreteType: DataType = this
+  private[sql] override def defaultConcreteType: DataType = this
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate
 }
 
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 06373a095b1b0d53d7d2aec0a40dcb053f8e2353..434fc037aad4f0db399dae66da870b8c5c9259aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -86,6 +86,10 @@ object DecimalType extends AbstractDataType {
 
   private[sql] override def defaultConcreteType: DataType = Unlimited
 
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[DecimalType]
+  }
+
   val Unlimited: DecimalType = DecimalType(None)
 
   private[sql] object Fixed {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 69c2119e23436e0d9f59fd7d54a24732492814f7..2b25617ec6655d638b94fa4ee9c3f92f5fe2b570 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -71,6 +71,10 @@ object MapType extends AbstractDataType {
 
   private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType)
 
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[MapType]
+  }
+
   /**
    * Construct a [[MapType]] object with the given key type and value type.
    * The `valueContainsNull` is true.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 6fedeabf23203a48d27d239ff4dfdf51c826b09b..7e77b77e739403aeb383f3fdea175eb7034f0df6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -301,7 +301,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
 }
 
 
-object StructType {
+object StructType extends AbstractDataType {
+
+  private[sql] override def defaultConcreteType: DataType = new StructType
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[StructType]
+  }
 
   def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 498fd86a06fd91c5384c78fae858c591ddf0642b..60e727c6c7d4de8b90a84c5774a1785064045d9f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -27,28 +27,47 @@ import org.apache.spark.sql.types._
 class HiveTypeCoercionSuite extends PlanTest {
 
   test("implicit type cast") {
-    def shouldCast(from: DataType, to: AbstractDataType): Unit = {
+    def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
       val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
-      assert(got.dataType === to.defaultConcreteType)
+      assert(got.map(_.dataType) == Option(expected),
+        s"Failed to cast $from to $to")
     }
 
+    shouldCast(NullType, NullType, NullType)
+    shouldCast(NullType, IntegerType, IntegerType)
+    shouldCast(NullType, DecimalType, DecimalType.Unlimited)
+
     // TODO: write the entire implicit cast table out for test cases.
-    shouldCast(ByteType, IntegerType)
-    shouldCast(IntegerType, IntegerType)
-    shouldCast(IntegerType, LongType)
-    shouldCast(IntegerType, DecimalType.Unlimited)
-    shouldCast(LongType, IntegerType)
-    shouldCast(LongType, DecimalType.Unlimited)
-
-    shouldCast(DateType, TimestampType)
-    shouldCast(TimestampType, DateType)
-
-    shouldCast(StringType, IntegerType)
-    shouldCast(StringType, DateType)
-    shouldCast(StringType, TimestampType)
-    shouldCast(IntegerType, StringType)
-    shouldCast(DateType, StringType)
-    shouldCast(TimestampType, StringType)
+    shouldCast(ByteType, IntegerType, IntegerType)
+    shouldCast(IntegerType, IntegerType, IntegerType)
+    shouldCast(IntegerType, LongType, LongType)
+    shouldCast(IntegerType, DecimalType, DecimalType.Unlimited)
+    shouldCast(LongType, IntegerType, IntegerType)
+    shouldCast(LongType, DecimalType, DecimalType.Unlimited)
+
+    shouldCast(DateType, TimestampType, TimestampType)
+    shouldCast(TimestampType, DateType, DateType)
+
+    shouldCast(StringType, IntegerType, IntegerType)
+    shouldCast(StringType, DateType, DateType)
+    shouldCast(StringType, TimestampType, TimestampType)
+    shouldCast(IntegerType, StringType, StringType)
+    shouldCast(DateType, StringType, StringType)
+    shouldCast(TimestampType, StringType, StringType)
+
+    shouldCast(StringType, BinaryType, BinaryType)
+    shouldCast(BinaryType, StringType, StringType)
+
+    shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType)
+
+    shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType)
+    shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType)
+    shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType)
+
+    shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType)
+    shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
+    shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
+    shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
   }
 
   test("tightest common bound for types") {