Skip to content
Snippets Groups Projects
Commit d36e6735 authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Joseph K. Bradley
Browse files

[SPARK-6965] [MLLIB] StringIndexer handles numeric input.

Cast numeric types to String for indexing. Boolean type is not handled in this PR. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #5753 from mengxr/SPARK-6965 and squashes the following commits:

2e34f3c [Xiangrui Meng] add actual type in the error message
ad938bf [Xiangrui Meng] StringIndexer handles numeric input.
parent 555213eb
No related branches found
No related tags found
No related merge requests found
...@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model} ...@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.types.{NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.collection.OpenHashMap
/** /**
...@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha ...@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/** Validates and transforms the input schema. */ /** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap) val map = extractParamMap(paramMap)
SchemaUtils.checkColumnType(schema, map(inputCol), StringType) val inputColName = map(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric type, " +
s"but got $inputDataType.")
val inputFields = schema.fields val inputFields = schema.fields
val outputColName = map(outputCol) val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName), require(inputFields.forall(_.name != outputColName),
...@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha ...@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/** /**
* :: AlphaComponent :: * :: AlphaComponent ::
* A label indexer that maps a string column of labels to an ML column of label indices. * A label indexer that maps a string column of labels to an ML column of label indices.
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies. * The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0. * So the most frequent label gets index 0.
*/ */
...@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase ...@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
val map = extractParamMap(paramMap) val map = extractParamMap(paramMap)
val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() val counts = dataset.select(col(map(inputCol)).cast(StringType))
.map(_.getString(0))
.countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val model = new StringIndexerModel(this, map, labels) val model = new StringIndexerModel(this, map, labels)
Params.inheritValues(map, this, model) Params.inheritValues(map, this, model)
...@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] ( ...@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] (
val outputColName = map(outputCol) val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata() .withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) dataset.select(col("*"),
indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata))
} }
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
......
...@@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { ...@@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected) assert(output === expected)
} }
test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("100", "300", "200"))
val output = transformed.select("id", "labelIndex").map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// 100 -> 0, 200 -> 2, 300 -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment