Skip to content
Snippets Groups Projects
Commit e4899a25 authored by Takuya UESHIN's avatar Takuya UESHIN Committed by Reynold Xin
Browse files

[SPARK-2254] [SQL] ScalaRefection should mark primitive types as non-nullable.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #1193 from ueshin/issues/SPARK-2254 and squashes the following commits:

cfd6088 [Takuya UESHIN] Modify ScalaRefection.schemaFor method to return nullability of Scala Type.
parent 441cdcca
No related branches found
No related tags found
No related merge requests found
...@@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._ ...@@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._
object ScalaReflection { object ScalaReflection {
import scala.reflect.runtime.universe._ import scala.reflect.runtime.universe._
case class Schema(dataType: DataType, nullable: Boolean)
/** Returns a Sequence of attributes for the given case class type. */ /** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case s: StructType => case Schema(s: StructType, _) =>
s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
} }
/** Returns a catalyst DataType for the given Scala Type using reflection. */ /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T]) def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
/** Returns a catalyst DataType for the given Scala Type using reflection. */ /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): DataType = tpe match { def schemaFor(tpe: `Type`): Schema = tpe match {
case t if t <:< typeOf[Option[_]] => case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t val TypeRef(_, _, Seq(optType)) = t
schemaFor(optType) Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] => case t if t <:< typeOf[Product] =>
val params = t.member("<init>": TermName).asMethod.paramss val params = t.member("<init>": TermName).asMethod.paramss
StructType( Schema(StructType(
params.head.map(p => params.head.map { p =>
StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true))) val Schema(dataType, nullable) = schemaFor(p.typeSignature)
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here. // Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => BinaryType case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] => case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] => case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t val TypeRef(_, _, Seq(elementType)) = t
ArrayType(schemaFor(elementType)) Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
case t if t <:< typeOf[Map[_,_]] => case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t val TypeRef(_, _, Seq(keyType, valueType)) = t
MapType(schemaFor(keyType), schemaFor(valueType)) Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
case t if t <:< typeOf[String] => StringType case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => TimestampType case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[BigDecimal] => DecimalType case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => IntegerType case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => LongType case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => DoubleType case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => FloatType case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => ShortType case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => ByteType case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => BooleanType case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
// TODO: The following datatypes could be marked as non-nullable. case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.DoubleTpe => DoubleType case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.FloatTpe => FloatType case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ShortTpe => ShortType case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if t <:< definitions.BooleanTpe => BooleanType
} }
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst
import java.sql.Timestamp
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._
case class PrimitiveData(
intField: Int,
longField: Long,
doubleField: Double,
floatField: Float,
shortField: Short,
byteField: Byte,
booleanField: Boolean)
case class NullableData(
intField: java.lang.Integer,
longField: java.lang.Long,
doubleField: java.lang.Double,
floatField: java.lang.Float,
shortField: java.lang.Short,
byteField: java.lang.Byte,
booleanField: java.lang.Boolean,
stringField: String,
decimalField: BigDecimal,
timestampField: Timestamp,
binaryField: Array[Byte])
case class OptionalData(
intField: Option[Int],
longField: Option[Long],
doubleField: Option[Double],
floatField: Option[Float],
shortField: Option[Short],
byteField: Option[Byte],
booleanField: Option[Boolean])
case class ComplexData(
arrayField: Seq[Int],
mapField: Map[Int, String],
structField: PrimitiveData)
class ScalaReflectionSuite extends FunSuite {
import ScalaReflection._
test("primitive data") {
val schema = schemaFor[PrimitiveData]
assert(schema === Schema(
StructType(Seq(
StructField("intField", IntegerType, nullable = false),
StructField("longField", LongType, nullable = false),
StructField("doubleField", DoubleType, nullable = false),
StructField("floatField", FloatType, nullable = false),
StructField("shortField", ShortType, nullable = false),
StructField("byteField", ByteType, nullable = false),
StructField("booleanField", BooleanType, nullable = false))),
nullable = true))
}
test("nullable data") {
val schema = schemaFor[NullableData]
assert(schema === Schema(
StructType(Seq(
StructField("intField", IntegerType, nullable = true),
StructField("longField", LongType, nullable = true),
StructField("doubleField", DoubleType, nullable = true),
StructField("floatField", FloatType, nullable = true),
StructField("shortField", ShortType, nullable = true),
StructField("byteField", ByteType, nullable = true),
StructField("booleanField", BooleanType, nullable = true),
StructField("stringField", StringType, nullable = true),
StructField("decimalField", DecimalType, nullable = true),
StructField("timestampField", TimestampType, nullable = true),
StructField("binaryField", BinaryType, nullable = true))),
nullable = true))
}
test("optinal data") {
val schema = schemaFor[OptionalData]
assert(schema === Schema(
StructType(Seq(
StructField("intField", IntegerType, nullable = true),
StructField("longField", LongType, nullable = true),
StructField("doubleField", DoubleType, nullable = true),
StructField("floatField", FloatType, nullable = true),
StructField("shortField", ShortType, nullable = true),
StructField("byteField", ByteType, nullable = true),
StructField("booleanField", BooleanType, nullable = true))),
nullable = true))
}
test("complex data") {
val schema = schemaFor[ComplexData]
assert(schema === Schema(
StructType(Seq(
StructField("arrayField", ArrayType(IntegerType), nullable = true),
StructField("mapField", MapType(IntegerType, StringType), nullable = true),
StructField(
"structField",
StructType(Seq(
StructField("intField", IntegerType, nullable = false),
StructField("longField", LongType, nullable = false),
StructField("doubleField", DoubleType, nullable = false),
StructField("floatField", FloatType, nullable = false),
StructField("shortField", ShortType, nullable = false),
StructField("byteField", ByteType, nullable = false),
StructField("booleanField", BooleanType, nullable = false))),
nullable = true))),
nullable = true))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment