Skip to content
Snippets Groups Projects
Commit d36e6735 authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Joseph K. Bradley
Browse files

[SPARK-6965] [MLLIB] StringIndexer handles numeric input.

Cast numeric types to String for indexing. Boolean type is not handled in this PR. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #5753 from mengxr/SPARK-6965 and squashes the following commits:

2e34f3c [Xiangrui Meng] add actual type in the error message
ad938bf [Xiangrui Meng] StringIndexer handles numeric input.
parent 555213eb
No related branches found
No related tags found
No related merge requests found
......@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.types.{NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap
/**
......@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
SchemaUtils.checkColumnType(schema, map(inputCol), StringType)
val inputColName = map(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric type, " +
s"but got $inputDataType.")
val inputFields = schema.fields
val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName),
......@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/**
* :: AlphaComponent ::
* A label indexer that maps a string column of labels to an ML column of label indices.
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
*/
......@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
val map = extractParamMap(paramMap)
val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
val counts = dataset.select(col(map(inputCol)).cast(StringType))
.map(_.getString(0))
.countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val model = new StringIndexerModel(this, map, labels)
Params.inheritValues(map, this, model)
......@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] (
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
dataset.select(col("*"),
indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
......
......@@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("100", "300", "200"))
val output = transformed.select("id", "labelIndex").map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// 100 -> 0, 200 -> 2, 300 -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment