diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 3d7a91dd39a71d11cd316c5ce879891f08379747..963f81cb3ec3931464866d940c9768e257e37a99 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
   }
 
   @Since("1.4.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): CrossValidator = {
@@ -331,11 +319,6 @@ class CrossValidatorModel private[ml] (
     @Since("1.5.0") val avgMetrics: Array[Double])
   extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
 
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
   @Since("1.4.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] (
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4587e259e8bf711e1c17c8c3b38e91f089d7b35a..70fa5f02347538cf8a0a2d1b2546126e12c9bf27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
   }
 
   @Since("1.5.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.5.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): TrainValidationSplit = {
@@ -160,11 +148,6 @@ class TrainValidationSplitModel private[ml] (
     @Since("1.5.0") val validationMetrics: Array[Double])
   extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
 
-  @Since("1.5.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
   @Since("1.5.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] (
 
   @Since("1.5.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 553f25417241095e0697409a1b6c255692a5d0d7..c004644ad8e5589f12c59e65c8522c2a1a85ae8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.Estimator
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.sql.types.StructType
 
 /**
  * :: DeveloperApi ::
@@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the estimator to be validated
+   *
    * @group param
    */
   val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
@@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for estimator param maps
+   *
    * @group param
    */
   val estimatorParamMaps: Param[Array[ParamMap]] =
@@ -50,6 +53,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the evaluator used to select hyper-parameters that maximize the validated metric
+   *
    * @group param
    */
   val evaluator: Param[Evaluator] = new Param(this, "evaluator",
@@ -57,4 +61,14 @@ private[ml] trait ValidatorParams extends Params {
 
   /** @group getParam */
   def getEvaluator: Evaluator = $(evaluator)
+
+  protected def transformSchemaImpl(schema: StructType): StructType = {
+    require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
+    val firstEstimatorParamMap = $(estimatorParamMaps).head
+    val est = $(estimator)
+    for (paramMap <- $(estimatorParamMaps).tail) {
+      est.copy(paramMap).transformSchema(schema)
+    }
+    est.copy(firstEstimatorParamMap).transformSchema(schema)
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 748868554fe65c264ce0a11679afd66f329843cc..a3366c0e5934c46f82b93c2846cdb1c574995fc6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite {
       solver.getParam("abc")
     }
 
-    intercept[IllegalArgumentException] {
-      solver.validateParams()
-    }
-    solver.copy(ParamMap(inputCol -> "input")).validateParams()
     solver.setInputCol("input")
     assert(solver.isSet(inputCol))
     assert(solver.isDefined(inputCol))
     assert(solver.getInputCol === "input")
-    solver.validateParams()
     intercept[IllegalArgumentException] {
       ParamMap(maxIter -> -10)
     }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 9d23547f2844781d529a586257fdcd3e6aebc23b..7d990ce0bcfd81cd4966579eb49183f7043fa393 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid
 
   def clearMaxIter(): this.type = clear(maxIter)
 
-  override def validateParams(): Unit = {
-    super.validateParams()
-    require(isDefined(inputCol))
-  }
-
   override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 56545de14bd308b417885cd7b7a3d51548a345bd..7af3c6d6ede473fead44ace31c8a86fd0d5582fa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
 
 class CrossValidatorSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -96,7 +96,7 @@ class CrossValidatorSuite
     assert(cvModel2.avgMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import CrossValidatorSuite.{MyEstimator, MyEvaluator}
 
     val est = new MyEstimator("est")
@@ -110,12 +110,12 @@ class CrossValidatorSuite
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
 
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 
@@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 5fb80091d0b4b8c1a5bf930ba87fd5e966f33510..cf8dcefebc3aacce32f0059e00d9315639305bc1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(cvModel2.validationMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import TrainValidationSplitSuite._
 
     val est = new MyEstimator("est")
@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 }
@@ -113,14 +113,13 @@ object TrainValidationSplitSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)