From 2ea17afb63f976500273518bf1b32f9efe250812 Mon Sep 17 00:00:00 2001
From: WeichenXu <weichen.xu@databricks.com>
Date: Fri, 29 Dec 2017 20:06:56 -0800
Subject: [PATCH] [SPARK-22881][ML][TEST] ML regression package testsuite add
 StructuredStreaming test

## What changes were proposed in this pull request?

ML regression package testsuite add StructuredStreaming test

In order to make testsuite easier to modify, new helper function added in `MLTest`:
```
def testTransformerByGlobalCheckFunc[A : Encoder](
      dataframe: DataFrame,
      transformer: Transformer,
      firstResultCol: String,
      otherResultCols: String*)
      (globalCheckFunction: Seq[Row] => Unit): Unit
```

## How was this patch tested?

N/A

Author: WeichenXu <weichen.xu@databricks.com>
Author: Bago Amirbekian <bago@databricks.com>

Closes #19979 from WeichenXu123/ml_stream_test.
---
 .../AFTSurvivalRegressionSuite.scala          | 19 ++++----
 .../DecisionTreeRegressorSuite.scala          | 43 ++++++++---------
 .../ml/regression/GBTRegressorSuite.scala     | 23 ++++-----
 .../GeneralizedLinearRegressionSuite.scala    | 47 ++++++++++---------
 .../regression/IsotonicRegressionSuite.scala  | 43 +++++++----------
 .../ml/regression/LinearRegressionSuite.scala | 25 +++++-----
 .../org/apache/spark/ml/util/MLTest.scala     | 39 +++++++++++----
 .../apache/spark/ml/util/MLTestSuite.scala    | 12 ++++-
 .../spark/sql/streaming/StreamTest.scala      | 27 +++++------
 9 files changed, 147 insertions(+), 131 deletions(-)

diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 02e5c6d294..4e4ff71c9d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -19,19 +19,16 @@ package org.apache.spark.ml.regression
 
 import scala.util.Random
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions.{col, lit}
 import org.apache.spark.sql.types._
 
-class AFTSurvivalRegressionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -191,8 +188,8 @@ class AFTSurvivalRegressionSuite
     assert(model.predict(features) ~== responsePredictR relTol 1E-3)
     assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
 
-    model.transform(datasetUnivariate).select("features", "prediction", "quantiles")
-      .collect().foreach {
+    testTransformer[(Vector, Double, Double)](datasetUnivariate, model,
+      "features", "prediction", "quantiles") {
         case Row(features: Vector, prediction: Double, quantiles: Vector) =>
           assert(prediction ~== model.predict(features) relTol 1E-5)
           assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
@@ -261,8 +258,8 @@ class AFTSurvivalRegressionSuite
     assert(model.predict(features) ~== responsePredictR relTol 1E-3)
     assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
 
-    model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
-      .collect().foreach {
+    testTransformer[(Vector, Double, Double)](datasetMultivariate, model,
+      "features", "prediction", "quantiles") {
         case Row(features: Vector, prediction: Double, quantiles: Vector) =>
           assert(prediction ~== model.predict(features) relTol 1E-5)
           assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
@@ -331,8 +328,8 @@ class AFTSurvivalRegressionSuite
     assert(model.predict(features) ~== responsePredictR relTol 1E-3)
     assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
 
-    model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
-      .collect().foreach {
+    testTransformer[(Vector, Double, Double)](datasetMultivariate, model,
+      "features", "prediction", "quantiles") {
         case Row(features: Vector, prediction: Double, quantiles: Vector) =>
           assert(prediction ~== model.predict(features) relTol 1E-5)
           assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 642f266891..68a1218c23 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -21,19 +21,18 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
   DecisionTreeSuite => OldDecisionTreeSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 
-class DecisionTreeRegressorSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
 
   import DecisionTreeRegressorSuite.compareAPIs
+  import testImplicits._
 
   private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
 
@@ -89,14 +88,11 @@ class DecisionTreeRegressorSuite
     val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
     val model = dt.fit(df)
 
-    val predictions = model.transform(df)
-      .select(model.getFeaturesCol, model.getVarianceCol)
-      .collect()
-
-    predictions.foreach { case Row(features: Vector, variance: Double) =>
-      val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
-      assert(variance === expectedVariance,
-        s"Expected variance $expectedVariance but got $variance.")
+    testTransformer[(Vector, Double)](df, model, "features", "variance") {
+      case Row(features: Vector, variance: Double) =>
+        val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
+        assert(variance === expectedVariance,
+          s"Expected variance $expectedVariance but got $variance.")
     }
 
     val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
@@ -104,18 +100,19 @@ class DecisionTreeRegressorSuite
     dt.setMaxDepth(1)
       .setMaxBins(6)
       .setSeed(0)
-    val transformVarDF = dt.fit(varianceDF).transform(varianceDF)
-    val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map {
-      case Row(variance: Double) => variance
-    }
 
-    // Since max depth is set to 1, the best split point is that which splits the data
-    // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
-    // data point in the left node is 0.667 and for each data point in the right node
-    // is 2.667
-    val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
-    calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
-      assert(actual ~== expected absTol 1e-3)
+    testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF),
+      "variance") { case rows: Seq[Row] =>
+      val calculatedVariances = rows.map(_.getDouble(0))
+
+      // Since max depth is set to 1, the best split point is that which splits the data
+      // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
+      // data point in the left node is 0.667 and for each data point in the right node
+      // is 2.667
+      val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
+      calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
+        assert(actual ~== expected absTol 1e-3)
+      }
     }
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index ecbb57126d..11c593b521 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,22 +19,20 @@ package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.util.Utils
 
 /**
  * Test suite for [[GBTRegressor]].
  */
-class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
-  with DefaultReadWriteTest {
+class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
 
   import GBTRegressorSuite.compareAPIs
   import testImplicits._
@@ -91,11 +89,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
     val model = gbt.fit(df)
 
     MLTestingUtils.checkCopyAndUids(gbt, model)
-    val preds = model.transform(df)
-    val predictions = preds.select("prediction").rdd.map(_.getDouble(0))
-    // Checks based on SPARK-8736 (to ensure it is not doing classification)
-    assert(predictions.max() > 2)
-    assert(predictions.min() < -1)
+
+    testTransformerByGlobalCheckFunc[(Double, Vector)](df, model, "prediction") {
+      case rows: Seq[Row] =>
+        val predictions = rows.map(_.getDouble(0))
+        // Checks based on SPARK-8736 (to ensure it is not doing classification)
+        assert(predictions.max > 2)
+        assert(predictions.min < -1)
+    }
   }
 
   test("Checkpointing") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index df7dee869d..ef2ff94a5e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.{Instance, OffsetInstance}
 import org.apache.spark.ml.feature.{LabeledPoint, RFormula}
 import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.random._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -33,8 +33,7 @@ import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.FloatType
 
-class GeneralizedLinearRegressionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -268,8 +267,8 @@ class GeneralizedLinearRegressionSuite
           s"$link link and fitIntercept = $fitIntercept.")
 
         val familyLink = FamilyAndLink(trainer)
-        model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
-          .foreach {
+        testTransformer[(Double, Vector)](dataset, model,
+          "features", "prediction", "linkPrediction") {
             case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
               val eta = BLAS.dot(features, model.coefficients) + model.intercept
               val prediction2 = familyLink.fitted(eta)
@@ -278,7 +277,7 @@ class GeneralizedLinearRegressionSuite
                 s"gaussian family, $link link and fitIntercept = $fitIntercept.")
               assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
                 s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.")
-          }
+        }
 
         idx += 1
       }
@@ -384,8 +383,8 @@ class GeneralizedLinearRegressionSuite
           s"$link link and fitIntercept = $fitIntercept.")
 
         val familyLink = FamilyAndLink(trainer)
-        model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
-          .foreach {
+        testTransformer[(Double, Vector)](dataset, model,
+          "features", "prediction", "linkPrediction") {
             case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
               val eta = BLAS.dot(features, model.coefficients) + model.intercept
               val prediction2 = familyLink.fitted(eta)
@@ -394,7 +393,7 @@ class GeneralizedLinearRegressionSuite
                 s"binomial family, $link link and fitIntercept = $fitIntercept.")
               assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
                 s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.")
-          }
+        }
 
         idx += 1
       }
@@ -456,8 +455,8 @@ class GeneralizedLinearRegressionSuite
           s"$link link and fitIntercept = $fitIntercept.")
 
         val familyLink = FamilyAndLink(trainer)
-        model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
-          .foreach {
+        testTransformer[(Double, Vector)](dataset, model,
+          "features", "prediction", "linkPrediction") {
             case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
               val eta = BLAS.dot(features, model.coefficients) + model.intercept
               val prediction2 = familyLink.fitted(eta)
@@ -466,7 +465,7 @@ class GeneralizedLinearRegressionSuite
                 s"poisson family, $link link and fitIntercept = $fitIntercept.")
               assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
                 s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.")
-          }
+        }
 
         idx += 1
       }
@@ -562,8 +561,8 @@ class GeneralizedLinearRegressionSuite
           s"$link link and fitIntercept = $fitIntercept.")
 
         val familyLink = FamilyAndLink(trainer)
-        model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
-          .foreach {
+        testTransformer[(Double, Vector)](dataset, model,
+          "features", "prediction", "linkPrediction") {
             case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
               val eta = BLAS.dot(features, model.coefficients) + model.intercept
               val prediction2 = familyLink.fitted(eta)
@@ -572,7 +571,7 @@ class GeneralizedLinearRegressionSuite
                 s"gamma family, $link link and fitIntercept = $fitIntercept.")
               assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
                 s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.")
-          }
+        }
 
         idx += 1
       }
@@ -649,8 +648,8 @@ class GeneralizedLinearRegressionSuite
         s"and variancePower = $variancePower.")
 
       val familyLink = FamilyAndLink(trainer)
-      model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect()
-        .foreach {
+      testTransformer[(Double, Double, Vector)](datasetTweedie, model,
+        "features", "prediction", "linkPrediction") {
           case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
             val eta = BLAS.dot(features, model.coefficients) + model.intercept
             val prediction2 = familyLink.fitted(eta)
@@ -661,7 +660,8 @@ class GeneralizedLinearRegressionSuite
             assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
               s"GLM with tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " +
               s"and variancePower = $variancePower.")
-        }
+      }
+
       idx += 1
     }
   }
@@ -724,8 +724,8 @@ class GeneralizedLinearRegressionSuite
           s"fitIntercept = $fitIntercept and variancePower = $variancePower.")
 
         val familyLink = FamilyAndLink(trainer)
-        model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect()
-          .foreach {
+        testTransformer[(Double, Double, Vector)](datasetTweedie, model,
+          "features", "prediction", "linkPrediction") {
             case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
               val eta = BLAS.dot(features, model.coefficients) + model.intercept
               val prediction2 = familyLink.fitted(eta)
@@ -736,7 +736,8 @@ class GeneralizedLinearRegressionSuite
               assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
                 s"GLM with tweedie family, fitIntercept = $fitIntercept " +
                 s"and variancePower = $variancePower.")
-          }
+        }
+
         idx += 1
       }
     }
@@ -861,8 +862,8 @@ class GeneralizedLinearRegressionSuite
           s" and fitIntercept = $fitIntercept.")
 
         val familyLink = FamilyAndLink(trainer)
