Skip to content
Snippets Groups Projects
Commit d7b2b97a authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11727][SQL] Split ExpressionEncoder into FlatEncoder and ProductEncoder

also add more tests for encoders, and fix bugs that I found:

* when convert array to catalyst array, we can only skip element conversion for native types(e.g. int, long, boolean), not `AtomicType`(String is AtomicType but we need to convert it)
* we should also handle scala `BigDecimal` when convert from catalyst `Decimal`.
* complex map type should be supported

other issues that still in investigation:

* encode java `BigDecimal` and decode it back, seems we will loss precision info.
* when encode case class that defined inside a object, `ClassNotFound` exception will be thrown.

I'll remove unused code in a follow-up PR.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9693 from cloud-fan/split.
parent 23b8188f
No related branches found
No related tags found
No related merge requests found
Showing
with 766 additions and 289 deletions
......@@ -75,7 +75,7 @@ trait ScalaReflection {
*
* @see SPARK-5281
*/
private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
/**
* Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference}
import org.apache.spark.sql.catalyst.ScalaReflection
object FlatEncoder {
import ScalaReflection.schemaFor
import ScalaReflection.dataTypeFor
def apply[T : TypeTag]: ExpressionEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val tpe = typeTag[T].tpe
val mirror = typeTag[T].mirror
val cls = mirror.runtimeClass(tpe)
assert(!schemaFor(tpe).dataType.isInstanceOf[StructType])
val input = BoundReference(0, dataTypeFor(tpe), nullable = true)
val toRowExpression = CreateNamedStruct(
Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil)
val fromRowExpression = ProductEncoder.constructorFor(tpe)
new ExpressionEncoder[T](
toRowExpression.dataType,
flat = true,
toRowExpression.flatten,
fromRowExpression,
ClassTag[T](cls))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.encoders
import org.apache.spark.util.Utils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData}
import scala.reflect.ClassTag
object ProductEncoder {
import ScalaReflection.universe._
import ScalaReflection.localTypeOf
import ScalaReflection.dataTypeFor
import ScalaReflection.Schema
import ScalaReflection.schemaFor
import ScalaReflection.arrayClassFor
def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val tpe = typeTag[T].tpe
val mirror = typeTag[T].mirror
val cls = mirror.runtimeClass(tpe)
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct]
val fromRowExpression = constructorFor(tpe)
new ExpressionEncoder[T](
toRowExpression.dataType,
flat = false,
toRowExpression.flatten,
fromRowExpression,
ClassTag[T](cls))
}
// The Predef.Map is scala.collection.immutable.Map.
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map
def extractorFor(
inputObject: Expression,
tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
optType match {
// For primitive types we must manually unbox the value of the object.
case t if t <:< definitions.IntTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
"intValue",
IntegerType)
case t if t <:< definitions.LongTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
"longValue",
LongType)
case t if t <:< definitions.DoubleTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
"doubleValue",
DoubleType)
case t if t <:< definitions.FloatTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
"floatValue",
FloatType)
case t if t <:< definitions.ShortTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
"shortValue",
ShortType)
case t if t <:< definitions.ByteTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
"byteValue",
ByteType)
case t if t <:< definitions.BooleanTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
"booleanValue",
BooleanType)
// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
val className: String = optType.erasure.typeSymbol.asClass.fullName
val classObj = Utils.classForName(className)
val optionObjectType = ObjectType(classObj)
val unwrapped = UnwrapOption(optionObjectType, inputObject)
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, schemaFor(optType).dataType),
extractorFor(unwrapped, optType))
}
case t if t <:< localTypeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val constructorSymbol = t.member(nme.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramss
} else {
// Find the primary constructor, and use its parameter ordering.
val primaryConstructorSymbol: Option[Symbol] =
constructorSymbol.asTerm.alternatives.find(s =>
s.isMethod && s.asMethod.isPrimaryConstructor)
if (primaryConstructorSymbol.isEmpty) {
sys.error("Internal SQL error: Product object did not have a primary constructor.")
} else {
primaryConstructorSymbol.get.asMethod.paramss
}
}
CreateNamedStruct(params.head.flatMap { p =>
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
})
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keys =
Invoke(
Invoke(inputObject, "keysIterator",
ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
val convertedKeys = toCatalystArray(keys, keyType)
val values =
Invoke(
Invoke(inputObject, "valuesIterator",
ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
val convertedValues = toCatalystArray(values, valueType)
val Schema(keyDataType, _) = schemaFor(keyType)
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
NewInstance(
classOf[ArrayBasedMapData],
convertedKeys :: convertedValues :: Nil,
dataType = MapType(keyDataType, valueDataType, valueNullable))
case t if t <:< localTypeOf[String] =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils,
DateType,
"fromJavaDate",
inputObject :: Nil)
case t if t <:< localTypeOf[BigDecimal] =>
StaticInvoke(
Decimal,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
case t if t <:< localTypeOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case t if t <:< localTypeOf[java.lang.Long] =>
Invoke(inputObject, "longValue", LongType)
case t if t <:< localTypeOf[java.lang.Double] =>
Invoke(inputObject, "doubleValue", DoubleType)
case t if t <:< localTypeOf[java.lang.Float] =>
Invoke(inputObject, "floatValue", FloatType)
case t if t <:< localTypeOf[java.lang.Short] =>
Invoke(inputObject, "shortValue", ShortType)
case t if t <:< localTypeOf[java.lang.Byte] =>
Invoke(inputObject, "byteValue", ByteType)
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
}
}
}
private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = schemaFor(elementType)
if (RowEncoder.isNativeType(catalystType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(catalystType, nullable))
} else {
MapObjects(extractorFor(_, elementType), input, externalDataType)
}
}
def constructorFor(
tpe: `Type`,
path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
/** Returns the current path with a field at ordinal extracted. */
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
.map(p => GetInternalRowField(p, ordinal, dataType))
.getOrElse(BoundReference(ordinal, dataType, false))
/** Returns the current path or `BoundReference`. */
def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
WrapOption(null, constructorFor(optType, path))
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
case t if t <:< localTypeOf[java.lang.Long] =>
val boxedType = classOf[java.lang.Long]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
case t if t <:< localTypeOf[java.lang.Double] =>
val boxedType = classOf[java.lang.Double]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
case t if t <:< localTypeOf[java.lang.Float] =>
val boxedType = classOf[java.lang.Float]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
case t if t <:< localTypeOf[java.lang.Short] =>
val boxedType = classOf[java.lang.Short]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
case t if t <:< localTypeOf[java.lang.Byte] =>
val boxedType = classOf[java.lang.Byte]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
case t if t <:< localTypeOf[java.lang.Boolean] =>
val boxedType = classOf[java.lang.Boolean]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
getPath :: Nil,
propagateNull = true)
case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
getPath :: Nil,
propagateNull = true)
case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
case t if t <:< localTypeOf[BigDecimal] =>
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
case t if t <:< definitions.FloatTpe => Some("toFloatArray")
case t if t <:< definitions.ShortTpe => Some("toShortArray")
case t if t <:< definitions.ByteTpe => Some("toByteArray")
case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
case _ => None
}
primitiveMethod.map { method =>
Invoke(getPath, method, arrayClassFor(elementType))
}.getOrElse {
Invoke(
MapObjects(
p => constructorFor(elementType, Some(p)),
getPath,
schemaFor(elementType).dataType),
"array",
arrayClassFor(elementType))
}
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val arrayData =
Invoke(
MapObjects(
p => constructorFor(elementType, Some(p)),
getPath,
schemaFor(elementType).dataType),
"array",
ObjectType(classOf[Array[Any]]))
StaticInvoke(
scala.collection.mutable.WrappedArray,
ObjectType(classOf[Seq[_]]),
"make",
arrayData :: Nil)
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyData =
Invoke(
MapObjects(
p => constructorFor(keyType, Some(p)),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
schemaFor(keyType).dataType),
"array",
ObjectType(classOf[Array[Any]]))
val valueData =
Invoke(
MapObjects(
p => constructorFor(valueType, Some(p)),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType),
"array",
ObjectType(classOf[Array[Any]]))
StaticInvoke(
ArrayBasedMapData,
ObjectType(classOf[Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
case t if t <:< localTypeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val constructorSymbol = t.member(nme.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramss
} else {
// Find the primary constructor, and use its parameter ordering.
val primaryConstructorSymbol: Option[Symbol] =
constructorSymbol.asTerm.alternatives.find(s =>
s.isMethod && s.asMethod.isPrimaryConstructor)
if (primaryConstructorSymbol.isEmpty) {
sys.error("Internal SQL error: Product object did not have a primary constructor.")
} else {
primaryConstructorSymbol.get.asMethod.paramss
}
}
val className: String = t.erasure.typeSymbol.asClass.fullName
val cls = Utils.classForName(className)
val arguments = params.head.zipWithIndex.map { case (p, i) =>
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val dataType = schemaFor(fieldType).dataType
// For tuples, we based grab the inner fields by ordinal instead of name.
if (className startsWith "scala.Tuple") {
constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
} else {
constructorFor(fieldType, Some(addToPath(fieldName)))
}
}
val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
if (path.nonEmpty) {
expressions.If(
IsNull(getPath),
expressions.Literal.create(null, ObjectType(cls)),
newInstance
)
} else {
newInstance
}
}
}
}
......@@ -119,9 +119,17 @@ object RowEncoder {
CreateStruct(convertedFields)
}
private def externalDataTypeFor(dt: DataType): DataType = dt match {
/**
* Returns true if the value of this data type is same between internal and external.
*/
def isNativeType(dt: DataType): Boolean = dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => dt
FloatType | DoubleType | BinaryType => true
case _ => false
}
private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if isNativeType(dt) => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
......@@ -137,13 +145,13 @@ object RowEncoder {
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType)
constructorFor(BoundReference(i, f.dataType, f.nullable))
)
}
CreateExternalRow(fields)
}
private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match {
private def constructorFor(input: Expression): Expression = input.dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input
......@@ -170,7 +178,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(constructorFor(_, et), input, et),
MapObjects(constructorFor, input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
......@@ -181,10 +189,10 @@ object RowEncoder {
case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType)
val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType))
val valueArrayType = ArrayType(vt, valueNullable)
val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType)
val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
StaticInvoke(
ArrayBasedMapData,
......@@ -197,42 +205,8 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(getField(input, i, f.dataType), f.dataType))
constructorFor(GetInternalRowField(input, i, f.dataType)))
}
CreateExternalRow(convertedFields)
}
private def getField(
row: Expression,
ordinal: Int,
dataType: DataType): Expression = dataType match {
case BooleanType =>
Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil)
case ByteType =>
Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil)
case ShortType =>
Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil)
case IntegerType | DateType =>
Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil)
case LongType | TimestampType =>
Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil)
case FloatType =>
Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil)
case DoubleType =>
Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil)
case t: DecimalType =>
Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_)))
case StringType =>
Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil)
case BinaryType =>
Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil)
case CalendarIntervalType =>
Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil)
case t: StructType =>
Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil)
case _: ArrayType =>
Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil)
case _: MapType =>
Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil)
}
}
......@@ -110,7 +110,7 @@ object DateTimeUtils {
}
def stringToTime(s: String): java.util.Date = {
var indexOfGMT = s.indexOf("GMT");
val indexOfGMT = s.indexOf("GMT")
if (indexOfGMT != -1) {
// ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00)
val s0 = s.substring(0, indexOfGMT)
......
......@@ -23,7 +23,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class GenericArrayData(val array: Array[Any]) extends ArrayData {
def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray)
def this(seq: Seq[Any]) = this(seq.toArray)
// TODO: This is boxing. We should specialize.
def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)
......
......@@ -17,232 +17,27 @@
package org.apache.spark.sql.catalyst.encoders
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe._
import java.util.Arrays
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{StructField, ArrayType}
case class RepeatedStruct(s: Seq[PrimitiveData])
case class NestedArray(a: Array[Array[Int]])
case class BoxedData(
intField: java.lang.Integer,
longField: java.lang.Long,
doubleField: java.lang.Double,
floatField: java.lang.Float,
shortField: java.lang.Short,
byteField: java.lang.Byte,
booleanField: java.lang.Boolean)
case class RepeatedData(
arrayField: Seq[Int],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: scala.collection.Map[Int, Long],
mapFieldNull: scala.collection.Map[Int, java.lang.Long],
structField: PrimitiveData)
case class SpecificCollection(l: List[Int])
class ExpressionEncoderSuite extends SparkFunSuite {
encodeDecodeTest(1)
encodeDecodeTest(1L)
encodeDecodeTest(1.toDouble)
encodeDecodeTest(1.toFloat)
encodeDecodeTest(true)
encodeDecodeTest(false)
encodeDecodeTest(1.toShort)
encodeDecodeTest(1.toByte)
encodeDecodeTest("hello")
encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
// TODO: Support creating specific subclasses of Seq.
ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) }
encodeDecodeTest(
OptionalData(
Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None))
encodeDecodeTest(
BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
encodeDecodeTest(
BoxedData(null, null, null, null, null, null, null))
encodeDecodeTest(
RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
encodeDecodeTest(
RepeatedData(
Seq(1, 2),
Seq(new Integer(1), null, new Integer(2)),
Map(1 -> 2L),
Map(1 -> null),
PrimitiveData(1, 1, 1, 1, 1, 1, true)))
encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null)))
encodeDecodeTest(("Seq[(String, String)]",
Seq(("a", "b"))))
encodeDecodeTest(("Seq[(Int, Int)]",
Seq((1, 2))))
encodeDecodeTest(("Seq[(Long, Long)]",
Seq((1L, 2L))))
encodeDecodeTest(("Seq[(Float, Float)]",
Seq((1.toFloat, 2.toFloat))))
encodeDecodeTest(("Seq[(Double, Double)]",
Seq((1.toDouble, 2.toDouble))))
encodeDecodeTest(("Seq[(Short, Short)]",
Seq((1.toShort, 2.toShort))))
encodeDecodeTest(("Seq[(Byte, Byte)]",
Seq((1.toByte, 2.toByte))))
encodeDecodeTest(("Seq[(Boolean, Boolean)]",
Seq((true, false))))
// TODO: Decoding/encoding of complex maps.
ignore("complex maps") {
encodeDecodeTest(("Map[Int, (String, String)]",
Map(1 ->("a", "b"))))
}
encodeDecodeTest(("ArrayBuffer[(String, String)]",
ArrayBuffer(("a", "b"))))
encodeDecodeTest(("ArrayBuffer[(Int, Int)]",
ArrayBuffer((1, 2))))
encodeDecodeTest(("ArrayBuffer[(Long, Long)]",
ArrayBuffer((1L, 2L))))
encodeDecodeTest(("ArrayBuffer[(Float, Float)]",
ArrayBuffer((1.toFloat, 2.toFloat))))
encodeDecodeTest(("ArrayBuffer[(Double, Double)]",
ArrayBuffer((1.toDouble, 2.toDouble))))
encodeDecodeTest(("ArrayBuffer[(Short, Short)]",
ArrayBuffer((1.toShort, 2.toShort))))
encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]",
ArrayBuffer((1.toByte, 2.toByte))))
encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]",
ArrayBuffer((true, false))))
encodeDecodeTest(("Seq[Seq[(Int, Int)]]",
Seq(Seq((1, 2)))))
encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
Array(Array((1, 2)))))
{ (l, r) => l._2(0)(0) == r._2(0)(0) }
encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
Array(Array(Array((1, 2))))))
{ (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]",
Array(Array(Array(Array((1, 2)))))))
{ (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]",
Array(Array(Array(Array(Array((1, 2))))))))
{ (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
encodeDecodeTestCustom(("Array[Array[Integer]]",
Array(Array[Integer](1))))
{ (l, r) => l._2(0)(0) == r._2(0)(0) }
encodeDecodeTestCustom(("Array[Array[Int]]",
Array(Array(1))))
{ (l, r) => l._2(0)(0) == r._2(0)(0) }
encodeDecodeTestCustom(("Array[Array[Int]]",
Array(Array(Array(1)))))
{ (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
encodeDecodeTestCustom(("Array[Array[Array[Int]]]",
Array(Array(Array(Array(1))))))
{ (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]",
Array(Array(Array(Array(Array(1)))))))
{ (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
encodeDecodeTest(("Array[Byte] null",
null: Array[Byte]))
encodeDecodeTestCustom(("Array[Byte]",
Array[Byte](1, 2, 3)))
{ (l, r) => java.util.Arrays.equals(l._2, r._2) }
encodeDecodeTest(("Array[Int] null",
null: Array[Int]))
encodeDecodeTestCustom(("Array[Int]",
Array[Int](1, 2, 3)))
{ (l, r) => java.util.Arrays.equals(l._2, r._2) }
encodeDecodeTest(("Array[Long] null",
null: Array[Long]))
encodeDecodeTestCustom(("Array[Long]",
Array[Long](1, 2, 3)))
{ (l, r) => java.util.Arrays.equals(l._2, r._2) }
encodeDecodeTest(("Array[Double] null",
null: Array[Double]))
encodeDecodeTestCustom(("Array[Double]",
Array[Double](1, 2, 3)))
{ (l, r) => java.util.Arrays.equals(l._2, r._2) }
encodeDecodeTest(("Array[Float] null",
null: Array[Float]))
encodeDecodeTestCustom(("Array[Float]",
Array[Float](1, 2, 3)))
{ (l, r) => java.util.Arrays.equals(l._2, r._2) }
encodeDecodeTest(("Array[Boolean] null",
null: Array[Boolean]))
encodeDecodeTestCustom(("Array[Boolean]",
Array[Boolean](true, false)))
{ (l, r) => java.util.Arrays.equals(l._2, r._2) }
encodeDecodeTest(("Array[Short] null",
null: Array[Short]))
encodeDecodeTestCustom(("Array[Short]",
Array[Short](1, 2, 3)))
{ (l, r) => java.util.Arrays.equals(l._2, r._2) }
encodeDecodeTestCustom(("java.sql.Timestamp",
new java.sql.Timestamp(1)))
{ (l, r) => l._2.toString == r._2.toString }
encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1)))
{ (l, r) => l._2.toString == r._2.toString }
/** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
protected def encodeDecodeTest[T : TypeTag](inputData: T) =
encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
/**
* Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
* matches the original.
*/
protected def encodeDecodeTestCustom[T : TypeTag](
inputData: T)(
c: (T, T) => Boolean) = {
test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
val encoder = try ExpressionEncoder[T]() catch {
case e: Exception =>
fail(s"Exception thrown generating encoder", e)
}
val convertedData = encoder.toRow(inputData)
import org.apache.spark.sql.types.ArrayType
abstract class ExpressionEncoderSuite extends SparkFunSuite {
protected def encodeDecodeTest[T](
input: T,
encoder: ExpressionEncoder[T],
testName: String): Unit = {
test(s"encode/decode for $testName: $input") {
val row = encoder.toRow(input)
val schema = encoder.schema.toAttributes
val boundEncoder = encoder.resolve(schema).bind(schema)
val convertedBack = try boundEncoder.fromRow(convertedData) catch {
val convertedBack = try boundEncoder.fromRow(row) catch {
case e: Exception =>
fail(
s"""Exception thrown while decoding
|Converted: $convertedData
|Converted: $row
|Schema: ${schema.mkString(",")}
|${encoder.schema.treeString}
|
......@@ -252,18 +47,27 @@ class ExpressionEncoderSuite extends SparkFunSuite {
""".stripMargin, e)
}
if (!c(inputData, convertedBack)) {
val isCorrect = (input, convertedBack) match {
case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2)
case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2)
case (b1: Array[Array[_]], b2: Array[Array[_]]) =>
Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
case (b1: Array[_], b2: Array[_]) =>
Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
case _ => input == convertedBack
}
if (!isCorrect) {
val types = convertedBack match {
case c: Product =>
c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
case other => other.getClass.getName
}
val encodedData = try {
convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
a.toArray[Any](at.elementType).toSeq
row.toSeq(encoder.schema).zip(schema).map {
case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) =>
a.toArray[Any](et).toSeq
case (other, _) =>
other
}.mkString("[", ",", "]")
......@@ -274,7 +78,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
fail(
s"""Encoded/Decoded data does not match input data
|
|in: $inputData
|in: $input
|out: $convertedBack
|types: $types
|
......@@ -282,11 +86,10 @@ class ExpressionEncoderSuite extends SparkFunSuite {
|Schema: ${schema.mkString(",")}
|${encoder.schema.treeString}
|
|Extract Expressions:
|$boundEncoder
|fromRow Expressions:
|${boundEncoder.fromRowExpression.treeString}
""".stripMargin)
}
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.encoders
import java.sql.{Date, Timestamp}
class FlatEncoderSuite extends ExpressionEncoderSuite {
encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean")
encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte")
encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short")
encodeDecodeTest(-3, FlatEncoder[Int], "primitive int")
encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long")
encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float")
encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double")
encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean")
encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte")
encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short")
encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int")
encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long")
encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float")
encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double")
encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal")
type JDecimal = java.math.BigDecimal
// encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal")
encodeDecodeTest("hello", FlatEncoder[String], "string")
encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date")
encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp")
encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary")
encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int")
encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string")
encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null")
encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int")
encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string")
encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)),
FlatEncoder[Seq[Seq[Int]]], "seq of seq of int")
encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")),
FlatEncoder[Seq[Seq[String]]], "seq of seq of string")
encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int")
encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string")
encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null")
encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int")
encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string")
encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)),
FlatEncoder[Array[Array[Int]]], "array of array of int")
encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")),
FlatEncoder[Array[Array[String]]], "array of array of string")
encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map")
encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null")
encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)),
FlatEncoder[Map[Int, Map[String, Int]]], "map of map")
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.encoders
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
case class RepeatedStruct(s: Seq[PrimitiveData])
case class NestedArray(a: Array[Array[Int]]) {
override def equals(other: Any): Boolean = other match {
case NestedArray(otherArray) =>
java.util.Arrays.deepEquals(
a.asInstanceOf[Array[AnyRef]],
otherArray.asInstanceOf[Array[AnyRef]])
case _ => false
}
}
case class BoxedData(
intField: java.lang.Integer,
longField: java.lang.Long,
doubleField: java.lang.Double,
floatField: java.lang.Float,
shortField: java.lang.Short,
byteField: java.lang.Byte,
booleanField: java.lang.Boolean)
case class RepeatedData(
arrayField: Seq[Int],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: scala.collection.Map[Int, Long],
mapFieldNull: scala.collection.Map[Int, java.lang.Long],
structField: PrimitiveData)
case class SpecificCollection(l: List[Int])
class ProductEncoderSuite extends ExpressionEncoderSuite {
productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
productTest(
OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
productTest(OptionalData(None, None, None, None, None, None, None, None))
productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
productTest(BoxedData(null, null, null, null, null, null, null))
productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true)))
productTest(
RepeatedData(
Seq(1, 2),
Seq(new Integer(1), null, new Integer(2)),
Map(1 -> 2L),
Map(1 -> null),
PrimitiveData(1, 1, 1, 1, 1, 1, true)))
productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6))))
productTest(("Seq[(String, String)]",
Seq(("a", "b"))))
productTest(("Seq[(Int, Int)]",
Seq((1, 2))))
productTest(("Seq[(Long, Long)]",
Seq((1L, 2L))))
productTest(("Seq[(Float, Float)]",
Seq((1.toFloat, 2.toFloat))))
productTest(("Seq[(Double, Double)]",
Seq((1.toDouble, 2.toDouble))))
productTest(("Seq[(Short, Short)]",
Seq((1.toShort, 2.toShort))))
productTest(("Seq[(Byte, Byte)]",
Seq((1.toByte, 2.toByte))))
productTest(("Seq[(Boolean, Boolean)]",
Seq((true, false))))
productTest(("ArrayBuffer[(String, String)]",
ArrayBuffer(("a", "b"))))
productTest(("ArrayBuffer[(Int, Int)]",
ArrayBuffer((1, 2))))
productTest(("ArrayBuffer[(Long, Long)]",
ArrayBuffer((1L, 2L))))
productTest(("ArrayBuffer[(Float, Float)]",
ArrayBuffer((1.toFloat, 2.toFloat))))
productTest(("ArrayBuffer[(Double, Double)]",
ArrayBuffer((1.toDouble, 2.toDouble))))
productTest(("ArrayBuffer[(Short, Short)]",
ArrayBuffer((1.toShort, 2.toShort))))
productTest(("ArrayBuffer[(Byte, Byte)]",
ArrayBuffer((1.toByte, 2.toByte))))
productTest(("ArrayBuffer[(Boolean, Boolean)]",
ArrayBuffer((true, false))))
productTest(("Seq[Seq[(Int, Int)]]",
Seq(Seq((1, 2)))))
private def productTest[T <: Product : TypeTag](input: T): Unit = {
encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
}
}
......@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
......@@ -56,9 +56,6 @@ class GroupedDataset[K, T] private[sql](
private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
/** Encoders for built in aggregations. */
private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
......@@ -211,7 +208,7 @@ class GroupedDataset[K, T] private[sql](
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
* for that key.
*/
def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long])
def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long]))
/**
* Applies the given function to each cogrouped data. For each unique group, the function will
......
......@@ -37,17 +37,21 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class SQLImplicits {
protected def _sqlContext: SQLContext
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int]
implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long]
implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double]
implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float]
implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte]
implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short]
implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean]
implicit def newStringEncoder: Encoder[String] = FlatEncoder[String]
/**
* Creates a [[Dataset]] from an RDD.
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = {
DatasetHolder(_sqlContext.createDataset(rdd))
}
......
......@@ -26,7 +26,7 @@ import scala.util.Try
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.encoders.FlatEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
......@@ -267,7 +267,7 @@ object functions extends LegacyFunctions {
* @since 1.3.0
*/
def count(columnName: String): TypedColumn[Any, Long] =
count(Column(columnName)).as(ExpressionEncoder[Long](flat = true))
count(Column(columnName)).as(FlatEncoder[Long])
/**
* Aggregate function: returns the number of distinct items in a group.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment