Skip to content
Snippets Groups Projects
Commit 7d29c72f authored by Eric Liang's avatar Eric Liang Committed by Reynold Xin
Browse files

[SPARK-14359] Unit tests for java 8 lambda syntax with typed aggregates

## What changes were proposed in this pull request?

Adds unit tests for java 8 lambda syntax with typed aggregates as a follow-up to #12168

## How was this patch tested?

Unit tests.

Author: Eric Liang <ekl@databricks.com>

Closes #12181 from ericl/sc-2794-2.
parent 1146c534
No related branches found
No related tags found
No related merge requests found
......@@ -58,6 +58,18 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-test-tags_${scala.binary.version}</artifactId>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql.sources;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.Test;
import scala.Tuple2;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.java.typed;
/**
* Suite that replicates tests in JavaDatasetAggregatorSuite using lambda syntax.
*/
public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
@Test
public void testTypedAggregationAverage() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2)));
Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
}
@Test
public void testTypedAggregationCount() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(v -> v));
Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
}
@Test
public void testTypedAggregationSumDouble() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(v -> (double)v._2()));
Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
}
@Test
public void testTypedAggregationSumLong() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(v -> (long)v._2()));
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
}
}
......@@ -41,46 +41,7 @@ import org.apache.spark.sql.test.TestSQLContext;
/**
* Suite for testing the aggregate functionality of Datasets in Java.
*/
public class JavaDatasetAggregatorSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
SparkContext sc = new SparkContext("local[*]", "testing");
jsc = new JavaSparkContext(sc);
context = new TestSQLContext(sc);
context.loadTestData();
}
@After
public void tearDown() {
context.sparkContext().stop();
context = null;
jsc = null;
}
private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
return new Tuple2<>(t1, t2);
}
private KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
List<Tuple2<String, Integer>> data =
Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
return ds.groupByKey(
new MapFunction<Tuple2<String, Integer>, String>() {
@Override
public String call(Tuple2<String, Integer> value) throws Exception {
return value._1();
}
},
Encoders.STRING());
}
public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
@Test
public void testTypedAggregationAnonClass() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
......@@ -100,7 +61,6 @@ public class JavaDatasetAggregatorSuite implements Serializable {
}
static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
@Override
public Integer zero() {
return 0;
......@@ -170,3 +130,47 @@ public class JavaDatasetAggregatorSuite implements Serializable {
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
}
}
/**
* Common test base shared across this and Java8DatasetAggregatorSuite.
*/
class JavaDatasetAggregatorSuiteBase implements Serializable {
protected transient JavaSparkContext jsc;
protected transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
SparkContext sc = new SparkContext("local[*]", "testing");
jsc = new JavaSparkContext(sc);
context = new TestSQLContext(sc);
context.loadTestData();
}
@After
public void tearDown() {
context.sparkContext().stop();
context = null;
jsc = null;
}
protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
return new Tuple2<>(t1, t2);
}
protected KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
List<Tuple2<String, Integer>> data =
Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
return ds.groupByKey(
new MapFunction<Tuple2<String, Integer>, String>() {
@Override
public String call(Tuple2<String, Integer> value) throws Exception {
return value._1();
}
},
Encoders.STRING());
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment