diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 23956c512c8a647ea2f8adff6af4fcb03a9ff94d..9db3b29e10d69460ff8f0cafc193cfa3eef70762 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputColName = map(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * :: AlphaComponent :: * A label indexer that maps a string column of labels to an ML column of label indices. + * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ @@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { val map = extractParamMap(paramMap) - val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val counts = dataset.select(col(map(inputCol)).cast(StringType)) + .map(_.getString(0)) + .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) Params.inheritValues(map, this, model) @@ -119,7 +125,8 @@ class StringIndexerModel private[ml] ( val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + dataset.select(col("*"), + indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 00b5d094d82f10170175c1f4da2634acb1ea2991..b6939e587041068be5dfded5abf90b523ab93604 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexer with a numeric input column") { + val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // 100 -> 0, 200 -> 2, 300 -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } }