diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala new file mode 100644 index 0000000000000000000000000000000000000000..ad7dc0ecdb1bf70f442ac2bb21cd7cde292b9ad4 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.{io => hadoopIo} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types +import org.apache.spark.sql.catalyst.types._ + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[hive] trait HiveInspectors { + + def javaClassToDataType(clz: Class[_]): DataType = clz match { + // writable + case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType + case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType + case c: Class[_] if c == classOf[hadoopIo.Text] => StringType + case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType + case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType + case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType + case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType + case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType + + // java class + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[Array[Byte]] => BinaryType + case c: Class[_] if c == classOf[java.lang.Short] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Byte] => ByteType + case c: Class[_] if c == classOf[java.lang.Float] => FloatType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + + // primitive type + case c: Class[_] if c == java.lang.Short.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + case c: Class[_] if c == java.lang.Long.TYPE => LongType + case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + case c: Class[_] if c == java.lang.Float.TYPE => FloatType + case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + } + + /** Converts hive types to native catalyst types. */ + def unwrap(a: Any): Any = a match { + case null => null + case i: hadoopIo.IntWritable => i.get + case t: hadoopIo.Text => t.toString + case l: hadoopIo.LongWritable => l.get + case d: hadoopIo.DoubleWritable => d.get + case d: hiveIo.DoubleWritable => d.get + case s: hiveIo.ShortWritable => s.get + case b: hadoopIo.BooleanWritable => b.get + case b: hiveIo.ByteWritable => b.get + case b: hadoopIo.FloatWritable => b.get + case b: hadoopIo.BytesWritable => { + val bytes = new Array[Byte](b.getLength) + System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) + bytes + } + case t: hiveIo.TimestampWritable => t.getTimestamp + case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) + case list: java.util.List[_] => list.map(unwrap) + case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap + case array: Array[_] => array.map(unwrap).toSeq + case p: java.lang.Short => p + case p: java.lang.Long => p + case p: java.lang.Float => p + case p: java.lang.Integer => p + case p: java.lang.Double => p + case p: java.lang.Byte => p + case p: java.lang.Boolean => p + case str: String => str + case p: java.math.BigDecimal => p + case p: Array[Byte] => p + case p: java.sql.Timestamp => p + } + + def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + case hvoi: HiveVarcharObjectInspector => + if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue + case hdoi: HiveDecimalObjectInspector => + if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) + case li: ListObjectInspector => + Option(li.getList(data)) + .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) + .orNull + case mi: MapObjectInspector => + Option(mi.getMap(data)).map( + _.map { + case (k,v) => + (unwrapData(k, mi.getMapKeyObjectInspector), + unwrapData(v, mi.getMapValueObjectInspector)) + }.toMap).orNull + case si: StructObjectInspector => + val allRefs = si.getAllStructFieldRefs + new GenericRow( + allRefs.map(r => + unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + } + + /** Converts native catalyst types to the types expected by Hive */ + def wrap(a: Any): AnyRef = a match { + case s: String => new hadoopIo.Text(s) // TODO why should be Text? + case i: Int => i: java.lang.Integer + case b: Boolean => b: java.lang.Boolean + case f: Float => f: java.lang.Float + case d: Double => d: java.lang.Double + case l: Long => l: java.lang.Long + case l: Short => l: java.lang.Short + case l: Byte => l: java.lang.Byte + case b: BigDecimal => b.bigDecimal + case b: Array[Byte] => b + case t: java.sql.Timestamp => t + case s: Seq[_] => seqAsJavaList(s.map(wrap)) + case m: Map[_,_] => + mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + case null => null + } + + def toInspector(dataType: DataType): ObjectInspector = dataType match { + case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType) => + ObjectInspectorFactory.getStandardMapObjectInspector( + toInspector(keyType), toInspector(valueType)) + case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector + case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector + case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector + case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) + } + + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { + case s: StructObjectInspector => + StructType(s.getAllStructFieldRefs.map(f => { + types.StructField( + f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) + })) + case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) + case m: MapObjectInspector => + MapType( + inspectorToDataType(m.getMapKeyObjectInspector), + inspectorToDataType(m.getMapValueObjectInspector)) + case _: WritableStringObjectInspector => StringType + case _: JavaStringObjectInspector => StringType + case _: WritableIntObjectInspector => IntegerType + case _: JavaIntObjectInspector => IntegerType + case _: WritableDoubleObjectInspector => DoubleType + case _: JavaDoubleObjectInspector => DoubleType + case _: WritableBooleanObjectInspector => BooleanType + case _: JavaBooleanObjectInspector => BooleanType + case _: WritableLongObjectInspector => LongType + case _: JavaLongObjectInspector => LongType + case _: WritableShortObjectInspector => ShortType + case _: JavaShortObjectInspector => ShortType + case _: WritableByteObjectInspector => ByteType + case _: JavaByteObjectInspector => ByteType + case _: WritableFloatObjectInspector => FloatType + case _: JavaFloatObjectInspector => FloatType + case _: WritableBinaryObjectInspector => BinaryType + case _: JavaBinaryObjectInspector => BinaryType + case _: WritableHiveDecimalObjectInspector => DecimalType + case _: JavaHiveDecimalObjectInspector => DecimalType + case _: WritableTimestampObjectInspector => TimestampType + case _: JavaTimestampObjectInspector => TimestampType + } + + implicit class typeInfoConversions(dt: DataType) { + import org.apache.hadoop.hive.serde2.typeinfo._ + import TypeInfoFactory._ + + def toTypeInfo: TypeInfo = dt match { + case BinaryType => binaryTypeInfo + case BooleanType => booleanTypeInfo + case ByteType => byteTypeInfo + case DoubleType => doubleTypeInfo + case FloatType => floatTypeInfo + case IntegerType => intTypeInfo + case LongType => longTypeInfo + case ShortType => shortTypeInfo + case StringType => stringTypeInfo + case DecimalType => decimalTypeInfo + case TimestampType => timestampTypeInfo + case NullType => voidTypeInfo + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 53480a521dd147d2aa019b6117703d1c735a3d8e..c4ca9f362a04d0dc2cadf254abef22bc974bf33e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command private[hive] case class SourceCommand(filePath: String) extends Command -private[hive] case class AddJar(jarPath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ @@ -229,7 +227,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8)) + NativeCommand(sql) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index fc33c5b460d7059953497c505e56ba57fd3c23a4..057eb60a0261235a2deaeacedcefaf29fee83acc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry - extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors { +private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { + + def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is @@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( sys.error(s"Couldn't find function $name")) + val functionClassName = functionInfo.getFunctionClass.getName() + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val function = createFunction[UDF](name) + val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) HiveSimpleUdf( - name, + functionClassName, children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(name, children) + HiveGenericUdf(functionClassName, children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(name, children) + HiveGenericUdaf(functionClassName, children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(name, Nil, children) + HiveGenericUdtf(functionClassName, Nil, children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } } - - def javaClassToDataType(clz: Class[_]): DataType = clz match { - // writable - case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType - case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType - case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType - case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType - case c: Class[_] if c == classOf[hadoopIo.Text] => StringType - case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType - case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType - case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType - case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType - case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType - - // java class - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType - case c: Class[_] if c == classOf[Array[Byte]] => BinaryType - case c: Class[_] if c == classOf[java.lang.Short] => ShortType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Byte] => ByteType - case c: Class[_] if c == classOf[java.lang.Float] => FloatType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - - // primitive type - case c: Class[_] if c == java.lang.Short.TYPE => ShortType - case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType - case c: Class[_] if c == java.lang.Long.TYPE => LongType - case c: Class[_] if c == java.lang.Double.TYPE => DoubleType - case c: Class[_] if c == java.lang.Byte.TYPE => ByteType - case c: Class[_] if c == java.lang.Float.TYPE => FloatType - case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - - case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) - } } private[hive] trait HiveFunctionFactory { - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) - def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass - def createFunction[UDFType](name: String) = - getFunctionClass(name).newInstance.asInstanceOf[UDFType] - - /** Converts hive types to native catalyst types. */ - def unwrap(a: Any): Any = a match { - case null => null - case i: hadoopIo.IntWritable => i.get - case t: hadoopIo.Text => t.toString - case l: hadoopIo.LongWritable => l.get - case d: hadoopIo.DoubleWritable => d.get - case d: hiveIo.DoubleWritable => d.get - case s: hiveIo.ShortWritable => s.get - case b: hadoopIo.BooleanWritable => b.get - case b: hiveIo.ByteWritable => b.get - case b: hadoopIo.FloatWritable => b.get - case b: hadoopIo.BytesWritable => { - val bytes = new Array[Byte](b.getLength) - System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) - bytes - } - case t: hiveIo.TimestampWritable => t.getTimestamp - case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) - case list: java.util.List[_] => list.map(unwrap) - case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap - case array: Array[_] => array.map(unwrap).toSeq - case p: java.lang.Short => p - case p: java.lang.Long => p - case p: java.lang.Float => p - case p: java.lang.Integer => p - case p: java.lang.Double => p - case p: java.lang.Byte => p - case p: java.lang.Boolean => p - case str: String => str - case p: java.math.BigDecimal => p - case p: Array[Byte] => p - case p: java.sql.Timestamp => p - } + val functionClassName: String + + def createFunction[UDFType]() = + getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType] } private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory { @@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu type UDFType type EvaluatedType = Any - val name: String - def nullable = true def references = children.flatMap(_.references).toSet - // FunctionInfo is not serializable so we must look it up here again. - lazy val functionInfo = getFunctionInfo(name) - lazy val function = createFunction[UDFType](name) + lazy val function = createFunction[UDFType]() - override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } -private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { +private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) + extends HiveUdf { + import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) } } -private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) +private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression]) extends HiveUdf with HiveInspectors { import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ @@ -277,131 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) } } -private[hive] trait HiveInspectors { - - def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { - case hvoi: HiveVarcharObjectInspector => - if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue - case hdoi: HiveDecimalObjectInspector => - if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) - case li: ListObjectInspector => - Option(li.getList(data)) - .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) - .orNull - case mi: MapObjectInspector => - Option(mi.getMap(data)).map( - _.map { - case (k,v) => - (unwrapData(k, mi.getMapKeyObjectInspector), - unwrapData(v, mi.getMapValueObjectInspector)) - }.toMap).orNull - case si: StructObjectInspector => - val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) - } - - /** Converts native catalyst types to the types expected by Hive */ - def wrap(a: Any): AnyRef = a match { - case s: String => new hadoopIo.Text(s) // TODO why should be Text? - case i: Int => i: java.lang.Integer - case b: Boolean => b: java.lang.Boolean - case f: Float => f: java.lang.Float - case d: Double => d: java.lang.Double - case l: Long => l: java.lang.Long - case l: Short => l: java.lang.Short - case l: Byte => l: java.lang.Byte - case b: BigDecimal => b.bigDecimal - case b: Array[Byte] => b - case t: java.sql.Timestamp => t - case s: Seq[_] => seqAsJavaList(s.map(wrap)) - case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) - case null => null - } - - def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => - ObjectInspectorFactory.getStandardMapObjectInspector( - toInspector(keyType), toInspector(valueType)) - case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector - case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector - case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector - case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector - case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector - case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector - case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector - case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector - case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector - case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector - case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector - case StructType(fields) => - ObjectInspectorFactory.getStandardStructObjectInspector( - fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) - } - - def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { - case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { - types.StructField( - f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) - case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) - case m: MapObjectInspector => - MapType( - inspectorToDataType(m.getMapKeyObjectInspector), - inspectorToDataType(m.getMapValueObjectInspector)) - case _: WritableStringObjectInspector => StringType - case _: JavaStringObjectInspector => StringType - case _: WritableIntObjectInspector => IntegerType - case _: JavaIntObjectInspector => IntegerType - case _: WritableDoubleObjectInspector => DoubleType - case _: JavaDoubleObjectInspector => DoubleType - case _: WritableBooleanObjectInspector => BooleanType - case _: JavaBooleanObjectInspector => BooleanType - case _: WritableLongObjectInspector => LongType - case _: JavaLongObjectInspector => LongType - case _: WritableShortObjectInspector => ShortType - case _: JavaShortObjectInspector => ShortType - case _: WritableByteObjectInspector => ByteType - case _: JavaByteObjectInspector => ByteType - case _: WritableFloatObjectInspector => FloatType - case _: JavaFloatObjectInspector => FloatType - case _: WritableBinaryObjectInspector => BinaryType - case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType - case _: WritableTimestampObjectInspector => TimestampType - case _: JavaTimestampObjectInspector => TimestampType - } - - implicit class typeInfoConversions(dt: DataType) { - import org.apache.hadoop.hive.serde2.typeinfo._ - import TypeInfoFactory._ - - def toTypeInfo: TypeInfo = dt match { - case BinaryType => binaryTypeInfo - case BooleanType => booleanTypeInfo - case ByteType => byteTypeInfo - case DoubleType => doubleTypeInfo - case FloatType => floatTypeInfo - case IntegerType => intTypeInfo - case LongType => longTypeInfo - case ShortType => shortTypeInfo - case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo - case TimestampType => timestampTypeInfo - case NullType => voidTypeInfo - } - } -} - private[hive] case class HiveGenericUdaf( - name: String, + functionClassName: String, children: Seq[Expression]) extends AggregateExpression with HiveInspectors with HiveFunctionFactory { @@ -409,7 +207,7 @@ private[hive] case class HiveGenericUdaf( type UDFType = AbstractGenericUDAFResolver @transient - protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name) + protected lazy val resolver: AbstractGenericUDAFResolver = createFunction() @transient protected lazy val objectInspector = { @@ -426,9 +224,9 @@ private[hive] case class HiveGenericUdaf( def references: Set[Attribute] = children.map(_.references).flatten.toSet - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" - def newInstance() = new HiveUdafFunction(name, children, this) + def newInstance() = new HiveUdafFunction(functionClassName, children, this) } /** @@ -443,7 +241,7 @@ private[hive] case class HiveGenericUdaf( * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUdtf( - name: String, + functionClassName: String, aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors with HiveFunctionFactory { @@ -451,7 +249,7 @@ private[hive] case class HiveGenericUdtf( override def references = children.flatMap(_.references).toSet @transient - protected lazy val function: GenericUDTF = createFunction(name) + protected lazy val function: GenericUDTF = createFunction() protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) @@ -506,11 +304,11 @@ private[hive] case class HiveGenericUdtf( } } - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } private[hive] case class HiveUdafFunction( - functionName: String, + functionClassName: String, exprs: Seq[Expression], base: AggregateExpression) extends AggregateFunction @@ -519,7 +317,7 @@ private[hive] case class HiveUdafFunction( def this() = this(null, null, null) - private val resolver = createFunction[AbstractGenericUDAFResolver](functionName) + private val resolver = createFunction[AbstractGenericUDAFResolver]() private val inspectors = exprs.map(_.dataType).map(toInspector).toArray