From 0e36ba6212bc24b3185e385914fbf2d62cbfb6da Mon Sep 17 00:00:00 2001 From: WeichenXu <weichen.xu@databricks.com> Date: Tue, 12 Dec 2017 21:28:24 -0800 Subject: [PATCH] [SPARK-22644][ML][TEST] Make ML testsuite support StructuredStreaming test ## What changes were proposed in this pull request? We need to add some helper code to make testing ML transformers & models easier with streaming data. These tests might help us catch any remaining issues and we could encourage future PRs to use these tests to prevent new Models & Transformers from having issues. I add a `MLTest` trait which extends `StreamTest` trait, and override `createSparkSession`. So ML testsuite can only extend `MLTest`, to use both ML & Stream test util functions. I only modify one testcase in `LinearRegressionSuite`, for first pass review. Link to #19746 ## How was this patch tested? `MLTestSuite` added. Author: WeichenXu <weichen.xu@databricks.com> Closes #19843 from WeichenXu123/ml_stream_test_helper. --- mllib/pom.xml | 14 +++ .../ml/regression/LinearRegressionSuite.scala | 8 +- .../org/apache/spark/ml/util/MLTest.scala | 91 +++++++++++++++++++ .../apache/spark/ml/util/MLTestSuite.scala | 47 ++++++++++ .../spark/sql/streaming/StreamTest.scala | 67 +++++++++----- .../spark/sql/test/TestSQLContext.scala | 2 +- 6 files changed, 203 insertions(+), 26 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala diff --git a/mllib/pom.xml b/mllib/pom.xml index 925b5422a5..a906c9e02c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -60,6 +60,20 @@ <artifactId>spark-sql_${scala.binary.version}</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-catalyst_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-graphx_${scala.binary.version}</artifactId> diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 0e0be58dbf..aec5ac0e75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -233,7 +232,8 @@ class LinearRegressionSuite assert(model2.intercept ~== interceptR relTol 1E-3) assert(model2.coefficients ~= coefficientsR relTol 1E-3) - model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { + testTransformer[(Double, Vector)](datasetWithDenseFeature, model1, + "features", "prediction") { case Row(features: DenseVector, prediction1: Double) => val prediction2 = features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala new file mode 100644 index 0000000000..7a5426ebad --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.File + +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.util.Utils + +trait MLTest extends StreamTest with TempDirectory { self: Suite => + + @transient var sc: SparkContext = _ + @transient var checkpointDir: String = _ + + protected override def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf)) + } + + override def beforeAll(): Unit = { + super.beforeAll() + sc = spark.sparkContext + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString + sc.setCheckpointDir(checkpointDir) + } + + override def afterAll() { + try { + Utils.deleteRecursively(new File(checkpointDir)) + } finally { + super.afterAll() + } + } + + def testTransformerOnStreamData[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + firstResultCol: String, + otherResultCols: String*) + (checkFunction: Row => Unit): Unit = { + + val columnNames = dataframe.schema.fieldNames + val stream = MemoryStream[A] + val streamDF = stream.toDS().toDF(columnNames: _*) + + val data = dataframe.as[A].collect() + + val streamOutput = transformer.transform(streamDF) + .select(firstResultCol, otherResultCols: _*) + testStream(streamOutput) ( + AddData(stream, data: _*), + CheckAnswer(checkFunction) + ) + } + + def testTransformer[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + firstResultCol: String, + otherResultCols: String*) + (checkFunction: Row => Unit): Unit = { + testTransformerOnStreamData(dataframe, transformer, firstResultCol, + otherResultCols: _*)(checkFunction) + + val dfOutput = transformer.transform(dataframe) + dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row => + checkFunction(row) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala new file mode 100644 index 0000000000..56217ec4f3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.Row + +class MLTestSuite extends MLTest { + + import testImplicits._ + + test("test transformer on stream data") { + + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f")) + .toDF("id", "label") + val indexer = new StringIndexer().setStringOrderType("alphabetAsc") + .setInputCol("label").setOutputCol("indexed") + val indexerModel = indexer.fit(data) + testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { + case Row(id: Int, indexed: Double) => + assert(id === indexed.toInt) + } + + intercept[Exception] { + testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") { + case Row(id: Int, indexed: Double) => + assert(id != indexed.toInt) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e68fca0505..dc5b998ad6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false) + + def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(checkFunction, false) } /** @@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false) + + def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc = + CheckAnswerRowsByFunc(checkFunction, true) } case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean) @@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } + case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${checkFunction.toString()}" + private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + } + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } + def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + verify(currentStream != null, "stream not running") + // Get the map of source index to the current source objects + val indexToSource = currentStream + .logicalPlan + .collect { case StreamingExecutionRelation(s, _) => s } + .zipWithIndex + .map(_.swap) + .toMap + + // Block until all data added has been processed for all the source + awaiting.foreach { case (sourceIndex, offset) => + failAfter(streamingTimeout) { + currentStream.awaitOffset(indexToSource(sourceIndex), offset) + } + } + + try if (lastOnly) sink.latestBatchData else sink.allData catch { + case e: Exception => + failTest("Exception while getting data from sink", e) + } + } + var manualClockExpectedTime = -1L val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be e.runAction() case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => - verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap - - // Block until all data added has been processed for all the source - awaiting.foreach { case (sourceIndex, offset) => - failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) - } - } - - val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch { - case e: Exception => - failTest("Exception while getting data from sink", e) - } - + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { error => failTest(error) } + + case CheckAnswerRowsByFunc(checkFunction, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + sparkAnswer.foreach { row => + try { + checkFunction(row) + } catch { + case e: Throwable => failTest(e.toString) + } + } } pos += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 959edf9a49..4286e8a6ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf /** * A special `SparkSession` prepared for testing. */ -private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => +private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) -- GitLab