From ddec173cba63df723cd94508121d8c06d8c153c6 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Wed, 20 May 2015 20:30:39 -0700
Subject: [PATCH] [SPARK-7774] [MLLIB] add sqlContext to MLlibTestSparkContext

to simplify test suites that require a SQLContext.

Author: Xiangrui Meng <meng@databricks.com>

Closes #6303 from mengxr/SPARK-7774 and squashes the following commits:

0622b5a [Xiangrui Meng] update some other test suites
e1f9b8d [Xiangrui Meng] add sqlContext to MLlibTestSparkContext
---
 .../ml/classification/LogisticRegressionSuite.scala   |  4 +---
 .../spark/ml/classification/OneVsRestSuite.scala      |  7 +++----
 .../org/apache/spark/ml/feature/BinarizerSuite.scala  |  6 +-----
 .../org/apache/spark/ml/feature/BucketizerSuite.scala |  9 +--------
 .../scala/org/apache/spark/ml/feature/IDFSuite.scala  |  9 +--------
 .../apache/spark/ml/feature/OneHotEncoderSuite.scala  |  8 +-------
 .../spark/ml/feature/PolynomialExpansionSuite.scala   | 11 ++---------
 .../apache/spark/ml/feature/StringIndexerSuite.scala  |  7 -------
 .../org/apache/spark/ml/feature/TokenizerSuite.scala  |  9 +--------
 .../spark/ml/feature/VectorAssemblerSuite.scala       |  9 +--------
 .../apache/spark/ml/feature/VectorIndexerSuite.scala  |  6 +-----
 .../org/apache/spark/ml/recommendation/ALSSuite.scala |  2 --
 .../spark/ml/regression/LinearRegressionSuite.scala   |  4 +---
 .../spark/mllib/util/MLlibTestSparkContext.scala      |  8 ++++++--
 14 files changed, 20 insertions(+), 79 deletions(-)

diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 97f9749cb4..9f77d5f3ef 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -23,18 +23,16 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
 
 class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
-  @transient var sqlContext: SQLContext = _
   @transient var dataset: DataFrame = _
   @transient var binaryDataset: DataFrame = _
   private val eps: Double = 1e-5
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    sqlContext = new SQLContext(sc)
 
     dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 990cfb08af..770b56890f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,24 +21,23 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.util.MetadataUtils
-import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.DataFrame
 
 class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
 
-  @transient var sqlContext: SQLContext = _
   @transient var dataset: DataFrame = _
   @transient var rdd: RDD[LabeledPoint] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    sqlContext = new SQLContext(sc)
+
     val nPoints = 1000
 
     // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index caf1b75959..8f6c6b39dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -20,18 +20,14 @@ package org.apache.spark.ml.feature
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
+import org.apache.spark.sql.{DataFrame, Row}
 
 class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
 
   @transient var data: Array[Double] = _
-  @transient var sqlContext: SQLContext = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    sqlContext = new SQLContext(sc)
     data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 20d2f3ac66..0391bd8427 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -25,17 +25,10 @@ import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
 
 class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
 
-  @transient private var sqlContext: SQLContext = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    sqlContext = new SQLContext(sc)
-  }
-
   test("Bucket continuous features, without -inf,inf") {
     // Check a set of valid feature values.
     val splits = Array(-0.5, 0.0, 0.5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index eaee3443c1..f85e854716 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -22,17 +22,10 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
 
 class IDFSuite extends FunSuite with MLlibTestSparkContext {
 
-  @transient var sqlContext: SQLContext = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    sqlContext = new SQLContext(sc)
-  }
-
   def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
     dataSet.map {
       case data: DenseVector =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 92ec407b98..056b9eda86 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -21,16 +21,10 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.DataFrame
 
 
 class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
-  private var sqlContext: SQLContext = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    sqlContext = new SQLContext(sc)
-  }
 
   def stringIndexed(): DataFrame = {
     val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index c1d64fba0a..aa230ca073 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -18,22 +18,15 @@
 package org.apache.spark.ml.feature
 
 import org.scalatest.FunSuite
+import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{Row, SQLContext}
-import org.scalatest.exceptions.TestFailedException
+import org.apache.spark.sql.Row
 
 class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext {
 
-  @transient var sqlContext: SQLContext = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    sqlContext = new SQLContext(sc)
-  }
-
   test("Polynomial expansion with default parameter") {
     val data = Array(
       Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index b6939e5870..89c2fe4557 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -21,15 +21,8 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.SQLContext
 
 class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
-  private var sqlContext: SQLContext = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    sqlContext = new SQLContext(sc)
-  }
 
   test("StringIndexer") {
     val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index d186ead8f5..a46d08d651 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -22,7 +22,7 @@ import scala.beans.BeanInfo
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
 
 @BeanInfo
 case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
@@ -30,13 +30,6 @@ case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
 class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
   import org.apache.spark.ml.feature.RegexTokenizerSuite._
   
-  @transient var sqlContext: SQLContext = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    sqlContext = new SQLContext(sc)
-  }
-
   test("RegexTokenizer") {
     val tokenizer = new RegexTokenizer()
       .setInputCol("rawText")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 0db27607bc..d0cd62c5e4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -22,17 +22,10 @@ import org.scalatest.FunSuite
 import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
 
 class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
 
-  @transient var sqlContext: SQLContext = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    sqlContext = new SQLContext(sc)
-  }
-
   test("assemble") {
     import org.apache.spark.ml.feature.VectorAssembler.assemble
     assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 38dc83b124..b11b029c63 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -26,15 +26,12 @@ import org.apache.spark.ml.attribute._
 import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
-
+import org.apache.spark.sql.DataFrame
 
 class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
 
   import VectorIndexerSuite.FeatureData
 
-  @transient var sqlContext: SQLContext = _
-
   // identical, of length 3
   @transient var densePoints1: DataFrame = _
   @transient var sparsePoints1: DataFrame = _
@@ -86,7 +83,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
     checkPair(densePoints1Seq, sparsePoints1Seq)
     checkPair(densePoints2Seq, sparsePoints2Seq)
 
-    sqlContext = new SQLContext(sc)
     densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
     sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
     densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 6cc6ec94eb..9a35555e52 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -38,14 +38,12 @@ import org.apache.spark.util.Utils
 
 class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
 
-  private var sqlContext: SQLContext = _
   private var tempDir: File = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
     tempDir = Utils.createTempDir()
     sc.setCheckpointDir(tempDir.getAbsolutePath)
-    sqlContext = new SQLContext(sc)
   }
 
   override def afterAll(): Unit = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 80323ef520..50a78631fa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,11 +22,10 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.linalg.DenseVector
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{Row, SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, Row}
 
 class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
-  @transient var sqlContext: SQLContext = _
   @transient var dataset: DataFrame = _
 
   /**
@@ -41,7 +40,6 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
    */
   override def beforeAll(): Unit = {
     super.beforeAll()
-    sqlContext = new SQLContext(sc)
     dataset = sqlContext.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
         6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index b658889476..5d1796ef65 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.mllib.util
 
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, Suite}
 
 import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SQLContext
 
 trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
   @transient var sc: SparkContext = _
+  @transient var sqlContext: SQLContext = _
 
   override def beforeAll() {
     super.beforeAll()
@@ -31,12 +32,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
       .setMaster("local[2]")
       .setAppName("MLlibUnitTest")
     sc = new SparkContext(conf)
+    sqlContext = new SQLContext(sc)
   }
 
   override def afterAll() {
+    sqlContext = null
     if (sc != null) {
       sc.stop()
     }
+    sc = null
     super.afterAll()
   }
 }
-- 
GitLab