Skip to content
Snippets Groups Projects
Commit b39e80d3 authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-13761][ML] Remove remaining uses of validateParams

## What changes were proposed in this pull request?

Cleanups from [https://github.com/apache/spark/pull/11620]: remove remaining uses of validateParams, and put functionality into transformSchema

## How was this patch tested?

Existing unit tests, modified to check using transformSchema instead of validateParams

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #11790 from jkbradley/SPARK-13761-cleanup.
parent 4c08e2c0
No related branches found
No related tags found
No related merge requests found
...@@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) ...@@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
} }
@Since("1.4.0") @Since("1.4.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
validateParams()
$(estimator).transformSchema(schema)
}
@Since("1.4.0")
override def validateParams(): Unit = {
super.validateParams()
val est = $(estimator)
for (paramMap <- $(estimatorParamMaps)) {
est.copy(paramMap).validateParams()
}
}
@Since("1.4.0") @Since("1.4.0")
override def copy(extra: ParamMap): CrossValidator = { override def copy(extra: ParamMap): CrossValidator = {
...@@ -331,11 +319,6 @@ class CrossValidatorModel private[ml] ( ...@@ -331,11 +319,6 @@ class CrossValidatorModel private[ml] (
@Since("1.5.0") val avgMetrics: Array[Double]) @Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
@Since("1.4.0")
override def validateParams(): Unit = {
bestModel.validateParams()
}
@Since("1.4.0") @Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
...@@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] ( ...@@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] (
@Since("1.4.0") @Since("1.4.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateParams()
bestModel.transformSchema(schema) bestModel.transformSchema(schema)
} }
......
...@@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St ...@@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
} }
@Since("1.5.0") @Since("1.5.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
validateParams()
$(estimator).transformSchema(schema)
}
@Since("1.5.0")
override def validateParams(): Unit = {
super.validateParams()
val est = $(estimator)
for (paramMap <- $(estimatorParamMaps)) {
est.copy(paramMap).validateParams()
}
}
@Since("1.5.0") @Since("1.5.0")
override def copy(extra: ParamMap): TrainValidationSplit = { override def copy(extra: ParamMap): TrainValidationSplit = {
...@@ -160,11 +148,6 @@ class TrainValidationSplitModel private[ml] ( ...@@ -160,11 +148,6 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") val validationMetrics: Array[Double]) @Since("1.5.0") val validationMetrics: Array[Double])
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
@Since("1.5.0")
override def validateParams(): Unit = {
bestModel.validateParams()
}
@Since("1.5.0") @Since("1.5.0")
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
...@@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] ( ...@@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") @Since("1.5.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateParams()
bestModel.transformSchema(schema) bestModel.transformSchema(schema)
} }
......
...@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi ...@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Estimator import org.apache.spark.ml.Estimator
import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.sql.types.StructType
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
...@@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params { ...@@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params {
/** /**
* param for the estimator to be validated * param for the estimator to be validated
*
* @group param * @group param
*/ */
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
...@@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params { ...@@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params {
/** /**
* param for estimator param maps * param for estimator param maps
*
* @group param * @group param
*/ */
val estimatorParamMaps: Param[Array[ParamMap]] = val estimatorParamMaps: Param[Array[ParamMap]] =
...@@ -50,6 +53,7 @@ private[ml] trait ValidatorParams extends Params { ...@@ -50,6 +53,7 @@ private[ml] trait ValidatorParams extends Params {
/** /**
* param for the evaluator used to select hyper-parameters that maximize the validated metric * param for the evaluator used to select hyper-parameters that maximize the validated metric
*
* @group param * @group param
*/ */
val evaluator: Param[Evaluator] = new Param(this, "evaluator", val evaluator: Param[Evaluator] = new Param(this, "evaluator",
...@@ -57,4 +61,14 @@ private[ml] trait ValidatorParams extends Params { ...@@ -57,4 +61,14 @@ private[ml] trait ValidatorParams extends Params {
/** @group getParam */ /** @group getParam */
def getEvaluator: Evaluator = $(evaluator) def getEvaluator: Evaluator = $(evaluator)
protected def transformSchemaImpl(schema: StructType): StructType = {
require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
val firstEstimatorParamMap = $(estimatorParamMaps).head
val est = $(estimator)
for (paramMap <- $(estimatorParamMaps).tail) {
est.copy(paramMap).transformSchema(schema)
}
est.copy(firstEstimatorParamMap).transformSchema(schema)
}
} }
...@@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite { ...@@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite {
solver.getParam("abc") solver.getParam("abc")
} }
intercept[IllegalArgumentException] {
solver.validateParams()
}
solver.copy(ParamMap(inputCol -> "input")).validateParams()
solver.setInputCol("input") solver.setInputCol("input")
assert(solver.isSet(inputCol)) assert(solver.isSet(inputCol))
assert(solver.isDefined(inputCol)) assert(solver.isDefined(inputCol))
assert(solver.getInputCol === "input") assert(solver.getInputCol === "input")
solver.validateParams()
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
ParamMap(maxIter -> -10) ParamMap(maxIter -> -10)
} }
......
...@@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid ...@@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid
def clearMaxIter(): this.type = clear(maxIter) def clearMaxIter(): this.type = clear(maxIter)
override def validateParams(): Unit = {
super.validateParams()
require(isDefined(inputCol))
}
override def copy(extra: ParamMap): TestParams = defaultCopy(extra) override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
} }
...@@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog ...@@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.{StructField, StructType}
class CrossValidatorSuite class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
...@@ -96,7 +96,7 @@ class CrossValidatorSuite ...@@ -96,7 +96,7 @@ class CrossValidatorSuite
assert(cvModel2.avgMetrics.length === lrParamMaps.length) assert(cvModel2.avgMetrics.length === lrParamMaps.length)
} }
test("validateParams should check estimatorParamMaps") { test("transformSchema should check estimatorParamMaps") {
import CrossValidatorSuite.{MyEstimator, MyEvaluator} import CrossValidatorSuite.{MyEstimator, MyEvaluator}
val est = new MyEstimator("est") val est = new MyEstimator("est")
...@@ -110,12 +110,12 @@ class CrossValidatorSuite ...@@ -110,12 +110,12 @@ class CrossValidatorSuite
.setEstimatorParamMaps(paramMaps) .setEstimatorParamMaps(paramMaps)
.setEvaluator(eval) .setEvaluator(eval)
cv.validateParams() // This should pass. cv.transformSchema(new StructType()) // This should pass.
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps) cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
cv.validateParams() cv.transformSchema(new StructType())
} }
} }
...@@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite { ...@@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
override def validateParams(): Unit = require($(inputCol).nonEmpty)
override def fit(dataset: DataFrame): MyModel = { override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException throw new UnsupportedOperationException
} }
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException require($(inputCol).nonEmpty)
schema
} }
override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
......
...@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext ...@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
assert(cvModel2.validationMetrics.length === lrParamMaps.length) assert(cvModel2.validationMetrics.length === lrParamMaps.length)
} }
test("validateParams should check estimatorParamMaps") { test("transformSchema should check estimatorParamMaps") {
import TrainValidationSplitSuite._ import TrainValidationSplitSuite._
val est = new MyEstimator("est") val est = new MyEstimator("est")
...@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext ...@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
.setEstimatorParamMaps(paramMaps) .setEstimatorParamMaps(paramMaps)
.setEvaluator(eval) .setEvaluator(eval)
.setTrainRatio(0.5) .setTrainRatio(0.5)
cv.validateParams() // This should pass. cv.transformSchema(new StructType()) // This should pass.
val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
cv.setEstimatorParamMaps(invalidParamMaps) cv.setEstimatorParamMaps(invalidParamMaps)
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
cv.validateParams() cv.transformSchema(new StructType())
} }
} }
} }
...@@ -113,14 +113,13 @@ object TrainValidationSplitSuite { ...@@ -113,14 +113,13 @@ object TrainValidationSplitSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
override def validateParams(): Unit = require($(inputCol).nonEmpty)
override def fit(dataset: DataFrame): MyModel = { override def fit(dataset: DataFrame): MyModel = {
throw new UnsupportedOperationException throw new UnsupportedOperationException
} }
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
throw new UnsupportedOperationException require($(inputCol).nonEmpty)
schema
} }
override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment