From d8662cd909a41575df6e0ea1630d2386d3711240 Mon Sep 17 00:00:00 2001 From: leahmcguire <lmcguire@salesforce.com> Date: Wed, 3 Jun 2015 15:46:38 -0700 Subject: [PATCH] [SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting Added stats from cross validation as a val in the cross validation model to save them for user access. Author: leahmcguire <lmcguire@salesforce.com> Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits: 49b507b [leahmcguire] fixed tyle error 67537b1 [leahmcguire] rebased 85907f0 [leahmcguire] fixed name 59987cc [leahmcguire] changed param name and test according to comments 36e71e3 [leahmcguire] rebasing 4b8223e [leahmcguire] fixed name 4ddffc6 [leahmcguire] changed param name and test according to comments 3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++++--- .../apache/spark/ml/tuning/CrossValidatorSuite.scala | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 6434b64aed..cb29392e8b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM @Experimental class CrossValidatorModel private[ml] ( override val uid: String, - val bestModel: Model[_]) + val bestModel: Model[_], + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(): Unit = { @@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] ( } override def copy(extra: ParamMap): CrossValidatorModel = { - val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + val copied = new CrossValidatorModel( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) copyValues(copied, extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 5ba469c7b1..9b3619f004 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) } test("validateParams should check estimatorParamMaps") { -- GitLab