diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 583338da57117e90cb79368df39cd9ff6f56d993..476ac2b7cb4748c42045e0bc38c31c95d9e5f311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -40,7 +40,7 @@ trait CheckAnalysis { def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { case e: Generator => true - }).length >= 1 + }).nonEmpty } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -85,12 +85,12 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty => + case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK + case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index e5dc99fb625d8d320d8d6548a29f01d9f7c837a7..ffefb0e7837e9963b31e65a9682100cba5ba831a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType { * Returns true if this data type is a parent of the `childCandidate`. */ private[sql] def isParentOf(childCandidate: DataType): Boolean + + /** Readable string representation for the type. */ + private[sql] def simpleString: String } @@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst private[sql] override def defaultConcreteType: DataType = types.head private[sql] override def isParentOf(childCandidate: DataType): Boolean = false + + private[sql] override def simpleString: String = { + types.map(_.simpleString).mkString("(", " or ", ")") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 8ea6cb14c360ea38d8bf30d667f394caace16835..43413ec761e6bcddbb1a5de32abc1be3c261f6ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType { private[sql] override def isParentOf(childCandidate: DataType): Boolean = { childCandidate.isInstanceOf[ArrayType] } + + private[sql] override def simpleString: String = "array" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 434fc037aad4f0db399dae66da870b8c5c9259aa..127b16ff85bed1b17cf811883edeac90b39a756c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType { childCandidate.isInstanceOf[DecimalType] } + private[sql] override def simpleString: String = "decimal" + val Unlimited: DecimalType = DecimalType(None) private[sql] object Fixed { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 2b25617ec6655d638b94fa4ee9c3f92f5fe2b570..868dea13d971e42596f33ecf0b1016af8b996833 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -75,6 +75,8 @@ object MapType extends AbstractDataType { childCandidate.isInstanceOf[MapType] } + private[sql] override def simpleString: String = "map" + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 7e77b77e739403aeb383f3fdea175eb7034f0df6..3b17566d54d9b6503b2289b249a640c293b8ea86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -309,6 +309,8 @@ object StructType extends AbstractDataType { childCandidate.isInstanceOf[StructType] } + private[sql] override def simpleString: String = "struct" + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 60e727c6c7d4de8b90a84c5774a1785064045d9f..67d05ab536b7fa6b8cfda9ec3e247941a7dca7c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { - test("implicit type cast") { + test("eligible implicit type cast") { def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) assert(got.map(_.dataType) == Option(expected), @@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) + + shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) + } + + test("ineligible implicit type cast") { + def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { + val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got") + } + + shouldNotCast(IntegerType, DateType) + shouldNotCast(IntegerType, TimestampType) + shouldNotCast(LongType, DateType) + shouldNotCast(LongType, TimestampType) + shouldNotCast(DecimalType.Unlimited, DateType) + shouldNotCast(DecimalType.Unlimited, TimestampType) + + shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) + + shouldNotCast(IntegerType, ArrayType) + shouldNotCast(IntegerType, MapType) + shouldNotCast(IntegerType, StructType) } test("tightest common bound for types") {