diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala index d263c312960a9abf90022117a63dbce791fcf573..29b4b9b006e45b60f62697fb92a3b7412255e22b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala @@ -26,8 +26,7 @@ import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer} import org.apache.hadoop.io.Writable @@ -95,29 +94,34 @@ case class HiveTableScan( attributes.map { a => val ordinal = relation.partitionKeys.indexOf(a) if (ordinal >= 0) { + val dataType = relation.partitionKeys(ordinal).dataType (_: Any, partitionKeys: Array[String]) => { - val value = partitionKeys(ordinal) - val dataType = relation.partitionKeys(ordinal).dataType - unwrapHiveData(castFromString(value, dataType)) + castFromString(partitionKeys(ordinal), dataType) } } else { val ref = objectInspector.getAllStructFieldRefs .find(_.getFieldName == a.name) .getOrElse(sys.error(s"Can't find attribute $a")) + val fieldObjectInspector = ref.getFieldObjectInspector + + val unwrapHiveData = fieldObjectInspector match { + case _: HiveVarcharObjectInspector => + (value: Any) => value.asInstanceOf[HiveVarchar].getValue + case _: HiveDecimalObjectInspector => + (value: Any) => BigDecimal(value.asInstanceOf[HiveDecimal].bigDecimalValue()) + case _ => + identity[Any] _ + } + (row: Any, _: Array[String]) => { val data = objectInspector.getStructFieldData(row, ref) - unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector)) + val hiveData = unwrapData(data, fieldObjectInspector) + if (hiveData != null) unwrapHiveData(hiveData) else null } } } } - private def unwrapHiveData(value: Any) = value match { - case varchar: HiveVarchar => varchar.getValue - case decimal: HiveDecimal => BigDecimal(decimal.bigDecimalValue) - case other => other - } - private def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 1b5a132f9665d3509361a43af5a48e7320af76e4..0f954103a85f21fd184df741137c5d1a8b20558b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -133,15 +133,14 @@ abstract class HiveComparisonTest def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: BaseRelation | _: Generate | _: Sample | _: Distinct => false case PhysicalOperation(_, _, Sort(_, _)) => true - case _ => plan.children.iterator.map(isSorted).exists(_ == true) + case _ => plan.children.iterator.exists(isSorted) } val orderedAnswer = hiveQuery.logical match { // Clean out non-deterministic time schema info. case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer - case plan if isSorted(plan) => answer - case _ => answer.sorted + case plan => if (isSorted(plan)) answer else answer.sorted } orderedAnswer.map(cleanPaths) }