From 9893dc975784551a62f65bbd709f8972e0204b2a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng <ruifengz@foxmail.com> Date: Fri, 27 May 2016 21:57:41 -0500 Subject: [PATCH] [SPARK-15610][ML] update error message for k in pca ## What changes were proposed in this pull request? Fix the wrong bound of `k` in `PCA` `require(k <= sources.first().size, ...` -> `require(k < sources.first().size` BTW, remove unused import in `ml.ElementwiseProduct` ## How was this patch tested? manual tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #13356 from zhengruifeng/fix_pca. --- .../org/apache/spark/ml/feature/ElementwiseProduct.scala | 1 - .../src/main/scala/org/apache/spark/mllib/feature/PCA.scala | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 91989c3d2f..9d2e60fa3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -23,7 +23,6 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.types.DataType diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 30c403e547..15b72205ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { */ @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { - require(k <= sources.first().size, - s"source vector size is ${sources.first().size} must be greater than k=$k") + val numFeatures = sources.first().size + require(k <= numFeatures, + s"source vector size $numFeatures must be no less than k=$k") val mat = new RowMatrix(sources) val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) @@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { case m => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") - } val denseExplainedVariance = explainedVariance match { case dv: DenseVector => -- GitLab