From a5a3189974ea4628e9489eb50099a5432174e80c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= <facai.yan@gmail.com> Date: Fri, 28 Jul 2017 10:10:35 +0800 Subject: [PATCH] [SPARK-21306][ML] OneVsRest should support setWeightCol MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? add `setWeightCol` method for OneVsRest. `weightCol` is ignored if classifier doesn't inherit HasWeightCol trait. ## How was this patch tested? + [x] add an unit test. Author: Yan Facai (é¢œå‘æ‰) <facai.yan@gmail.com> Closes #18554 from facaiy/BUG/oneVsRest_missing_weightCol. --- .../spark/ml/classification/OneVsRest.scala | 39 +++++++++++++++++-- .../ml/classification/OneVsRestSuite.scala | 10 +++++ python/pyspark/ml/classification.py | 27 ++++++++++--- python/pyspark/ml/tests.py | 14 +++++++ 4 files changed, 81 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7cbcccf272..05b8c3ab54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait { +private[ml] trait OneVsRestParams extends PredictorParams + with ClassifierTypeTrait with HasWeightCol { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -294,6 +296,18 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** + * Sets the value of param [[weightCol]]. + * + * This is ignored if weight is not supported by [[classifier]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("2.3.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) @@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") ( val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) instr.logNumClasses(numClasses) - val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) + val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && { + getClassifier match { + case _: HasWeightCol => true + case c => + logWarning(s"weightCol is ignored, as it is not supported by $c now.") + false + } + } + + val multiclassLabeled = if (weightColIsUsed) { + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + } else { + dataset.select($(labelCol), $(featuresCol)) + } // persist if underlying dataset is not persistent. val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") ( paramMap.put(classifier.labelCol -> labelColName) paramMap.put(classifier.featuresCol -> getFeaturesCol) paramMap.put(classifier.predictionCol -> getPredictionCol) - classifier.fit(trainingDataset, paramMap) + if (weightColIsUsed) { + val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol] + paramMap.put(classifier_.weightCol -> getWeightCol) + classifier_.fit(trainingDataset, paramMap) + } else { + classifier.fit(trainingDataset, paramMap) + } }.toArray[ClassificationModel[_, _]] instr.logNumFeatures(models.head.numFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index c02e38ad64..17f82827b7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -156,6 +156,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } + test("SPARK-21306: OneVsRest should support setWeightCol") { + val dataset2 = dataset.withColumn("weight", lit(1)) + // classifier inherits hasWeightCol + val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression()) + assert(ova.fit(dataset2) !== null) + // classifier doesn't inherit hasWeightCol + val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier()) + assert(ova2.fit(dataset2) !== null) + } + test("OneVsRest.copy and OneVsRestModel.copy") { val lr = new LogisticRegression() .setMaxIter(1) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 82207f6644..4af6f71e19 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1447,7 +1447,7 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, Ja return self._call_java("weights") -class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): +class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol): """ Parameters for OneVsRest and OneVsRestModel. """ @@ -1517,10 +1517,10 @@ class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable): @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - classifier=None): + classifier=None, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - classifier=None) + classifier=None, weightCol=None) """ super(OneVsRest, self).__init__() kwargs = self._input_kwargs @@ -1528,9 +1528,11 @@ class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable): @keyword_only @since("2.0.0") - def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): + def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, + classifier=None, weightCol=None): """ - setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): + setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \ + classifier=None, weightCol=None): Sets params for OneVsRest. """ kwargs = self._input_kwargs @@ -1546,7 +1548,18 @@ class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - multiclassLabeled = dataset.select(labelCol, featuresCol) + weightCol = None + if (self.isDefined(self.weightCol) and self.getWeightCol()): + if isinstance(classifier, HasWeightCol): + weightCol = self.getWeightCol() + else: + warnings.warn("weightCol is ignored, " + "as it is not supported by {} now.".format(classifier)) + + if weightCol: + multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol) + else: + multiclassLabeled = dataset.select(labelCol, featuresCol) # persist if underlying dataset is not persistent. handlePersistence = \ @@ -1562,6 +1575,8 @@ class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable): paramMap = dict([(classifier.labelCol, binaryLabelCol), (classifier.featuresCol, featuresCol), (classifier.predictionCol, predictionCol)]) + if weightCol: + paramMap[classifier.weightCol] = weightCol return classifier.fit(trainingDataset, paramMap) # TODO: Parallel training for all classes. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6c71e69c9b..a9ca346fa5 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1394,6 +1394,20 @@ class OneVsRestTests(SparkSessionTestCase): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) + def test_support_for_weightCol(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), + (1.0, Vectors.sparse(2, [], []), 1.0), + (2.0, Vectors.dense(0.5, 0.5), 1.0)], + ["label", "features", "weight"]) + # classifier inherits hasWeightCol + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, weightCol="weight") + self.assertIsNotNone(ovr.fit(df)) + # classifier doesn't inherit hasWeightCol + dt = DecisionTreeClassifier() + ovr2 = OneVsRest(classifier=dt, weightCol="weight") + self.assertIsNotNone(ovr2.fit(df)) + class HashingTFTest(SparkSessionTestCase): -- GitLab