diff --git a/mllib/pom.xml b/mllib/pom.xml
index 925b5422a54cc2e81e78174f31398b83a6b1b381..a906c9e02cd4cfcc7649e296bb7cd0158dfe7eed 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -60,6 +60,20 @@
       <artifactId>spark-sql_${scala.binary.version}</artifactId>
       <version>${project.version}</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sql_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
     <dependency>
       <groupId>org.apache.spark</groupId>
       <artifactId>spark-graphx_${scala.binary.version}</artifactId>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 0e0be58dbf022ca867014bf8bb6628f390cd075a..aec5ac0e758964b9d3a3b081f1d2a949aa7bbfda 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.{DataFrame, Row}
 
-class LinearRegressionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -233,7 +232,8 @@ class LinearRegressionSuite
       assert(model2.intercept ~== interceptR relTol 1E-3)
       assert(model2.coefficients ~= coefficientsR relTol 1E-3)
 
-      model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
         case Row(features: DenseVector, prediction1: Double) =>
           val prediction2 =
             features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) +
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
new file mode 100644
index 0000000000000000000000000000000000000000..7a5426ebadaa5c9dc2db21c4e489b01e679f8856
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.File
+
+import org.scalatest.Suite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{PipelineModel, Transformer}
+import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.util.Utils
+
+trait MLTest extends StreamTest with TempDirectory { self: Suite =>
+
+  @transient var sc: SparkContext = _
+  @transient var checkpointDir: String = _
+
+  protected override def createSparkSession: TestSparkSession = {
+    new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf))
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sc = spark.sparkContext
+    checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
+    sc.setCheckpointDir(checkpointDir)
+  }
+
+  override def afterAll() {
+    try {
+      Utils.deleteRecursively(new File(checkpointDir))
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  def testTransformerOnStreamData[A : Encoder](
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (checkFunction: Row => Unit): Unit = {
+
+    val columnNames = dataframe.schema.fieldNames
+    val stream = MemoryStream[A]
+    val streamDF = stream.toDS().toDF(columnNames: _*)
+
+    val data = dataframe.as[A].collect()
+
+    val streamOutput = transformer.transform(streamDF)
+      .select(firstResultCol, otherResultCols: _*)
+    testStream(streamOutput) (
+      AddData(stream, data: _*),
+      CheckAnswer(checkFunction)
+    )
+  }
+
+  def testTransformer[A : Encoder](
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (checkFunction: Row => Unit): Unit = {
+    testTransformerOnStreamData(dataframe, transformer, firstResultCol,
+      otherResultCols: _*)(checkFunction)
+
+    val dfOutput = transformer.transform(dataframe)
+    dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { row =>
+      checkFunction(row)
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..56217ec4f3b0cf6abaa928c54bec5ec6a44e23fa
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.{PipelineModel, Transformer}
+import org.apache.spark.ml.feature.StringIndexer
+import org.apache.spark.sql.Row
+
+class MLTestSuite extends MLTest {
+
+  import testImplicits._
+
+  test("test transformer on stream data") {
+
+    val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))
+      .toDF("id", "label")
+    val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
+      .setInputCol("label").setOutputCol("indexed")
+    val indexerModel = indexer.fit(data)
+    testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
+      case Row(id: Int, indexed: Double) =>
+        assert(id === indexed.toInt)
+    }
+
+    intercept[Exception] {
+      testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
+        case Row(id: Int, indexed: Double) =>
+          assert(id != indexed.toInt)
+      }
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index e68fca050571f7238c8a75fea87d1d2cf1f89fba..dc5b998ad68b5f80e8c962b1e3f46ccbe4e5a664 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     }
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)
+
+    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(checkFunction, false)
   }
 
   /**
@@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     }
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
+
+    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(checkFunction, true)
   }
 
   case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
@@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
   }
 
+  case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean)
+      extends StreamAction with StreamMustBeRunning {
+    override def toString: String = s"$operatorName: ${checkFunction.toString()}"
+    private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc"
+  }
+
   /** Stops the stream. It must currently be running. */
   case object StopStream extends StreamAction with StreamMustBeRunning
 
@@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
          """.stripMargin)
     }
 
+    def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
+      verify(currentStream != null, "stream not running")
+      // Get the map of source index to the current source objects
+      val indexToSource = currentStream
+        .logicalPlan
+        .collect { case StreamingExecutionRelation(s, _) => s }
+        .zipWithIndex
+        .map(_.swap)
+        .toMap
+
+      // Block until all data added has been processed for all the source
+      awaiting.foreach { case (sourceIndex, offset) =>
+        failAfter(streamingTimeout) {
+          currentStream.awaitOffset(indexToSource(sourceIndex), offset)
+        }
+      }
+
+      try if (lastOnly) sink.latestBatchData else sink.allData catch {
+        case e: Exception =>
+          failTest("Exception while getting data from sink", e)
+      }
+    }
+
     var manualClockExpectedTime = -1L
     val defaultCheckpointLocation =
       Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
             e.runAction()
 
           case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
-            verify(currentStream != null, "stream not running")
-            // Get the map of source index to the current source objects
-            val indexToSource = currentStream
-              .logicalPlan
-              .collect { case StreamingExecutionRelation(s, _) => s }
-              .zipWithIndex
-              .map(_.swap)
-              .toMap
-
-            // Block until all data added has been processed for all the source
-            awaiting.foreach { case (sourceIndex, offset) =>
-              failAfter(streamingTimeout) {
-                currentStream.awaitOffset(indexToSource(sourceIndex), offset)
-              }
-            }
-
-            val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch {
-              case e: Exception =>
-                failTest("Exception while getting data from sink", e)
-            }
-
+            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
             QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
               error => failTest(error)
             }
+
+          case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
+            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+            sparkAnswer.foreach { row =>
+              try {
+                checkFunction(row)
+              } catch {
+                case e: Throwable => failTest(e.toString)
+              }
+            }
         }
         pos += 1
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 959edf9a4937181317ea3c3a01fe451242db9afb..4286e8a6ca2c81a3488f63a7d1a8cf76ff33c3bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf
 /**
  * A special `SparkSession` prepared for testing.
  */
-private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
+private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
   def this(sparkConf: SparkConf) {
     this(new SparkContext("local[2]", "test-sql-context",
       sparkConf.set("spark.sql.testkey", "true")))