From b39e80d39dae8e6779f9d78c1631a27585239032 Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Thu, 17 Mar 2016 13:23:07 -0700
Subject: [PATCH] [SPARK-13761][ML] Remove remaining uses of validateParams

## What changes were proposed in this pull request?

Cleanups from [https://github.com/apache/spark/pull/11620]: remove remaining uses of validateParams, and put functionality into transformSchema

## How was this patch tested?

Existing unit tests, modified to check using transformSchema instead of validateParams

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #11790 from jkbradley/SPARK-13761-cleanup.
---
 .../spark/ml/tuning/CrossValidator.scala      | 20 +------------------
 .../ml/tuning/TrainValidationSplit.scala      | 20 +------------------
 .../spark/ml/tuning/ValidatorParams.scala     | 14 +++++++++++++
 .../apache/spark/ml/param/ParamsSuite.scala   |  5 -----
 .../apache/spark/ml/param/TestParams.scala    |  5 -----
 .../spark/ml/tuning/CrossValidatorSuite.scala | 13 ++++++------
 .../ml/tuning/TrainValidationSplitSuite.scala | 11 +++++-----
 7 files changed, 27 insertions(+), 61 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 3d7a91dd39..963f81cb3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
   }
 
   @Since("1.4.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): CrossValidator = {
@@ -331,11 +319,6 @@ class CrossValidatorModel private[ml] (
     @Since("1.5.0") val avgMetrics: Array[Double])
   extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
 
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
   @Since("1.4.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] (
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4587e259e8..70fa5f0234 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
   }
 
   @Since("1.5.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.5.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): TrainValidationSplit = {
@@ -160,11 +148,6 @@ class TrainValidationSplitModel private[ml] (
     @Since("1.5.0") val validationMetrics: Array[Double])
   extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
 
-  @Since("1.5.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
   @Since("1.5.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] (
 
   @Since("1.5.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 553f254172..c004644ad8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.Estimator
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.sql.types.StructType
 
 /**
  * :: DeveloperApi ::
@@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the estimator to be validated
+   *
    * @group param
    */
   val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
@@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for estimator param maps
+   *
    * @group param
    */
   val estimatorParamMaps: Param[Array[ParamMap]] =
@@ -50,6 +53,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the evaluator used to select hyper-parameters that maximize the validated metric
+   *
    * @group param
    */
   val evaluator: Param[Evaluator] = new Param(this, "evaluator",
@@ -57,4 +61,14 @@ private[ml] trait ValidatorParams extends Params {
 
   /** @group getParam */
   def getEvaluator: Evaluator = $(evaluator)
+
+  protected def transformSchemaImpl(schema: StructType): StructType = {
+    require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
+    val firstEstimatorParamMap = $(estimatorParamMaps).head
+    val est = $(estimator)
+    for (paramMap <- $(estimatorParamMaps).tail) {
+      est.copy(paramMap).transformSchema(schema)
+    }
+    est.copy(firstEstimatorParamMap).transformSchema(schema)
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 748868554f..a3366c0e59 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite {
       solver.getParam("abc")
     }
 
-    intercept[IllegalArgumentException] {
-      solver.validateParams()
-    }
-    solver.copy(ParamMap(inputCol -> "input")).validateParams()
     solver.setInputCol("input")
     assert(solver.isSet(inputCol))
     assert(solver.isDefined(inputCol))
     assert(solver.getInputCol === "input")
-    solver.validateParams()
     intercept[IllegalArgumentException] {
       ParamMap(maxIter -> -10)
     }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 9d23547f28..7d990ce0bc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid
 
   def clearMaxIter(): this.type = clear(maxIter)
 
-  override def validateParams(): Unit = {
-    super.validateParams()
-    require(isDefined(inputCol))
-  }
-
   override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 56545de14b..7af3c6d6ed 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
 
 class CrossValidatorSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -96,7 +96,7 @@ class CrossValidatorSuite
     assert(cvModel2.avgMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import CrossValidatorSuite.{MyEstimator, MyEvaluator}
 
     val est = new MyEstimator("est")
@@ -110,12 +110,12 @@ class CrossValidatorSuite
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
 
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 
@@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 5fb80091d0..cf8dcefebc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(cvModel2.validationMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import TrainValidationSplitSuite._
 
     val est = new MyEstimator("est")
@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 }
@@ -113,14 +113,13 @@ object TrainValidationSplitSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
-- 
GitLab