diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 02e2384afe530086517ec5f3423d16f602b7644f..6d2c59a905ec7a10ac58764139e19277a1c7e2b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -678,6 +678,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(!ratings.isEmpty(), s"No ratings available from $ratings") require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") val sc = ratings.sparkContext diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index cc9ee15738ad6509fe210277de54461fc89fc8bc..0039db7ecbbc79a7004d2514ffdad55e71508d61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -236,6 +236,8 @@ class ALS private ( */ @Since("0.8.0") def run(ratings: RDD[Rating]): MatrixFactorizationModel = { + require(!ratings.isEmpty(), s"No ratings available from $ratings") + val sc = ratings.context val numUserBlocks = if (this.numUserBlocks == -1) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index d0aa2cdfe0fd19ca2e50dafd35852d690e189d96..b923bacce23ca813defbfc5ad105d29ab1f8b5d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.recommendation.ALS.Rating import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -539,6 +540,13 @@ class ALSSuite }.getMessage.contains("was out of Integer range")) } } + + test("SPARK-18268: ALS with empty RDD should fail with better message") { + val ratings = sc.parallelize(Array.empty[Rating[Int]]) + intercept[IllegalArgumentException] { + ALS.train(ratings) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index d9dc557e3b2b95317d9784e017b7ffb515541058..b08ad99f4f2049ace54b5f956d346977c559de0d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -188,6 +188,13 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } + test("SPARK-18268: ALS with empty RDD should fail with better message") { + val ratings = sc.parallelize(Array.empty[Rating]) + intercept[IllegalArgumentException] { + new ALS().run(ratings) + } + } + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. *