From d8662cd909a41575df6e0ea1630d2386d3711240 Mon Sep 17 00:00:00 2001
From: leahmcguire <lmcguire@salesforce.com>
Date: Wed, 3 Jun 2015 15:46:38 -0700
Subject: [PATCH] [SPARK-6164] [ML] CrossValidatorModel should keep stats from
 fitting

Added stats from cross validation as a val in the cross validation model to save them for user access.

Author: leahmcguire <lmcguire@salesforce.com>

Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits:

49b507b [leahmcguire] fixed tyle error
67537b1 [leahmcguire] rebased
85907f0 [leahmcguire] fixed name
59987cc [leahmcguire] changed param name and test according to comments
36e71e3 [leahmcguire] rebasing
4b8223e [leahmcguire] fixed name
4ddffc6 [leahmcguire] changed param name and test according to comments
3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access
---
 .../org/apache/spark/ml/tuning/CrossValidator.scala    | 10 +++++++---
 .../apache/spark/ml/tuning/CrossValidatorSuite.scala   |  1 +
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 6434b64aed..cb29392e8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
     logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
     logInfo(s"Best cross-validation metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
-    copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
+    copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
   }
 
   override def transformSchema(schema: StructType): StructType = {
@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
 @Experimental
 class CrossValidatorModel private[ml] (
     override val uid: String,
-    val bestModel: Model[_])
+    val bestModel: Model[_],
+    val avgMetrics: Array[Double])
   extends Model[CrossValidatorModel] with CrossValidatorParams {
 
   override def validateParams(): Unit = {
@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] (
   }
 
   override def copy(extra: ParamMap): CrossValidatorModel = {
-    val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+    val copied = new CrossValidatorModel(
+      uid,
+      bestModel.copy(extra).asInstanceOf[Model[_]],
+      avgMetrics.clone())
     copyValues(copied, extra)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 5ba469c7b1..9b3619f004 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
     val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
     assert(parent.getRegParam === 0.001)
     assert(parent.getMaxIter === 10)
+    assert(cvModel.avgMetrics.length === lrParamMaps.length)
   }
 
   test("validateParams should check estimatorParamMaps") {
-- 
GitLab