From 3e2e1873b2762d07e49de8f9ea709bf3fa2d171c Mon Sep 17 00:00:00 2001
From: Yin Huai <yhuai@databricks.com>
Date: Sun, 15 Nov 2015 13:59:59 -0800
Subject: [PATCH] [SPARK-11738] [SQL] Making ArrayType orderable

https://issues.apache.org/jira/browse/SPARK-11738

Author: Yin Huai <yhuai@databricks.com>

Closes #9718 from yhuai/makingArrayOrderable.
---
 .../sql/catalyst/analysis/CheckAnalysis.scala |  32 +----
 .../expressions/codegen/CodeGenerator.scala   |  43 ++++++
 .../expressions/collectionOperations.scala    |   2 +
 .../sql/catalyst/expressions/ordering.scala   |   6 +
 .../spark/sql/catalyst/util/TypeUtils.scala   |   1 +
 .../spark/sql/types/AbstractDataType.scala    |   1 +
 .../apache/spark/sql/types/ArrayType.scala    |  48 +++++++
 .../analysis/AnalysisErrorSuite.scala         |  37 ++++--
 .../ExpressionTypeCheckingSuite.scala         |  32 ++---
 .../sql/catalyst/analysis/TestRelations.scala |   3 +
 .../expressions/CodeGenerationSuite.scala     |  36 -----
 .../catalyst/expressions/OrderingSuite.scala  | 124 ++++++++++++++++++
 .../spark/sql/DataFrameFunctionsSuite.scala   |  12 +-
 .../execution/AggregationQuerySuite.scala     |  52 ++++++++
 14 files changed, 335 insertions(+), 94 deletions(-)
 create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 5a4b0c1e39..7b2c93d63d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -137,32 +137,14 @@ trait CheckAnalysis {
               case e => e.children.foreach(checkValidAggregateExpression)
             }
 
-            def checkSupportedGroupingDataType(
-                expressionString: String,
-                dataType: DataType): Unit = dataType match {
-              case BinaryType =>
-                failAnalysis(s"expression $expressionString cannot be used in " +
-                  s"grouping expression because it is in binary type or its inner field is " +
-                  s"in binary type")
-              case a: ArrayType =>
-                failAnalysis(s"expression $expressionString cannot be used in " +
-                  s"grouping expression because it is in array type or its inner field is " +
-                  s"in array type")
-              case m: MapType =>
-                failAnalysis(s"expression $expressionString cannot be used in " +
-                  s"grouping expression because it is in map type or its inner field is " +
-                  s"in map type")
-              case s: StructType =>
-                s.fields.foreach { f =>
-                  checkSupportedGroupingDataType(expressionString, f.dataType)
-                }
-              case udt: UserDefinedType[_] =>
-                checkSupportedGroupingDataType(expressionString, udt.sqlType)
-              case _ => // OK
-            }
-
             def checkValidGroupingExprs(expr: Expression): Unit = {
-              checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
+              // Check if the data type of expr is orderable.
+              if (!RowOrdering.isOrderable(expr.dataType)) {
+                failAnalysis(
+                  s"expression ${expr.prettyString} cannot be used as a grouping expression " +
+                    s"because its data type ${expr.dataType.simpleString} is not a orderable " +
+                    s"data type.")
+              }
 
               if (!expr.deterministic) {
                 // This is just a sanity check, our analysis rule PullOutNondeterministic should
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index ccd91d3549..1718cfbd35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -267,6 +267,49 @@ class CodeGenContext {
     case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
     case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
     case NullType => "0"
+    case array: ArrayType =>
+      val elementType = array.elementType
+      val elementA = freshName("elementA")
+      val isNullA = freshName("isNullA")
+      val elementB = freshName("elementB")
+      val isNullB = freshName("isNullB")
+      val compareFunc = freshName("compareArray")
+      val minLength = freshName("minLength")
+      val funcCode: String =
+        s"""
+          public int $compareFunc(ArrayData a, ArrayData b) {
+            int lengthA = a.numElements();
+            int lengthB = b.numElements();
+            int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
+            for (int i = 0; i < $minLength; i++) {
+              boolean $isNullA = a.isNullAt(i);
+              boolean $isNullB = b.isNullAt(i);
+              if ($isNullA && $isNullB) {
+                // Nothing
+              } else if ($isNullA) {
+                return -1;
+              } else if ($isNullB) {
+                return 1;
+              } else {
+                ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
+                ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
+                int comp = ${genComp(elementType, elementA, elementB)};
+                if (comp != 0) {
+                  return comp;
+                }
+              }
+            }
+
+            if (lengthA < lengthB) {
+              return -1;
+            } else if (lengthA > lengthB) {
+              return 1;
+            }
+            return 0;
+          }
+        """
+      addNewFunction(compareFunc, funcCode)
+      s"this.$compareFunc($c1, $c2)"
     case schema: StructType =>
       val comparisons = GenerateOrdering.genComparisons(this, schema)
       val compareFunc = freshName("compareStruct")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2cf19b939f..741ad1f3ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
   private lazy val lt: Comparator[Any] = {
     val ordering = base.dataType match {
       case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
+      case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
       case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
     }
 
@@ -90,6 +91,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
   private lazy val gt: Comparator[Any] = {
     val ordering = base.dataType match {
       case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
+      case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
       case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
     }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
index 6407c73bc9..6112259fed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
@@ -48,6 +48,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
             dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
           case dt: AtomicType if order.direction == Descending =>
             dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
+          case a: ArrayType if order.direction == Ascending =>
+            a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
+          case a: ArrayType if order.direction == Descending =>
+            a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
           case s: StructType if order.direction == Ascending =>
             s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
           case s: StructType if order.direction == Descending =>
@@ -86,6 +90,8 @@ object RowOrdering {
     case NullType => true
     case dt: AtomicType => true
     case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
+    case array: ArrayType => isOrderable(array.elementType)
+    case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
     case _ => false
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index bcf4d78fb9..f603cbfb0c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -57,6 +57,7 @@ object TypeUtils {
   def getInterpretedOrdering(t: DataType): Ordering[Any] = {
     t match {
       case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
+      case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
       case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
     }
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 1d2d007c2b..a5ae8bb0e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -84,6 +84,7 @@ private[sql] object TypeCollection {
    * Types that can be ordered/compared. In the long run we should probably make this a trait
    * that can be mixed into each data type, and perhaps create an [[AbstractDataType]].
    */
+  // TODO: Should we consolidate this with RowOrdering.isOrderable?
   val Ordered = TypeCollection(
     BooleanType,
     ByteType, ShortType, IntegerType, LongType,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 5770f59b53..a001eadcc6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.sql.types
 
+import org.apache.spark.sql.catalyst.util.ArrayData
 import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.DeveloperApi
 
+import scala.math.Ordering
+
 
 object ArrayType extends AbstractDataType {
   /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
@@ -81,4 +84,49 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
   override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
     f(this) || elementType.existsRecursively(f)
   }
+
+  @transient
+  private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
+    private[this] val elementOrdering: Ordering[Any] = elementType match {
+      case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
+      case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
+      case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
+      case other =>
+        throw new IllegalArgumentException(s"Type $other does not support ordered operations")
+    }
+
+    def compare(x: ArrayData, y: ArrayData): Int = {
+      val leftArray = x
+      val rightArray = y
+      val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
+      var i = 0
+      while (i < minLength) {
+        val isNullLeft = leftArray.isNullAt(i)
+        val isNullRight = rightArray.isNullAt(i)
+        if (isNullLeft && isNullRight) {
+          // Do nothing.
+        } else if (isNullLeft) {
+          return -1
+        } else if (isNullRight) {
+          return 1
+        } else {
+          val comp =
+            elementOrdering.compare(
+              leftArray.get(i, elementType),
+              rightArray.get(i, elementType))
+          if (comp != 0) {
+            return comp
+          }
+        }
+        i += 1
+      }
+      if (leftArray.numElements() < rightArray.numElements()) {
+        return -1
+      } else if (leftArray.numElements() > rightArray.numElements()) {
+        return 1
+      } else {
+        return 0
+      }
+    }
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 2e7c3bd67b..ee43557874 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
+import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData}
 import org.apache.spark.sql.types._
 
 import scala.beans.{BeanProperty, BeanInfo}
@@ -53,21 +53,29 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
 }
 
 @BeanInfo
-private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
+private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int])
 
 private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
 
-  override def sqlType: DataType = ArrayType(IntegerType)
+  override def sqlType: DataType = MapType(IntegerType, IntegerType)
 
-  override def serialize(obj: Any): ArrayData = {
+  override def serialize(obj: Any): MapData = {
     obj match {
-      case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
+      case groupableData: UngroupableData =>
+        val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
+        val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
+        new ArrayBasedMapData(keyArray, valueArray)
     }
   }
 
   override def deserialize(datum: Any): UngroupableData = {
     datum match {
-      case data: Array[Int] => UngroupableData(data)
+      case data: MapData =>
+        val keyArray = data.keyArray().array
+        val valueArray = data.valueArray().array
+        assert(keyArray.length == valueArray.length)
+        val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]]
+        UngroupableData(mapData)
     }
   }
 
@@ -154,8 +162,8 @@ class AnalysisErrorSuite extends AnalysisTest {
 
   errorTest(
     "sorting by unsupported column types",
-    listRelation.orderBy('list.asc),
-    "sort" :: "type" :: "array<int>" :: Nil)
+    mapRelation.orderBy('map.asc),
+    "sort" :: "type" :: "map<int,int>" :: Nil)
 
   errorTest(
     "non-boolean filters",
@@ -259,32 +267,33 @@ class AnalysisErrorSuite extends AnalysisTest {
         case true =>
           assertAnalysisSuccess(plan, true)
         case false =>
-          assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
+          assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
       }
-
     }
 
     val supportedDataTypes = Seq(
-      StringType,
+      StringType, BinaryType,
       NullType, BooleanType,
       ByteType, ShortType, IntegerType, LongType,
       FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
       DateType, TimestampType,
+      ArrayType(IntegerType),
       new StructType()
         .add("f1", FloatType, nullable = true)
         .add("f2", StringType, nullable = true),
+      new StructType()
+        .add("f1", FloatType, nullable = true)
+        .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
       new GroupableUDT())
     supportedDataTypes.foreach { dataType =>
       checkDataType(dataType, shouldSuccess = true)
     }
 
     val unsupportedDataTypes = Seq(
-      BinaryType,
-      ArrayType(IntegerType),
       MapType(StringType, LongType),
       new StructType()
         .add("f1", FloatType, nullable = true)
-        .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
+        .add("f2", MapType(StringType, LongType), nullable = true),
       new UngroupableUDT())
     unsupportedDataTypes.foreach { dataType =>
       checkDataType(dataType, shouldSuccess = false)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index b902982add..ba1866efc8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{TypeCollection, StringType}
+import org.apache.spark.sql.types.{LongType, TypeCollection, StringType}
 
 class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
@@ -32,7 +32,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     'intField.int,
     'stringField.string,
     'booleanField.boolean,
-    'complexField.array(StringType))
+    'arrayField.array(StringType),
+    'mapField.map(StringType, LongType))
 
   def assertError(expr: Expression, errorMessage: String): Unit = {
     val e = intercept[AnalysisException] {
@@ -90,9 +91,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type")
     assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type")
 
-    assertError(MaxOf('complexField, 'complexField),
+    assertError(MaxOf('mapField, 'mapField),
       s"requires ${TypeCollection.Ordered.simpleString} type")
-    assertError(MinOf('complexField, 'complexField),
+    assertError(MinOf('mapField, 'mapField),
       s"requires ${TypeCollection.Ordered.simpleString} type")
   }
 
@@ -109,20 +110,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertSuccess(EqualTo('intField, 'booleanField))
     assertSuccess(EqualNullSafe('intField, 'booleanField))
 
-    assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
-    assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
+    assertErrorForDifferingTypes(EqualTo('intField, 'mapField))
+    assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField))
     assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
     assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
     assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
     assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
 
-    assertError(LessThan('complexField, 'complexField),
+    assertError(LessThan('mapField, 'mapField),
       s"requires ${TypeCollection.Ordered.simpleString} type")
-    assertError(LessThanOrEqual('complexField, 'complexField),
+    assertError(LessThanOrEqual('mapField, 'mapField),
       s"requires ${TypeCollection.Ordered.simpleString} type")
-    assertError(GreaterThan('complexField, 'complexField),
+    assertError(GreaterThan('mapField, 'mapField),
       s"requires ${TypeCollection.Ordered.simpleString} type")
-    assertError(GreaterThanOrEqual('complexField, 'complexField),
+    assertError(GreaterThanOrEqual('mapField, 'mapField),
       s"requires ${TypeCollection.Ordered.simpleString} type")
 
     assertError(If('intField, 'stringField, 'stringField),
@@ -130,10 +131,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
 
     assertError(
-      CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)),
+      CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
       "THEN and ELSE expressions should all be same type or coercible to a common type")
     assertError(
-      CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)),
+      CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
       "THEN and ELSE expressions should all be same type or coercible to a common type")
     assertError(
       CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
@@ -147,9 +148,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     // We will cast String to Double for sum and average
     assertSuccess(Sum('stringField))
     assertSuccess(Average('stringField))
+    assertSuccess(Min('arrayField))
 
-    assertError(Min('complexField), "min does not support ordering on type")
-    assertError(Max('complexField), "max does not support ordering on type")
+    assertError(Min('mapField), "min does not support ordering on type")
+    assertError(Max('mapField), "max does not support ordering on type")
     assertError(Sum('booleanField), "function sum requires numeric type")
     assertError(Average('booleanField), "function average requires numeric type")
   }
@@ -184,7 +186,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
     assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
     assertError(Round('intField, 'booleanField), "requires int type")
-    assertError(Round('intField, 'complexField), "requires int type")
+    assertError(Round('intField, 'mapField), "requires int type")
     assertError(Round('booleanField, 'intField), "requires numeric type")
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
index 05b870705e..bc07b609a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -48,4 +48,7 @@ object TestRelations {
 
   val listRelation = LocalRelation(
     AttributeReference("list", ArrayType(IntegerType))())
+
+  val mapRelation = LocalRelation(
+    AttributeReference("map", MapType(IntegerType, IntegerType))())
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index e323467af5..002ed16dcf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import scala.math._
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.{Row, RandomDataGenerator}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -49,40 +47,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     futures.foreach(Await.result(_, 10.seconds))
   }
 
-  // Test GenerateOrdering for all common types. For each type, we construct random input rows that
-  // contain two columns of that type, then for pairs of randomly-generated rows we check that
-  // GenerateOrdering agrees with RowOrdering.
-  (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
-    test(s"GenerateOrdering with $dataType") {
-      val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType))
-      val genOrdering = GenerateOrdering.generate(
-        BoundReference(0, dataType, nullable = true).asc ::
-          BoundReference(1, dataType, nullable = true).asc :: Nil)
-      val rowType = StructType(
-        StructField("a", dataType, nullable = true) ::
-          StructField("b", dataType, nullable = true) :: Nil)
-      val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
-      assume(maybeDataGenerator.isDefined)
-      val randGenerator = maybeDataGenerator.get
-      val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
-      for (_ <- 1 to 50) {
-        val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
-        val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
-        withClue(s"a = $a, b = $b") {
-          assert(genOrdering.compare(a, a) === 0)
-          assert(genOrdering.compare(b, b) === 0)
-          assert(rowOrdering.compare(a, a) === 0)
-          assert(rowOrdering.compare(b, b) === 0)
-          assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
-          assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
-          assert(
-            signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
-            "Generated and non-generated orderings should agree")
-        }
-      }
-    }
-  }
-
   test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
     val length = 5000
     val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
new file mode 100644
index 0000000000..7ad8657bde
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.math._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{Row, RandomDataGenerator}
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
+import org.apache.spark.sql.types._
+
+class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  def compareArrays(a: Seq[Any], b: Seq[Any], expected: Int): Unit = {
+    test(s"compare two arrays: a = $a, b = $b") {
+      val dataType = ArrayType(IntegerType)
+      val rowType = StructType(StructField("array", dataType, nullable = true) :: Nil)
+      val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+      val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow]
+      val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow]
+      Seq(Ascending, Descending).foreach { direction =>
+        val sortOrder = direction match {
+          case Ascending => BoundReference(0, dataType, nullable = true).asc
+          case Descending => BoundReference(0, dataType, nullable = true).desc
+        }
+        val expectedCompareResult = direction match {
+          case Ascending => signum(expected)
+          case Descending => -1 * signum(expected)
+        }
+        val intOrdering = new InterpretedOrdering(sortOrder :: Nil)
+        val genOrdering = GenerateOrdering.generate(sortOrder :: Nil)
+        Seq(intOrdering, genOrdering).foreach { ordering =>
+          assert(ordering.compare(rowA, rowA) === 0)
+          assert(ordering.compare(rowB, rowB) === 0)
+          assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult)
+          assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult)
+        }
+      }
+    }
+  }
+
+  // Two arrays have the same size.
+  compareArrays(Seq[Any](), Seq[Any](), 0)
+  compareArrays(Seq[Any](1), Seq[Any](1), 0)
+  compareArrays(Seq[Any](1, 2), Seq[Any](1, 2), 0)
+  compareArrays(Seq[Any](1, 2, 2), Seq[Any](1, 2, 3), -1)
+
+  // Two arrays have different sizes.
+  compareArrays(Seq[Any](), Seq[Any](1), -1)
+  compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 4), -1)
+  compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 2), -1)
+  compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 2, 2), 1)
+
+  // Arrays having nulls.
+  compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, null), -1)
+  compareArrays(Seq[Any](), Seq[Any](null), -1)
+  compareArrays(Seq[Any](null), Seq[Any](null), 0)
+  compareArrays(Seq[Any](null, null), Seq[Any](null, null), 0)
+  compareArrays(Seq[Any](null), Seq[Any](null, null), -1)
+  compareArrays(Seq[Any](null), Seq[Any](1), -1)
+  compareArrays(Seq[Any](null), Seq[Any](null, 1), -1)
+  compareArrays(Seq[Any](null, 1), Seq[Any](1, 1), -1)
+  compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0)
+  compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1)
+
+  // Test GenerateOrdering for all common types. For each type, we construct random input rows that
+  // contain two columns of that type, then for pairs of randomly-generated rows we check that
+  // GenerateOrdering agrees with RowOrdering.
+  {
+    val structType =
+      new StructType()
+        .add("f1", FloatType, nullable = true)
+        .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true)
+    val arrayOfStructType = ArrayType(structType)
+    val complexTypes = ArrayType(IntegerType) :: structType :: arrayOfStructType :: Nil
+    (DataTypeTestUtils.atomicTypes ++ complexTypes ++ Set(NullType)).foreach { dataType =>
+      test(s"GenerateOrdering with $dataType") {
+        val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType))
+        val genOrdering = GenerateOrdering.generate(
+          BoundReference(0, dataType, nullable = true).asc ::
+            BoundReference(1, dataType, nullable = true).asc :: Nil)
+        val rowType = StructType(
+          StructField("a", dataType, nullable = true) ::
+            StructField("b", dataType, nullable = true) :: Nil)
+        val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
+        assume(maybeDataGenerator.isDefined)
+        val randGenerator = maybeDataGenerator.get
+        val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+        for (_ <- 1 to 50) {
+          val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+          val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+          withClue(s"a = $a, b = $b") {
+            assert(genOrdering.compare(a, a) === 0)
+            assert(genOrdering.compare(b, b) === 0)
+            assert(rowOrdering.compare(a, a) === 0)
+            assert(rowOrdering.compare(b, b) === 0)
+            assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
+            assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
+            assert(
+              signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
+              "Generated and non-generated orderings should agree")
+          }
+        }
+      }
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 3a3f19af14..aff9efe4b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -308,10 +308,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
         Row(null, null))
     )
 
-    val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
-    assert(intercept[AnalysisException] {
-      df2.selectExpr("sort_array(a)").collect()
-    }.getMessage().contains("does not support sorting array of type array<int>"))
+    val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b")
+    checkAnswer(
+      df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"),
+      Seq(
+        Row(
+          Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)),
+          Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null)))
+    )
 
     val df3 = Seq(("xxx", "x")).toDF("a", "b")
     assert(intercept[AnalysisException] {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 61e3e913c2..6dde79f74d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -132,6 +132,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
       (3, null, null)).toDF("key", "value1", "value2")
     data2.write.saveAsTable("agg2")
 
+    val data3 = Seq[(Seq[Integer], Integer, Integer)](
+      (Seq[Integer](1, 1), 10, -10),
+      (Seq[Integer](null), -60, 60),
+      (Seq[Integer](1, 1), 30, -30),
+      (Seq[Integer](1), 30, 30),
+      (Seq[Integer](2), 1, 1),
+      (null, -10, 10),
+      (Seq[Integer](2, 3), -1, null),
+      (Seq[Integer](2, 3), 1, 1),
+      (Seq[Integer](2, 3, 4), null, 1),
+      (Seq[Integer](null), 100, -10),
+      (Seq[Integer](3), null, 3),
+      (null, null, null),
+      (Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
+    data3.write.saveAsTable("agg3")
+
     val emptyDF = sqlContext.createDataFrame(
       sparkContext.emptyRDD[Row],
       StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
@@ -146,6 +162,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
   override def afterAll(): Unit = {
     sqlContext.sql("DROP TABLE IF EXISTS agg1")
     sqlContext.sql("DROP TABLE IF EXISTS agg2")
+    sqlContext.sql("DROP TABLE IF EXISTS agg3")
     sqlContext.dropTempTable("emptyTable")
   }
 
@@ -266,6 +283,41 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
         Row(100, null) ::
         Row(null, 3) ::
         Row(null, null) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT DISTINCT key
+          |FROM agg3
+        """.stripMargin),
+      Row(Seq[Integer](1, 1)) ::
+        Row(Seq[Integer](null)) ::
+        Row(Seq[Integer](1)) ::
+        Row(Seq[Integer](2)) ::
+        Row(null) ::
+        Row(Seq[Integer](2, 3)) ::
+        Row(Seq[Integer](2, 3, 4)) ::
+        Row(Seq[Integer](3)) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT value1, key
+          |FROM agg3
+          |GROUP BY value1, key
+        """.stripMargin),
+      Row(10, Seq[Integer](1, 1)) ::
+        Row(-60, Seq[Integer](null)) ::
+        Row(30, Seq[Integer](1, 1)) ::
+        Row(30, Seq[Integer](1)) ::
+        Row(1, Seq[Integer](2)) ::
+        Row(-10, null) ::
+        Row(-1, Seq[Integer](2, 3)) ::
+        Row(1, Seq[Integer](2, 3)) ::
+        Row(null, Seq[Integer](2, 3, 4)) ::
+        Row(100, Seq[Integer](null)) ::
+        Row(null, Seq[Integer](3)) ::
+        Row(null, null) :: Nil)
   }
 
   test("case in-sensitive resolution") {
-- 
GitLab