-        model.transform(dataset).select("features", "offset", "prediction", "linkPrediction")
-          .collect().foreach {
+        testTransformer[(Double, Double, Double, Vector)](dataset, model,
+          "features", "offset", "prediction", "linkPrediction") {
           case Row(features: DenseVector, offset: Double, prediction1: Double,
           linkPrediction1: Double) =>
             val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 180f5f7ce5..18fbbce936 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -17,15 +17,12 @@
 
 package org.apache.spark.ml.regression
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.sql.{DataFrame, Row}
 
-class IsotonicRegressionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class IsotonicRegressionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -44,13 +41,11 @@ class IsotonicRegressionSuite
 
     val model = ir.fit(dataset)
 
-    val predictions = model
-      .transform(dataset)
-      .select("prediction").rdd.map { case Row(pred) =>
-        pred
-      }.collect()
-
-    assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
+    testTransformerByGlobalCheckFunc[(Double, Double, Double)](dataset, model,
+      "prediction") { case rows: Seq[Row] =>
+      val predictions = rows.map(_.getDouble(0))
+      assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
+    }
 
     assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8))
     assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0))
@@ -64,13 +59,11 @@ class IsotonicRegressionSuite
     val model = ir.fit(dataset)
     val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0))
 
-    val predictions = model
-      .transform(features)
-      .select("prediction").rdd.map {
-        case Row(pred) => pred
-      }.collect()
-
-    assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
+    testTransformerByGlobalCheckFunc[Tuple1[Double]](features, model,
+      "prediction") { case rows: Seq[Row] =>
+      val predictions = rows.map(_.getDouble(0))
+      assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
+    }
   }
 
   test("params validation") {
@@ -157,13 +150,11 @@ class IsotonicRegressionSuite
 
     val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0))
 
-    val predictions = model
-      .transform(features)
-      .select("prediction").rdd.map {
-      case Row(pred) => pred
-    }.collect()
-
-    assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
+    testTransformerByGlobalCheckFunc[Tuple1[Double]](features, model,
+      "prediction") { case rows: Seq[Row] =>
+      val predictions = rows.map(_.getDouble(0))
+      assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
+    }
   }
 
   test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 9bb2895858..d42cb17144 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.ml.regression
 
 import scala.util.Random
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.LinearDataGenerator
 import org.apache.spark.sql.{DataFrame, Row}
 
 class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
@@ -363,8 +362,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
       assert(model2.intercept ~== interceptR2 relTol 1E-3)
       assert(model2.coefficients ~= coefficientsR2 relTol 1E-3)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction")
-        .collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
           case Row(features: DenseVector, prediction1: Double) =>
             val prediction2 =
               features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
@@ -416,8 +415,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
       assert(model2.intercept ~== interceptR2 absTol 1E-2)
       assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction")
-        .collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
           case Row(features: DenseVector, prediction1: Double) =>
             val prediction2 =
               features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
@@ -467,7 +466,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
       assert(model2.intercept ~== interceptR2 relTol 1E-2)
       assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
         case Row(features: DenseVector, prediction1: Double) =>
           val prediction2 =
             features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
@@ -518,7 +518,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
       assert(model2.intercept ~== interceptR2 absTol 1E-2)
       assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
         case Row(features: DenseVector, prediction1: Double) =>
           val prediction2 =
             features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
@@ -570,8 +571,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
       assert(model2.intercept ~== interceptR2 relTol 1E-2)
       assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction")
-        .collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
         case Row(features: DenseVector, prediction1: Double) =>
           val prediction2 =
             features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
@@ -624,8 +625,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
       assert(model2.intercept ~== interceptR2 absTol 1E-2)
       assert(model2.coefficients ~= coefficientsR2 relTol 1E-2)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction")
-        .collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
         case Row(features: DenseVector, prediction1: Double) =>
           val prediction2 =
             features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 7a5426ebad..17678aa611 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -53,12 +53,12 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
     }
   }
 
