diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index ab5f4a1a9a6c452648b50e597ea3bb2ecbb84cfa..e7ca7ada74c8cade37adfa90e26c3e1fdf9529c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types._ @@ -68,7 +70,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { * will be created from the specified response variable in the formula. */ @Experimental -class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { +class RFormula(override val uid: String) + extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("rFormula")) @@ -180,6 +183,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } +@Since("2.0.0") +object RFormula extends DefaultParamsReadable[RFormula] { + + @Since("2.0.0") + override def load(path: String): RFormula = super.load(path) +} + /** * :: Experimental :: * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. @@ -189,9 +199,9 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R @Experimental class RFormulaModel private[feature]( override val uid: String, - resolvedFormula: ResolvedRFormula, - pipelineModel: PipelineModel) - extends Model[RFormulaModel] with RFormulaBase { + private[ml] val resolvedFormula: ResolvedRFormula, + private[ml] val pipelineModel: PipelineModel) + extends Model[RFormulaModel] with RFormulaBase with MLWritable { override def transform(dataset: DataFrame): DataFrame = { checkCanTransform(dataset.schema) @@ -246,14 +256,71 @@ class RFormulaModel private[feature]( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, "Label column already exists and is not of type DoubleType.") } + + @Since("2.0.0") + override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this) +} + +@Since("2.0.0") +object RFormulaModel extends MLReadable[RFormulaModel] { + + @Since("2.0.0") + override def read: MLReader[RFormulaModel] = new RFormulaModelReader + + @Since("2.0.0") + override def load(path: String): RFormulaModel = super.load(path) + + /** [[MLWriter]] instance for [[RFormulaModel]] */ + private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: resolvedFormula + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(instance.resolvedFormula)) + .repartition(1).write.parquet(dataPath) + // Save pipeline model + val pmPath = new Path(path, "pipelineModel").toString + instance.pipelineModel.save(pmPath) + } + } + + private class RFormulaModelReader extends MLReader[RFormulaModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RFormulaModel].getName + + override def load(path: String): RFormulaModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() + val label = data.getString(0) + val terms = data.getAs[Seq[Seq[String]]](1) + val hasIntercept = data.getBoolean(2) + val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept) + + val pmPath = new Path(path, "pipelineModel").toString + val pipelineModel = PipelineModel.load(pmPath) + + val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** * Utility transformer for removing temporary columns from a DataFrame. * TODO(ekl) make this a public transformer */ -private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { - override val uid = Identifiable.randomUID("columnPruner") +private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String]) + extends Transformer with MLWritable { + + def this(columnsToPrune: Set[String]) = + this(Identifiable.randomUID("columnPruner"), columnsToPrune) override def transform(dataset: DataFrame): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) @@ -265,6 +332,48 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) + + override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this) +} + +private object ColumnPruner extends MLReadable[ColumnPruner] { + + override def read: MLReader[ColumnPruner] = new ColumnPrunerReader + + override def load(path: String): ColumnPruner = super.load(path) + + /** [[MLWriter]] instance for [[ColumnPruner]] */ + private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter { + + private case class Data(columnsToPrune: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: columnsToPrune + val data = Data(instance.columnsToPrune.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ColumnPrunerReader extends MLReader[ColumnPruner] { + + /** Checked against metadata when loading model */ + private val className = classOf[ColumnPruner].getName + + override def load(path: String): ColumnPruner = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head() + val columnsToPrune = data.getAs[Seq[String]](0).toSet + val pruner = new ColumnPruner(metadata.uid, columnsToPrune) + + DefaultParamsReader.getAndSetParams(pruner, metadata) + pruner + } + } } /** @@ -278,11 +387,13 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { * by the value in the map. */ private class VectorAttributeRewriter( - vectorCol: String, - prefixesToRewrite: Map[String, String]) - extends Transformer { + override val uid: String, + val vectorCol: String, + val prefixesToRewrite: Map[String, String]) + extends Transformer with MLWritable { - override val uid = Identifiable.randomUID("vectorAttrRewriter") + def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = + this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) override def transform(dataset: DataFrame): DataFrame = { val metadata = { @@ -315,4 +426,48 @@ private class VectorAttributeRewriter( } override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) + + override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this) +} + +private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { + + override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader + + override def load(path: String): VectorAttributeRewriter = super.load(path) + + /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */ + private[VectorAttributeRewriter] + class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter { + + private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: vectorCol, prefixesToRewrite + val data = Data(instance.vectorCol, instance.prefixesToRewrite) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] { + + /** Checked against metadata when loading model */ + private val className = classOf[VectorAttributeRewriter].getName + + override def load(path: String): VectorAttributeRewriter = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() + val vectorCol = data.getString(0) + val prefixesToRewrite = data.getAs[Map[String, String]](1) + val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) + + DefaultParamsReader.getAndSetParams(rewriter, metadata) + rewriter + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 010e7d2686571ce378eb1f3bacddb0942e2e71dd..3d7a91dd39a71d11cd316c5ce879891f08379747 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -221,10 +221,7 @@ object CrossValidator extends MLReadable[CrossValidator] { // TODO: SPARK-11892: This case may require special handling. throw new UnsupportedOperationException("CrossValidator write will fail because it" + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") - case rform: RFormulaModel => - // TODO: SPARK-11891: This case may require special handling. - throw new UnsupportedOperationException("CrossValidator write will fail because it" + - " cannot yet handle an estimator containing an RFormulaModel") + case rformModel: RFormulaModel => Array(rformModel.pipelineModel) case _: Params => Array() } val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 16e565d8b588b6574c3339c774f07febf5c40055..e1b269b5b681f1955102bfd8aead41709e0f4144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -252,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) assert(attrs === expectedAttrs) } + + test("read/write: RFormula") { + val rFormula = new RFormula() + .setFormula("id ~ a:b") + .setFeaturesCol("myFeatures") + .setLabelCol("myLabels") + + testDefaultReadWrite(rFormula) + } + + test("read/write: RFormulaModel") { + def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = { + assert(model.uid === model2.uid) + + assert(model.resolvedFormula.label === model2.resolvedFormula.label) + assert(model.resolvedFormula.terms === model2.resolvedFormula.terms) + assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept) + + assert(model.pipelineModel.uid === model2.pipelineModel.uid) + + model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach { + case (transformer1, transformer2) => + assert(transformer1.uid === transformer2.uid) + assert(transformer1.params === transformer2.params) + } + } + + val dataset = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + + val rFormula = new RFormula().setFormula("id ~ a:b") + + val model = rFormula.fit(dataset) + val newModel = testDefaultReadWrite(model) + checkModelData(model, newModel) + } }