From 0e36ba6212bc24b3185e385914fbf2d62cbfb6da Mon Sep 17 00:00:00 2001
From: WeichenXu <weichen.xu@databricks.com>
Date: Tue, 12 Dec 2017 21:28:24 -0800
Subject: [PATCH] [SPARK-22644][ML][TEST] Make ML testsuite support
 StructuredStreaming test

## What changes were proposed in this pull request?

We need to add some helper code to make testing ML transformers & models easier with streaming data. These tests might help us catch any remaining issues and we could encourage future PRs to use these tests to prevent new Models & Transformers from having issues.

I add a `MLTest` trait which extends `StreamTest` trait, and override `createSparkSession`. So ML testsuite can only extend `MLTest`, to use both ML & Stream test util functions.

I only modify one testcase in `LinearRegressionSuite`, for first pass review.

Link to #19746

## How was this patch tested?

`MLTestSuite` added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19843 from WeichenXu123/ml_stream_test_helper.
---
 mllib/pom.xml                                 | 14 +++
 .../ml/regression/LinearRegressionSuite.scala |  8 +-
 .../org/apache/spark/ml/util/MLTest.scala     | 91 +++++++++++++++++++
 .../apache/spark/ml/util/MLTestSuite.scala    | 47 ++++++++++
 .../spark/sql/streaming/StreamTest.scala      | 67 +++++++++-----
 .../spark/sql/test/TestSQLContext.scala       |  2 +-
 6 files changed, 203 insertions(+), 26 deletions(-)
 create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala

diff --git a/mllib/pom.xml b/mllib/pom.xml
index 925b5422a5..a906c9e02c 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -60,6 +60,20 @@
       <artifactId>spark-sql_${scala.binary.version}</artifactId>
       <version>${project.version}</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sql_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
     <dependency>
       <groupId>org.apache.spark</groupId>
       <artifactId>spark-graphx_${scala.binary.version}</artifactId>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 0e0be58dbf..aec5ac0e75 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.{DataFrame, Row}
 
-class LinearRegressionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -233,7 +232,8 @@ class LinearRegressionSuite
       assert(model2.intercept ~== interceptR relTol 1E-3)
       assert(model2.coefficients ~= coefficientsR relTol 1E-3)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
         case Row(features: DenseVector, prediction1: Double) =>
           val prediction2 =
             features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
new file mode 100644
index 0000000000..7a5426ebad
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.File
+
+import org.scalatest.Suite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{PipelineModel, Transformer}
+import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.util.Utils
+
+trait MLTest extends StreamTest with TempDirectory { self: Suite =>
+
+  @transient var sc: SparkContext = _
+  @transient var checkpointDir: String = _
+
+  protected override def createSparkSession: TestSparkSession = {
+    new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf))
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sc = spark.sparkContext
+    checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
+    sc.setCheckpointDir(checkpointDir)
+  }
+
+  override def afterAll() {
+    try {
+      Utils.deleteRecursively(new File(checkpointDir))
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  def testTransformerOnStreamData[A : Encoder](
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (checkFunction: Row => Unit): Unit = {
+
+    val columnNames = dataframe.schema.fieldNames
+    val stream = MemoryStream[A]
+    val streamDF = stream.toDS().toDF(columnNames: _*)
+
+    val data = dataframe.as[A].collect()
+
+    val streamOutput = transformer.transform(streamDF)
+      .select(firstResultCol, otherResultCols: _*)
+    testStream(streamOutput) (
+      AddData(stream, data: _*),
+      CheckAnswer(checkFunction)
+    )
+  }
+
+  def testTransformer[A : Encoder](
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (checkFunction: Row => Unit): Unit = {
+    testTransformerOnStreamData(dataframe, transformer, firstResultCol,
+      otherResultCols: _*)(checkFunction)
+
+    val dfOutput = transformer.transform(dataframe)
+    dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row =>
+      checkFunction(row)
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
new file mode 100644
index 0000000000..56217ec4f3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.{PipelineModel, Transformer}
+import org.apache.spark.ml.feature.StringIndexer
+import org.apache.spark.sql.Row
+
+class MLTestSuite extends MLTest {
+
+  import testImplicits._
+
+  test("test transformer on stream data") {
+
+    val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))
+      .toDF("id", "label")
+    val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
+      .setInputCol("label").setOutputCol("indexed")
+    val indexerModel = indexer.fit(data)
+    testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
+      case Row(id: Int, indexed: Double) =>
+        assert(id === indexed.toInt)
+    }
+
+    intercept[Exception] {
+      testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
+        case Row(id: Int, indexed: Double) =>
+          assert(id != indexed.toInt)
+      }
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index e68fca0505..dc5b998ad6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     }
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)
+
+    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(checkFunction, false)
   }
 
   /**
@@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     }
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
+
+    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(checkFunction, true)
   }
 
   case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
@@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
   }
 
+  case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean)
+      extends StreamAction with StreamMustBeRunning {
+    override def toString: String = s"$operatorName: ${checkFunction.toString()}"
+    private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
+  }
+
   /** Stops the stream. It must currently be running. */
   case object StopStream extends StreamAction with StreamMustBeRunning
 
@@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
          """.stripMargin)
     }
 
+    def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
+      verify(currentStream != null, "stream not running")
+      // Get the map of source index to the current source objects
+      val indexToSource = currentStream
+        .logicalPlan
+        .collect { case StreamingExecutionRelation(s, _) => s }
+        .zipWithIndex
+        .map(_.swap)
+        .toMap
+
+      // Block until all data added has been processed for all the source
+      awaiting.foreach { case (sourceIndex, offset) =>
+        failAfter(streamingTimeout) {
+          currentStream.awaitOffset(indexToSource(sourceIndex), offset)
+        }
+      }
+
+      try if (lastOnly) sink.latestBatchData else sink.allData catch {
+        case e: Exception =>
+          failTest("Exception while getting data from sink", e)
+      }
+    }
+
     var manualClockExpectedTime = -1L
     val defaultCheckpointLocation =
       Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
             e.runAction()
 
           case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
-            verify(currentStream != null, "stream not running")
-            // Get the map of source index to the current source objects
-            val indexToSource = currentStream
-              .logicalPlan
-              .collect { case StreamingExecutionRelation(s, _) => s }
-              .zipWithIndex
-              .map(_.swap)
-              .toMap
-
-            // Block until all data added has been processed for all the source
-            awaiting.foreach { case (sourceIndex, offset) =>
-              failAfter(streamingTimeout) {
-                currentStream.awaitOffset(indexToSource(sourceIndex), offset)
-              }
-            }
-
-            val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch {
-              case e: Exception =>
-                failTest("Exception while getting data from sink", e)
-            }
-
+            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
             QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
               error => failTest(error)
             }
+
+          case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
+            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+            sparkAnswer.foreach { row =>
+              try {
+                checkFunction(row)
+              } catch {
+                case e: Throwable => failTest(e.toString)
+              }
+            }
         }
         pos += 1
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 959edf9a49..4286e8a6ca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf
 /**
  * A special `SparkSession` prepared for testing.
  */
-private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
+private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
   def this(sparkConf: SparkConf) {
     this(new SparkContext("local[2]", "test-sql-context",
       sparkConf.set("spark.sql.testkey", "true")))
-- 
GitLab