diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 23956c512c8a647ea2f8adff6af4fcb03a9ff94d..9db3b29e10d69460ff8f0cafc193cfa3eef70762 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.{NumericType, StringType, StructType}
 import org.apache.spark.util.collection.OpenHashMap
 
 /**
@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = extractParamMap(paramMap)
-    SchemaUtils.checkColumnType(schema, map(inputCol), StringType)
+    val inputColName = map(inputCol)
+    val inputDataType = schema(inputColName).dataType
+    require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
+      s"The input column $inputColName must be either string type or numeric type, " +
+        s"but got $inputDataType.")
     val inputFields = schema.fields
     val outputColName = map(outputCol)
     require(inputFields.forall(_.name != outputColName),
@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
 /**
  * :: AlphaComponent ::
  * A label indexer that maps a string column of labels to an ML column of label indices.
+ * If the input column is numeric, we cast it to string and index the string values.
  * The indices are in [0, numLabels), ordered by label frequencies.
  * So the most frequent label gets index 0.
  */
@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
 
   override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
     val map = extractParamMap(paramMap)
-    val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
+    val counts = dataset.select(col(map(inputCol)).cast(StringType))
+      .map(_.getString(0))
+      .countByValue()
     val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
     val model = new StringIndexerModel(this, map, labels)
     Params.inheritValues(map, this, model)
@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] (
     val outputColName = map(outputCol)
     val metadata = NominalAttribute.defaultAttr
       .withName(outputColName).withValues(labels).toMetadata()
-    dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
+    dataset.select(col("*"),
+      indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata))
   }
 
   override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 00b5d094d82f10170175c1f4da2634acb1ea2991..b6939e587041068be5dfded5abf90b523ab93604 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
     val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
     assert(output === expected)
   }
+
+  test("StringIndexer with a numeric input column") {
+    val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    val transformed = indexer.transform(df)
+    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attr.values.get === Array("100", "300", "200"))
+    val output = transformed.select("id", "labelIndex").map { r =>
+      (r.getInt(0), r.getDouble(1))
+    }.collect().toSet
+    // 100 -> 0, 200 -> 2, 300 -> 1
+    val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
+    assert(output === expected)
+  }
 }