diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 82d2428f3c4448d90186e329fc60d5367905cf35..15af8298ba484cb4bdabab39636b8c4633a8e29f 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -69,6 +69,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.svmLinear} returns a fitted linear SVM model. #' @rdname spark.svmLinear @@ -98,7 +103,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @note spark.svmLinear since 2.2.0 setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE, - threshold = 0.0, weightCol = NULL, aggregationDepth = 2) { + threshold = 0.0, weightCol = NULL, aggregationDepth = 2, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (!is.null(weightCol) && weightCol == "") { @@ -107,10 +113,12 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu weightCol <- as.character(weightCol) } + handleInvalid <- match.arg(handleInvalid) + jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.integer(maxIter), as.numeric(tol), as.logical(standardization), as.numeric(threshold), - weightCol, as.integer(aggregationDepth)) + weightCol, as.integer(aggregationDepth), handleInvalid) new("LinearSVCModel", jobj = jobj) }) @@ -218,6 +226,11 @@ function(object, path, overwrite = FALSE) { #' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. #' The bound vector size must be equal to 1 for binomial regression, or the number #' of classes for multinomial regression. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -257,7 +270,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") tol = 1E-6, family = "auto", standardization = TRUE, thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL, - lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) { + lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") row <- 0 col <- 0 @@ -304,6 +318,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients)) } + handleInvalid <- match.arg(handleInvalid) + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.numeric(elasticNetParam), as.integer(maxIter), @@ -312,7 +328,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") weightCol, as.integer(aggregationDepth), as.integer(row), as.integer(col), lowerBoundsOnCoefficients, upperBoundsOnCoefficients, - lowerBoundsOnIntercepts, upperBoundsOnIntercepts) + lowerBoundsOnIntercepts, upperBoundsOnIntercepts, + handleInvalid) new("LogisticRegressionModel", jobj = jobj) }) @@ -394,7 +411,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @param stepSize stepSize parameter. #' @param seed seed parameter for weights initialization. #' @param initialWeights initialWeights parameter for weights initialization, it should be a -#' numeric vector. +#' numeric vector. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -426,7 +448,8 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @note spark.mlp since 2.1.0 setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, - tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") @@ -441,10 +464,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), if (!is.null(initialWeights)) { initialWeights <- as.array(as.numeric(na.omit(initialWeights))) } + handleInvalid <- match.arg(handleInvalid) jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), seed, initialWeights) + as.numeric(stepSize), seed, initialWeights, handleInvalid) new("MultilayerPerceptronClassificationModel", jobj = jobj) }) @@ -514,6 +538,11 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param smoothing smoothing parameter. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes @@ -543,10 +572,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' } #' @note spark.naiveBayes since 2.0.0 setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, smoothing = 1.0) { + function(data, formula, smoothing = 1.0, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") + handleInvalid <- match.arg(handleInvalid) jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", - formula, data@sdf, smoothing) + formula, data@sdf, smoothing, handleInvalid) new("NaiveBayesModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 75b1a74ee8c7cfb6b746f29158d61c3ca8551ecc..33c4653f4c184d84c7965e580b57ad1be9d5f588 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -164,6 +164,11 @@ print.summary.decisionTree <- function(x) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.gbt,SparkDataFrame,formula-method #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. @@ -205,7 +210,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, - checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -225,6 +231,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), new("GBTRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(lossType)) lossType <- "logistic" lossType <- match.arg(lossType, "logistic") jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", @@ -233,7 +240,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), as.numeric(stepSize), as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), lossType, seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("GBTClassificationModel", jobj = jobj) } ) @@ -374,10 +382,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. -#' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -583,6 +592,11 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.decisionTree,SparkDataFrame,formula-method #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. @@ -617,7 +631,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -636,6 +651,7 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo new("DecisionTreeRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", @@ -643,7 +659,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo as.integer(maxBins), impurity, as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), seed, - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("DecisionTreeClassificationModel", jobj = jobj) } ) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index 3d75f4ce11ec808df3dabb8a556339435b773c06..a4d0397236d177eea69ca8e00993b5f02b26258c 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -70,6 +70,20 @@ test_that("spark.svmLinear", { prediction <- collect(select(predict(model, df), "prediction")) expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "list") + }) test_that("spark.logit", { @@ -263,6 +277,21 @@ test_that("spark.logit", { virginicaCoefs <- summary$coefficients[, "virginica"] expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.logit(traindf, clicked ~ ., regParam = 0.5) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.logit(traindf, clicked ~ ., regParam = 0.5, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + }) test_that("spark.mlp", { @@ -344,6 +373,21 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3)) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "list") + }) test_that("spark.naiveBayes", { @@ -427,6 +471,20 @@ test_that("spark.naiveBayes", { expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) expect_equal(sum(s$apriori), 1) expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") }) sparkR.session.stop() diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index e31a65f8dfedb9abcd6c77da13cb731c599ea1b9..799f94401d0086b32318edee6c214559ff491f94 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -109,6 +109,20 @@ test_that("spark.gbt", { model <- spark.gbt(data, label ~ features, "classification") expect_equal(summary(model)$numFeatures, 692) } + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.gbt(traindf, clicked ~ ., type = "classification") + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") }) test_that("spark.randomForest", { @@ -328,6 +342,22 @@ test_that("spark.decisionTree", { model <- spark.decisionTree(data, label ~ features, "classification") expect_equal(summary(model)$numFeatures, 4) } + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.decisionTree(traindf, clicked ~ ., type = "classification", + maxDepth = 5, maxBins = 16) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.decisionTree(traindf, clicked ~ ., type = "classification", + maxDepth = 5, maxBins = 16, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") }) sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala index 7f59825504d8e1eae88b4cf44585ae3960d23c41..a90cae5869b2a372569539ae4cd53b75191a684b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala @@ -73,11 +73,13 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC checkpointInterval: Int, seed: String, maxMemoryInMB: Int, - cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): DecisionTreeClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index c07eadb30a4d2d2c72a5259e3c23c721eb4ed466..ecaeec5a7791a10162c831b4b137c24bcf5b6a99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -78,11 +78,13 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] seed: String, subsamplingRate: Double, maxMemoryInMB: Int, - cacheNodeIds: Boolean): GBTClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): GBTClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala index 0dd1f1146fbf807c5cb264ee60d7fff4654652af..7a22a71c3a8194ee28e5e9c8fe81fd263a4eab66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala @@ -79,12 +79,14 @@ private[r] object LinearSVCWrapper standardization: Boolean, threshold: Double, weightCol: String, - aggregationDepth: Int + aggregationDepth: Int, + handleInvalid: String ): LinearSVCWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index b96481acf46d7ecbac8cd8dc93b856fb0e7cbd20..18acf7d21656f0db5e9eabd96e322dcc322db20b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -103,12 +103,14 @@ private[r] object LogisticRegressionWrapper lowerBoundsOnCoefficients: Array[Double], upperBoundsOnCoefficients: Array[Double], lowerBoundsOnIntercepts: Array[Double], - upperBoundsOnIntercepts: Array[Double] + upperBoundsOnIntercepts: Array[Double], + handleInvalid: String ): LogisticRegressionWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index 48c87743dee605773f4951918d01b9ef4beb6f1f..62f642142701b77077731172838b022c8996a153 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -62,7 +62,7 @@ private[r] object MultilayerPerceptronClassifierWrapper val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" val PREDICTED_LABEL_COL = "prediction" - def fit( + def fit( // scalastyle:ignore data: DataFrame, formula: String, blockSize: Int, @@ -72,11 +72,13 @@ private[r] object MultilayerPerceptronClassifierWrapper tol: Double, stepSize: Double, seed: String, - initialWeights: Array[Double] + initialWeights: Array[Double], + handleInvalid: String ): MultilayerPerceptronClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 0afea4be3d1ddd27eef2982f1d0af83be5567899..fbf9f462ff5f67897d9a7acbc99726195bcf1b43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -57,10 +57,15 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" val PREDICTED_LABEL_COL = "prediction" - def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = { + def fit( + formula: String, + data: DataFrame, + smoothing: Double, + handleInvalid: String): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema