Skip to content
Snippets Groups Projects
Commit bdfe7f67 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Yanbo Liang
Browse files

[SPARK-18625][ML] OneVsRestModel should support setFeaturesCol and setPredictionCol

## What changes were proposed in this pull request?
add `setFeaturesCol` and `setPredictionCol` for `OneVsRestModel`

## How was this patch tested?
added tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #16059 from zhengruifeng/ovrm_setCol.
parent e9730b70
No related branches found
No related tags found
No related merge requests found
...@@ -140,6 +140,14 @@ final class OneVsRestModel private[ml] ( ...@@ -140,6 +140,14 @@ final class OneVsRestModel private[ml] (
this(uid, Metadata.empty, models.asScala.toArray) this(uid, Metadata.empty, models.asScala.toArray)
} }
/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
@Since("1.4.0") @Since("1.4.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
...@@ -175,6 +183,7 @@ final class OneVsRestModel private[ml] ( ...@@ -175,6 +183,7 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1))) predictions + ((index, prediction(1)))
} }
model.setFeaturesCol($(featuresCol))
val transformedDataset = model.transform(df).select(columns: _*) val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset val updatedDataset = transformedDataset
.withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
......
...@@ -22,7 +22,7 @@ import org.apache.spark.ml.attribute.NominalAttribute ...@@ -22,7 +22,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
...@@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext ...@@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.Metadata import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
...@@ -136,6 +137,17 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau ...@@ -136,6 +137,17 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(outputFields.contains("p")) assert(outputFields.contains("p"))
} }
test("SPARK-18625 : OneVsRestModel should support setFeaturesCol and setPredictionCol") {
val ova = new OneVsRest().setClassifier(new LogisticRegression)
val ovaModel = ova.fit(dataset)
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
ovaModel.setFeaturesCol("fea")
ovaModel.setPredictionCol("pred")
val transformedDataset = ovaModel.transform(dataset2)
val outputFields = transformedDataset.schema.fieldNames.toSet
assert(outputFields === Set("y", "fea", "pred"))
}
test("SPARK-8049: OneVsRest shouldn't output temp columns") { test("SPARK-8049: OneVsRest shouldn't output temp columns") {
val logReg = new LogisticRegression() val logReg = new LogisticRegression()
.setMaxIter(1) .setMaxIter(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment