diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 0bc893224026eb214592434a324439f04a5ff25c..6006e7bf00c1326e88e0392cb56600c4c67ffcbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import javax.annotation.Nullable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule @@ -713,39 +715,68 @@ object HiveTypeCoercion { case e: ExpectsInputTypes => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => - implicitCast(in, expected) + // If we cannot do the implicit cast, just use the original input. + implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) } /** - * If needed, cast the expression into the expected type. - * If the implicit cast is not allowed, return the expression itself. + * Given an expected data type, try to cast the expression and return the cast expression. + * + * If the expression already fits the input type, we simply return the expression itself. + * If the expression has an incompatible type that cannot be implicitly cast, return None. */ - def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = { + def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { val inType = e.dataType - (inType, expectedType) match { + + // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. + // We wrap immediately an Option after this. + @Nullable val ret: Expression = (inType, expectedType) match { + + // If the expected type is already a parent of the input type, no need to cast. + case _ if expectedType.isParentOf(inType) => e + // Cast null type (usually from null literals) into target types - case (NullType, target: DataType) => Cast(e, target.defaultConcreteType) + case (NullType, target) => Cast(e, target.defaultConcreteType) // Implicit cast among numeric types + // If input is decimal, and we expect a decimal type, just use the input. + case (_: DecimalType, DecimalType) => e + // If input is a numeric type but not decimal, and we expect a decimal type, + // cast the input to unlimited precision decimal. + case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => + Cast(e, DecimalType.Unlimited) + // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) + case (_: NumericType, target: NumericType) => e // Implicit cast between date time types case (DateType, TimestampType) => Cast(e, TimestampType) case (TimestampType, DateType) => Cast(e, DateType) // Implicit cast from/to string - case (StringType, NumericType) => Cast(e, DoubleType) + case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited) case (StringType, target: NumericType) => Cast(e, target) case (StringType, DateType) => Cast(e, DateType) case (StringType, TimestampType) => Cast(e, TimestampType) case (StringType, BinaryType) => Cast(e, BinaryType) case (any, StringType) if any != StringType => Cast(e, StringType) + // Type collection. + // First see if we can find our input type in the type collection. If we can, then just + // use the current expression; otherwise, find the first one we can implicitly cast. + case (_, TypeCollection(types)) => + if (types.exists(_.isParentOf(inType))) { + e + } else { + types.flatMap(implicitCast(e, _)).headOption.orNull + } + // Else, just return the same input expression - case _ => e + case _ => null } + Option(ret) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 43e2f8a46e62e237fac329f1e4c9c6c7b13c15e0..e5dc99fb625d8d320d8d6548a29f01d9f7c837a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -28,7 +28,45 @@ import org.apache.spark.util.Utils * A non-concrete data type, reserved for internal uses. */ private[sql] abstract class AbstractDataType { + /** + * The default concrete type to use if we want to cast a null literal into this type. + */ private[sql] def defaultConcreteType: DataType + + /** + * Returns true if this data type is a parent of the `childCandidate`. + */ + private[sql] def isParentOf(childCandidate: DataType): Boolean +} + + +/** + * A collection of types that can be used to specify type constraints. The sequence also specifies + * precedence: an earlier type takes precedence over a latter type. + * + * {{{ + * TypeCollection(StringType, BinaryType) + * }}} + * + * This means that we prefer StringType over BinaryType if it is possible to cast to StringType. + */ +private[sql] class TypeCollection(private val types: Seq[DataType]) extends AbstractDataType { + require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") + + private[sql] override def defaultConcreteType: DataType = types.head + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = false +} + + +private[sql] object TypeCollection { + + def apply(types: DataType*): TypeCollection = new TypeCollection(types) + + def unapply(typ: AbstractDataType): Option[Seq[DataType]] = typ match { + case typ: TypeCollection => Some(typ.types) + case _ => None + } } @@ -61,7 +99,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType extends AbstractDataType { +private[sql] object NumericType { /** * Enables matching against NumericType for expressions: * {{{ @@ -70,12 +108,10 @@ private[sql] object NumericType extends AbstractDataType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - - private[sql] override def defaultConcreteType: DataType = IntegerType } -private[sql] object IntegralType extends AbstractDataType { +private[sql] object IntegralType { /** * Enables matching against IntegralType for expressions: * {{{ @@ -84,8 +120,6 @@ private[sql] object IntegralType extends AbstractDataType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] - - private[sql] override def defaultConcreteType: DataType = IntegerType } @@ -94,7 +128,7 @@ private[sql] abstract class IntegralType extends NumericType { } -private[sql] object FractionalType extends AbstractDataType { +private[sql] object FractionalType { /** * Enables matching against FractionalType for expressions: * {{{ @@ -103,8 +137,6 @@ private[sql] object FractionalType extends AbstractDataType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] - - private[sql] override def defaultConcreteType: DataType = DoubleType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 81553e7fc91a85e4d6c4a8c674f7d8042f63cccf..8ea6cb14c360ea38d8bf30d667f394caace16835 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -26,7 +26,11 @@ object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) - override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[ArrayType] + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index c333fa70d1ef4e007b3fec89a5c13b2c332947db..7d00047d08d742d5c8e6a2a31475895c2972a235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -75,7 +75,9 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType - override def defaultConcreteType: DataType = this + private[sql] override def defaultConcreteType: DataType = this + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 06373a095b1b0d53d7d2aec0a40dcb053f8e2353..434fc037aad4f0db399dae66da870b8c5c9259aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -86,6 +86,10 @@ object DecimalType extends AbstractDataType { private[sql] override def defaultConcreteType: DataType = Unlimited + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[DecimalType] + } + val Unlimited: DecimalType = DecimalType(None) private[sql] object Fixed { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 69c2119e23436e0d9f59fd7d54a24732492814f7..2b25617ec6655d638b94fa4ee9c3f92f5fe2b570 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -71,6 +71,10 @@ object MapType extends AbstractDataType { private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[MapType] + } + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6fedeabf23203a48d27d239ff4dfdf51c826b09b..7e77b77e739403aeb383f3fdea175eb7034f0df6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -301,7 +301,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } -object StructType { +object StructType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = new StructType + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[StructType] + } def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 498fd86a06fd91c5384c78fae858c591ddf0642b..60e727c6c7d4de8b90a84c5774a1785064045d9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -27,28 +27,47 @@ import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { test("implicit type cast") { - def shouldCast(from: DataType, to: AbstractDataType): Unit = { + def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) - assert(got.dataType === to.defaultConcreteType) + assert(got.map(_.dataType) == Option(expected), + s"Failed to cast $from to $to") } + shouldCast(NullType, NullType, NullType) + shouldCast(NullType, IntegerType, IntegerType) + shouldCast(NullType, DecimalType, DecimalType.Unlimited) + // TODO: write the entire implicit cast table out for test cases. - shouldCast(ByteType, IntegerType) - shouldCast(IntegerType, IntegerType) - shouldCast(IntegerType, LongType) - shouldCast(IntegerType, DecimalType.Unlimited) - shouldCast(LongType, IntegerType) - shouldCast(LongType, DecimalType.Unlimited) - - shouldCast(DateType, TimestampType) - shouldCast(TimestampType, DateType) - - shouldCast(StringType, IntegerType) - shouldCast(StringType, DateType) - shouldCast(StringType, TimestampType) - shouldCast(IntegerType, StringType) - shouldCast(DateType, StringType) - shouldCast(TimestampType, StringType) + shouldCast(ByteType, IntegerType, IntegerType) + shouldCast(IntegerType, IntegerType, IntegerType) + shouldCast(IntegerType, LongType, LongType) + shouldCast(IntegerType, DecimalType, DecimalType.Unlimited) + shouldCast(LongType, IntegerType, IntegerType) + shouldCast(LongType, DecimalType, DecimalType.Unlimited) + + shouldCast(DateType, TimestampType, TimestampType) + shouldCast(TimestampType, DateType, DateType) + + shouldCast(StringType, IntegerType, IntegerType) + shouldCast(StringType, DateType, DateType) + shouldCast(StringType, TimestampType, TimestampType) + shouldCast(IntegerType, StringType, StringType) + shouldCast(DateType, StringType, StringType) + shouldCast(TimestampType, StringType, StringType) + + shouldCast(StringType, BinaryType, BinaryType) + shouldCast(BinaryType, StringType, StringType) + + shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType) + + shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) + shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) + + shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType) + shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) + shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) + shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) } test("tightest common bound for types") {