From 7d29c72f64f8637d8182fb7c495f87ab7ce86ea0 Mon Sep 17 00:00:00 2001
From: Eric Liang <ekl@databricks.com>
Date: Tue, 5 Apr 2016 21:22:20 -0500
Subject: [PATCH] [SPARK-14359] Unit tests for java 8 lambda syntax with typed
 aggregates

## What changes were proposed in this pull request?

Adds unit tests for java 8 lambda syntax with typed aggregates as a follow-up to #12168

## How was this patch tested?

Unit tests.

Author: Eric Liang <ekl@databricks.com>

Closes #12181 from ericl/sc-2794-2.
---
 external/java8-tests/pom.xml                  | 12 +++
 .../sql/Java8DatasetAggregatorSuite.java      | 61 +++++++++++++
 .../sources/JavaDatasetAggregatorSuite.java   | 86 ++++++++++---------
 3 files changed, 118 insertions(+), 41 deletions(-)
 create mode 100644 external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java

diff --git a/external/java8-tests/pom.xml b/external/java8-tests/pom.xml
index f5a06467ee..1ea9196e9d 100644
--- a/external/java8-tests/pom.xml
+++ b/external/java8-tests/pom.xml
@@ -58,6 +58,18 @@
       <type>test-jar</type>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sql_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sql_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
     <dependency>
       <groupId>org.apache.spark</groupId>
       <artifactId>spark-test-tags_${scala.binary.version}</artifactId>
diff --git a/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
new file mode 100644
index 0000000000..23abfa3970
--- /dev/null
+++ b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.util.Arrays;
+
+import org.junit.Assert;
+import org.junit.Test;
+import scala.Tuple2;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.expressions.java.typed;
+
+/**
+ * Suite that replicates tests in JavaDatasetAggregatorSuite using lambda syntax.
+ */
+public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
+  @Test
+  public void testTypedAggregationAverage() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2)));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationCount() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(v -> v));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationSumDouble() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(v -> (double)v._2()));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+  }
+
+  @Test
+  public void testTypedAggregationSumLong() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+    Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(v -> (long)v._2()));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+  }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index c8d0eecd5c..594f4675bd 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -41,46 +41,7 @@ import org.apache.spark.sql.test.TestSQLContext;
 /**
  * Suite for testing the aggregate functionality of Datasets in Java.
  */
-public class JavaDatasetAggregatorSuite implements Serializable {
-  private transient JavaSparkContext jsc;
-  private transient TestSQLContext context;
-
-  @Before
-  public void setUp() {
-    // Trigger static initializer of TestData
-    SparkContext sc = new SparkContext("local[*]", "testing");
-    jsc = new JavaSparkContext(sc);
-    context = new TestSQLContext(sc);
-    context.loadTestData();
-  }
-
-  @After
-  public void tearDown() {
-    context.sparkContext().stop();
-    context = null;
-    jsc = null;
-  }
-
-  private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
-    return new Tuple2<>(t1, t2);
-  }
-
-  private KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
-    Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
-    List<Tuple2<String, Integer>> data =
-      Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
-    Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
-
-    return ds.groupByKey(
-      new MapFunction<Tuple2<String, Integer>, String>() {
-        @Override
-        public String call(Tuple2<String, Integer> value) throws Exception {
-          return value._1();
-        }
-      },
-      Encoders.STRING());
-  }
-
+public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
   @Test
   public void testTypedAggregationAnonClass() {
     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
@@ -100,7 +61,6 @@ public class JavaDatasetAggregatorSuite implements Serializable {
   }
 
   static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
-
     @Override
     public Integer zero() {
       return 0;
@@ -170,3 +130,47 @@ public class JavaDatasetAggregatorSuite implements Serializable {
     Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
   }
 }
+
+/**
+ * Common test base shared across this and Java8DatasetAggregatorSuite.
+ */
+class JavaDatasetAggregatorSuiteBase implements Serializable {
+  protected transient JavaSparkContext jsc;
+  protected transient TestSQLContext context;
+
+  @Before
+  public void setUp() {
+    // Trigger static initializer of TestData
+    SparkContext sc = new SparkContext("local[*]", "testing");
+    jsc = new JavaSparkContext(sc);
+    context = new TestSQLContext(sc);
+    context.loadTestData();
+  }
+
+  @After
+  public void tearDown() {
+    context.sparkContext().stop();
+    context = null;
+    jsc = null;
+  }
+
+  protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
+    return new Tuple2<>(t1, t2);
+  }
+
+  protected KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
+    Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
+    List<Tuple2<String, Integer>> data =
+      Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+    Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+
+    return ds.groupByKey(
+      new MapFunction<Tuple2<String, Integer>, String>() {
+        @Override
+        public String call(Tuple2<String, Integer> value) throws Exception {
+          return value._1();
+        }
+      },
+      Encoders.STRING());
+  }
+}
-- 
GitLab