diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 541f3288b6c43e6d873263afe8f9c9a4d89dc0a9..52d6468a72af7217fcff3c557d262ee8e0dacad3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -82,6 +82,18 @@ object MLUtils { val value = indexAndValue(1).toDouble (index, value) }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, "indices should be one-based and in ascending order" ) + previous = current + i += 1 + } + (label, indices.toArray, values.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 734b7babec7bedca019f67c203b8a2f5bfd9ac11..70219e9ad9d3ef5a5621d54d1b27467f171f411c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") { + val lines = + """ + |0 + |0 0:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + + test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") { + val lines = + """ + |0 + |0 3:4.0 2:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + test("saveAsLibSVMFile") { val examples = sc.parallelize(Seq( LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),