-  def testTransformerOnStreamData[A : Encoder](
+  private[util] def testTransformerOnStreamData[A : Encoder](
       dataframe: DataFrame,
       transformer: Transformer,
       firstResultCol: String,
       otherResultCols: String*)
-      (checkFunction: Row => Unit): Unit = {
+      (globalCheckFunction: Seq[Row] => Unit): Unit = {
 
     val columnNames = dataframe.schema.fieldNames
     val stream = MemoryStream[A]
@@ -70,22 +70,43 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
       .select(firstResultCol, otherResultCols: _*)
     testStream(streamOutput) (
       AddData(stream, data: _*),
-      CheckAnswer(checkFunction)
+      CheckAnswer(globalCheckFunction)
     )
   }
 
+  private[util] def testTransformerOnDF(
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (globalCheckFunction: Seq[Row] => Unit): Unit = {
+    val dfOutput = transformer.transform(dataframe)
+    val outputs = dfOutput.select(firstResultCol, otherResultCols: _*).collect()
+    globalCheckFunction(outputs)
+  }
+
   def testTransformer[A : Encoder](
       dataframe: DataFrame,
       transformer: Transformer,
       firstResultCol: String,
       otherResultCols: String*)
       (checkFunction: Row => Unit): Unit = {
-    testTransformerOnStreamData(dataframe, transformer, firstResultCol,
-      otherResultCols: _*)(checkFunction)
+    testTransformerByGlobalCheckFunc(
+      dataframe,
+      transformer,
+      firstResultCol,
+      otherResultCols: _*) { rows: Seq[Row] => rows.foreach(checkFunction(_)) }
+  }
 
-    val dfOutput = transformer.transform(dataframe)
-    dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row =>
-      checkFunction(row)
-    }
+  def testTransformerByGlobalCheckFunc[A : Encoder](
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (globalCheckFunction: Seq[Row] => Unit): Unit = {
+    testTransformerOnStreamData(dataframe, transformer, firstResultCol,
+      otherResultCols: _*)(globalCheckFunction)
+    testTransformerOnDF(dataframe, transformer, firstResultCol,
+      otherResultCols: _*)(globalCheckFunction)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
index 56217ec4f3..20c5b5395f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.ml.util
 
-import org.apache.spark.ml.{PipelineModel, Transformer}
 import org.apache.spark.ml.feature.StringIndexer
 import org.apache.spark.sql.Row
 
@@ -32,10 +31,13 @@ class MLTestSuite extends MLTest {
     val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
       .setInputCol("label").setOutputCol("indexed")
     val indexerModel = indexer.fit(data)
-    testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
+    testTransformer[(Int, String)](data, indexerModel, "id", "indexed") {
       case Row(id: Int, indexed: Double) =>
         assert(id === indexed.toInt)
     }
+    testTransformerByGlobalCheckFunc[(Int, String)] (data, indexerModel, "id", "indexed") { rows =>
+      assert(rows.map(_.getDouble(1)).max === 5.0)
+    }
 
     intercept[Exception] {
       testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
@@ -43,5 +45,11 @@ class MLTestSuite extends MLTest {
           assert(id != indexed.toInt)
       }
     }
+    intercept[Exception] {
+      testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
+        rows: Seq[Row] =>
+          assert(rows.map(_.getDouble(1)).max === 1.0)
+      }
+    }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index fb9ebc81dd..4b7f0fbe97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -137,8 +137,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)
 
-    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
-      CheckAnswerRowsByFunc(checkFunction, false)
+    def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(globalCheckFunction, false)
   }
 
   /**
@@ -161,8 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
 
-    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
-      CheckAnswerRowsByFunc(checkFunction, true)
+    def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(globalCheckFunction, true)
   }
 
   case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
@@ -177,9 +177,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
   }
 
-  case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean)
-      extends StreamAction with StreamMustBeRunning {
-    override def toString: String = s"$operatorName: ${checkFunction.toString()}"
+  case class CheckAnswerRowsByFunc(
+      globalCheckFunction: Seq[Row] => Unit,
+      lastOnly: Boolean) extends StreamAction with StreamMustBeRunning {
+    override def toString: String = s"$operatorName"
     private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
   }
 
@@ -639,14 +640,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
               error => failTest(error)
             }
 
-          case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
+          case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
             val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
-            sparkAnswer.foreach { row =>
-              try {
-                checkFunction(row)
-              } catch {
-                case e: Throwable => failTest(e.toString)
-              }
+            try {
+              globalCheckFunction(sparkAnswer)
+            } catch {
+              case e: Throwable => failTest(e.toString)
             }
         }
         pos += 1
-- 
GitLab