diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 509c944fed74cb63a9e52b5fd196d9441bdc03b1..f257382d2205cc10b27b0db27a1730e972af815e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} @@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom */ private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { /** - * Param for the column name for user ids. + * Param for the column name for user ids. Ids must be integers. Other + * numeric types are supported for this column, but will be cast to integers as long as they + * fall within the integer value range. * Default: "user" * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids") + val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " + + "the integer value range.") /** @group getParam */ def getUserCol: String = $(userCol) /** - * Param for the column name for item ids. + * Param for the column name for item ids. Ids must be integers. Other + * numeric types are supported for this column, but will be cast to integers as long as they + * fall within the integer value range. * Default: "item" * @group param */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") + val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " + + "the integer value range.") /** @group getParam */ def getItemCol: String = $(itemCol) + + /** + * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is + * out of integer range. + */ + protected val checkedCast = udf { (n: Double) => + if (n > Int.MaxValue || n < Int.MinValue) { + throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + + s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") + } else { + n.toInt + } + } } /** @@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) - SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) - val ratingType = schema($(ratingCol)).dataType - require(ratingType == FloatType || ratingType == DoubleType) + // user and item will be cast to Int + SchemaUtils.checkNumericType(schema, $(userCol)) + SchemaUtils.checkNumericType(schema, $(itemCol)) + // rating will be cast to Float + SchemaUtils.checkNumericType(schema, $(ratingCol)) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } @@ -232,6 +252,7 @@ class ALSModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema) // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -242,16 +263,19 @@ class ALSModel private[ml] ( } } dataset - .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") - .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .join(userFactors, + checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") + .join(itemFactors, + checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) - SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + // user and item will be cast to Int + SchemaUtils.checkNumericType(schema, $(userCol)) + SchemaUtils.checkNumericType(schema, $(itemCol)) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } @@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") override def fit(dataset: Dataset[_]): ALSModel = { + transformSchema(dataset.schema) import dataset.sparkSession.implicits._ + val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) + .select(checkedCast(col($(userCol)).cast(DoubleType)), + checkedCast(col($(itemCol)).cast(DoubleType)), r) .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index bbfc415cbb9b7cdecee80c4404cfbeab286bc4fb..59b5edc4013e8b5b7d1e0e3d042290dd3168344f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.{FloatType, IntegerType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -205,7 +206,6 @@ class ALSSuite /** * Generates an explicit feedback dataset for testing ALS. - * * @param numUsers number of users * @param numItems number of items * @param rank rank @@ -246,7 +246,6 @@ class ALSSuite /** * Generates an implicit feedback dataset for testing ALS. - * * @param numUsers number of users * @param numItems number of items * @param rank rank @@ -265,7 +264,6 @@ class ALSSuite /** * Generates random user/item factors, with i.i.d. values drawn from U(a, b). - * * @param size number of users/items * @param rank number of features * @param random random number generator @@ -284,7 +282,6 @@ class ALSSuite /** * Test ALS using the given training/test splits and parameters. - * * @param training training dataset * @param test test dataset * @param rank rank of the matrix factorization @@ -486,6 +483,62 @@ class ALSSuite assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) } + + test("input type validation") { + val spark = this.spark + import spark.implicits._ + + // check that ALS can handle all numeric types for rating column + // and user/item columns (when the user/item ids are within Int range) + val als = new ALS().setMaxIter(1).setRank(1) + Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { + case (colName, sqlType) => + MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + (ex, act) => + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) + } { (ex, act, _) => + ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~== + act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6 + } + } + // check user/item ids falling outside of Int range + val big = Int.MaxValue.toLong + 1 + val small = Int.MinValue.toDouble - 1 + val df = Seq( + (0, 0L, 0d, 1, 1L, 1d, 3.0), + (0, big, small, 0, big, small, 2.0), + (1, 1L, 1d, 0, 0L, 0d, 5.0) + ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + withClue("fit should fail when ids exceed integer range. ") { + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + assert(intercept[IllegalArgumentException] { + als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) + }.getMessage.contains("was out of Integer range")) + } + withClue("transform should fail when ids exceed integer range. ") { + val model = als.fit(df) + assert(intercept[SparkException] { + model.transform(df.select(df("user_big").as("user"), df("item"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("user_small").as("user"), df("item"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("item_big").as("item"), df("user"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("item_small").as("item"), df("user"))).first + }.getMessage.contains("was out of Integer range")) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 6aae625fc83f2283c1632b8c85ab0b4af75e51d8..80b976914cbdfcaaf6717df5f82ab0899375fa0c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions._ @@ -58,6 +59,30 @@ object MLTestingUtils extends SparkFunSuite { "Column label must be of type NumericType but was actually of type StringType")) } + def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val expected = estimator.fit(dfs(baseType)) + val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } + + val baseDF = dfs(baseType) + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_)) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -116,6 +141,26 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } + def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String): Map[NumericType, DataFrame] = { + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col(_)) + val types: Seq[NumericType] = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types.map { t => + val cols = Seq(col(column).cast(t)) ++ others + t -> df.select(cols: _*) + }.toMap + } + def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index d7cb65846574b7822a3e6859e2dda79f18489b82..86c00d91652d167358317a8907877009b0d66345 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -110,10 +110,10 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha typeConverter=TypeConverters.toBoolean) alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference", typeConverter=TypeConverters.toFloat) - userCol = Param(Params._dummy(), "userCol", "column name for user ids", - typeConverter=TypeConverters.toString) - itemCol = Param(Params._dummy(), "itemCol", "column name for item ids", - typeConverter=TypeConverters.toString) + userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings", typeConverter=TypeConverters.toString) nonnegative = Param(Params._dummy(), "nonnegative",