From 7c8fc1f7cb837ff5c32811fdeb3ee2b84de2dea4 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Tue, 4 Aug 2015 17:05:19 -0700
Subject: [PATCH] [SPARK-9598][SQL] do not expose generic getter in internal
 row

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7932 from cloud-fan/generic-getter and squashes the following commits:

c60de4c [Wenchen Fan] do not expose generic getter in internal row
---
 .../sql/catalyst/expressions/UnsafeRow.java   |  5 --
 .../spark/sql/catalyst/InternalRow.scala      | 37 +++++++++--
 .../GenericSpecializedGetters.scala           | 61 -------------------
 .../sql/catalyst/expressions/Projection.scala |  4 +-
 .../expressions/SpecificMutableRow.scala      |  2 +-
 .../sql/catalyst/expressions/aggregates.scala |  2 +-
 .../codegen/GenerateProjection.scala          |  2 +-
 .../spark/sql/catalyst/expressions/rows.scala | 12 ++--
 .../spark/sql/types/GenericArrayData.scala    | 37 +++++++----
 .../datasources/DataSourceStrategy.scala      |  6 +-
 .../spark/sql/columnar/ColumnStatsSuite.scala | 20 +++---
 11 files changed, 80 insertions(+), 108 deletions(-)
 delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index e6750fce4f..e3e1622de0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -253,11 +253,6 @@ public final class UnsafeRow extends MutableRow {
     }
   }
 
-  @Override
-  public Object genericGet(int ordinal) {
-    throw new UnsupportedOperationException();
-  }
-
   @Override
   public Object get(int ordinal, DataType dataType) {
     if (isNullAt(ordinal) || dataType instanceof NullType) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 7656d054dc..7d17cca808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -17,15 +17,15 @@
 
 package org.apache.spark.sql.catalyst
 
-import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 /**
  * An abstract class for row used internal in Spark SQL, which only contain the columns as
  * internal types.
  */
-// todo: make InternalRow just extends SpecializedGetters, remove generic getter
-abstract class InternalRow extends GenericSpecializedGetters with Serializable {
+abstract class InternalRow extends SpecializedGetters with Serializable {
 
   def numFields: Int
 
@@ -50,6 +50,31 @@ abstract class InternalRow extends GenericSpecializedGetters with Serializable {
     false
   }
 
+  // Subclasses of InternalRow should implement all special getters and equals/hashCode,
+  // or implement this genericGet.
+  protected def genericGet(ordinal: Int): Any = throw new IllegalStateException(
+    "Concrete internal rows should implement genericGet, " +
+      "or implement all special getters and equals/hashCode")
+
+  // default implementation (slow)
+  private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
+  override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+  override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal)
+  override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+  override def getByte(ordinal: Int): Byte = getAs(ordinal)
+  override def getShort(ordinal: Int): Short = getAs(ordinal)
+  override def getInt(ordinal: Int): Int = getAs(ordinal)
+  override def getLong(ordinal: Int): Long = getAs(ordinal)
+  override def getFloat(ordinal: Int): Float = getAs(ordinal)
+  override def getDouble(ordinal: Int): Double = getAs(ordinal)
+  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+  override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+  override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+  override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+  override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+  override def getMap(ordinal: Int): MapData = getAs(ordinal)
+  override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+
   override def equals(o: Any): Boolean = {
     if (!o.isInstanceOf[InternalRow]) {
       return false
@@ -159,15 +184,15 @@ abstract class InternalRow extends GenericSpecializedGetters with Serializable {
 
 object InternalRow {
   /**
-   * This method can be used to construct a [[Row]] with the given values.
+   * This method can be used to construct a [[InternalRow]] with the given values.
    */
   def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray)
 
   /**
-   * This method can be used to construct a [[Row]] from a [[Seq]] of values.
+   * This method can be used to construct a [[InternalRow]] from a [[Seq]] of values.
    */
   def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray)
 
-  /** Returns an empty row. */
+  /** Returns an empty [[InternalRow]]. */
   val empty = apply()
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala
deleted file mode 100644
index 6e957928e0..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-
-trait GenericSpecializedGetters extends SpecializedGetters {
-
-  def genericGet(ordinal: Int): Any
-
-  private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T]
-
-  override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
-
-  override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal)
-
-  override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
-
-  override def getByte(ordinal: Int): Byte = getAs(ordinal)
-
-  override def getShort(ordinal: Int): Short = getAs(ordinal)
-
-  override def getInt(ordinal: Int): Int = getAs(ordinal)
-
-  override def getLong(ordinal: Int): Long = getAs(ordinal)
-
-  override def getFloat(ordinal: Int): Float = getAs(ordinal)
-
-  override def getDouble(ordinal: Int): Double = getAs(ordinal)
-
-  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
-
-  override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
-
-  override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
-
-  override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
-
-  override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
-
-  override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
-
-  override def getMap(ordinal: Int): MapData = getAs(ordinal)
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 7964974102..4296b4b123 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -207,8 +207,8 @@ class JoinedRow extends InternalRow {
 
   override def numFields: Int = row1.numFields + row2.numFields
 
-  override def genericGet(i: Int): Any =
-    if (i < row1.numFields) row1.genericGet(i) else row2.genericGet(i - row1.numFields)
+  override def get(i: Int, dt: DataType): AnyRef =
+    if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)
 
   override def isNullAt(i: Int): Boolean =
     if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index d149a5b179..b94df6bd66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -232,7 +232,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
     new GenericInternalRow(newValues)
   }
 
-  override def genericGet(i: Int): Any = values(i).boxed
+  override protected def genericGet(i: Int): Any = values(i).boxed
 
   override def update(ordinal: Int, value: Any) {
     if (value == null) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5d4b349b15..2cf8312ea5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -623,7 +623,7 @@ case class CombineSetsAndSumFunction(
       null
     } else {
       Cast(Literal(
-        casted.iterator.map(f => f.genericGet(0)).reduceLeft(
+        casted.iterator.map(f => f.get(0, null)).reduceLeft(
           base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
         base.dataType).eval(null)
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 1572b2b99a..c04fe734d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -184,7 +184,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
         public void setNullAt(int i) { nullBits[i] = true; }
         public boolean isNullAt(int i) { return nullBits[i]; }
 
-        public Object genericGet(int i) {
+        protected Object genericGet(int i) {
           if (isNullAt(i)) return null;
           switch (i) {
           $getCases
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index d04434b953..5e5de1d1dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType}
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 /**
  * An extended interface to [[InternalRow]] that allows the values for each column to be updated.
@@ -76,13 +76,13 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
  * Note that, while the array is not copied, and thus could technically be mutated after creation,
  * this is not allowed.
  */
-class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow {
+class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow {
   /** No-arg constructor for serialization. */
   protected def this() = this(null)
 
   def this(size: Int) = this(new Array[Any](size))
 
-  override def genericGet(ordinal: Int): Any = values(ordinal)
+  override protected def genericGet(ordinal: Int) = values(ordinal)
 
   override def toSeq: Seq[Any] = values
 
@@ -103,13 +103,13 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType)
   def fieldIndex(name: String): Int = schema.fieldIndex(name)
 }
 
-class GenericMutableRow(val values: Array[Any]) extends MutableRow {
+class GenericMutableRow(values: Array[Any]) extends MutableRow {
   /** No-arg constructor for serialization. */
   protected def this() = this(null)
 
   def this(size: Int) = this(new Array[Any](size))
 
-  override def genericGet(ordinal: Int): Any = values(ordinal)
+  override protected def genericGet(ordinal: Int) = values(ordinal)
 
   override def toSeq: Seq[Any] = values
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index b314acdfe3..459fcb6fc0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -17,22 +17,33 @@
 
 package org.apache.spark.sql.types
 
-import scala.reflect.ClassTag
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
-import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters
-
-class GenericArrayData(private[sql] val array: Array[Any])
-  extends ArrayData with GenericSpecializedGetters {
-
-  override def genericGet(ordinal: Int): Any = array(ordinal)
+class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
 
   override def copy(): ArrayData = new GenericArrayData(array.clone())
 
-  // todo: Array is invariant in scala, maybe use toSeq instead?
-  override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T])
-
   override def numElements(): Int = array.length
 
+  private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T]
+  override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+  override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal)
+  override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
+  override def getByte(ordinal: Int): Byte = getAs(ordinal)
+  override def getShort(ordinal: Int): Short = getAs(ordinal)
+  override def getInt(ordinal: Int): Int = getAs(ordinal)
+  override def getLong(ordinal: Int): Long = getAs(ordinal)
+  override def getFloat(ordinal: Int): Float = getAs(ordinal)
+  override def getDouble(ordinal: Int): Double = getAs(ordinal)
+  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
+  override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
+  override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
+  override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+  override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
+  override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
+  override def getMap(ordinal: Int): MapData = getAs(ordinal)
+
   override def toString(): String = array.mkString("[", ",", "]")
 
   override def equals(o: Any): Boolean = {
@@ -56,8 +67,8 @@ class GenericArrayData(private[sql] val array: Array[Any])
         return false
       }
       if (!isNullAt(i)) {
-        val o1 = genericGet(i)
-        val o2 = other.genericGet(i)
+        val o1 = array(i)
+        val o2 = other.array(i)
         o1 match {
           case b1: Array[Byte] =>
             if (!o2.isInstanceOf[Array[Byte]] ||
@@ -91,7 +102,7 @@ class GenericArrayData(private[sql] val array: Array[Any])
         if (isNullAt(i)) {
           0
         } else {
-          genericGet(i) match {
+          array(i) match {
             case b: Boolean => if (b) 0 else 1
             case b: Byte => b.toInt
             case s: Short => s.toInt
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 6b91e51ca5..d9d7bc19bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -187,15 +187,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
         // To see whether the `index`-th column is a partition column...
         val i = partitionColumns.indexOf(name)
         if (i != -1) {
+          val dt = schema(partitionColumns(i)).dataType
           // If yes, gets column value from partition values.
           (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
-            mutableRow(ordinal) = partitionValues.genericGet(i)
+            mutableRow(ordinal) = partitionValues.get(i, dt)
           }
         } else {
           // Otherwise, inherits the value from scanned data.
           val i = nonPartitionColumns.indexOf(name)
+          val dt = schema(nonPartitionColumns(i)).dataType
           (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
-            mutableRow(ordinal) = dataRow.genericGet(i)
+            mutableRow(ordinal) = dataRow.get(i, dt)
           }
         }
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 66014ddca0..16e0187ed2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -61,11 +61,11 @@ class ColumnStatsSuite extends SparkFunSuite {
       val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
       val stats = columnStats.collectedStatistics
 
-      assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
-      assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
-      assertResult(10, "Wrong null count")(stats.genericGet(2))
-      assertResult(20, "Wrong row count")(stats.genericGet(3))
-      assertResult(stats.genericGet(4), "Wrong size in bytes") {
+      assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
+      assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
+      assertResult(10, "Wrong null count")(stats.get(2, null))
+      assertResult(20, "Wrong row count")(stats.get(3, null))
+      assertResult(stats.get(4, null), "Wrong size in bytes") {
         rows.map { row =>
           if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
         }.sum
@@ -96,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite {
       val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
       val stats = columnStats.collectedStatistics
 
-      assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
-      assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
-      assertResult(10, "Wrong null count")(stats.genericGet(2))
-      assertResult(20, "Wrong row count")(stats.genericGet(3))
-      assertResult(stats.genericGet(4), "Wrong size in bytes") {
+      assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null))
+      assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null))
+      assertResult(10, "Wrong null count")(stats.get(2, null))
+      assertResult(20, "Wrong row count")(stats.get(3, null))
+      assertResult(stats.get(4, null), "Wrong size in bytes") {
         rows.map { row =>
           if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
         }.sum
-- 
GitLab