diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 509c944fed74cb63a9e52b5fd196d9441bdc03b1..f257382d2205cc10b27b0db27a1730e972af815e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
@@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom
  */
 private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
   /**
-   * Param for the column name for user ids.
+   * Param for the column name for user ids. Ids must be integers. Other
+   * numeric types are supported for this column, but will be cast to integers as long as they
+   * fall within the integer value range.
    * Default: "user"
    * @group param
    */
-  val userCol = new Param[String](this, "userCol", "column name for user ids")
+  val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " +
+    "the integer value range.")
 
   /** @group getParam */
   def getUserCol: String = $(userCol)
 
   /**
-   * Param for the column name for item ids.
+   * Param for the column name for item ids. Ids must be integers. Other
+   * numeric types are supported for this column, but will be cast to integers as long as they
+   * fall within the integer value range.
    * Default: "item"
    * @group param
    */
-  val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+  val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " +
+    "the integer value range.")
 
   /** @group getParam */
   def getItemCol: String = $(itemCol)
+
+  /**
+   * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
+   * out of integer range.
+   */
+  protected val checkedCast = udf { (n: Double) =>
+    if (n > Int.MaxValue || n < Int.MinValue) {
+      throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
+        s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
+    } else {
+      n.toInt
+    }
+  }
 }
 
 /**
@@ -193,10 +212,11 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
-    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
-    val ratingType = schema($(ratingCol)).dataType
-    require(ratingType == FloatType || ratingType == DoubleType)
+    // user and item will be cast to Int
+    SchemaUtils.checkNumericType(schema, $(userCol))
+    SchemaUtils.checkNumericType(schema, $(itemCol))
+    // rating will be cast to Float
+    SchemaUtils.checkNumericType(schema, $(ratingCol))
     SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
   }
 }
@@ -232,6 +252,7 @@ class ALSModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema)
     // Register a UDF for DataFrame, and then
     // create a new column named map(predictionCol) by running the predict UDF.
     val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -242,16 +263,19 @@ class ALSModel private[ml] (
       }
     }
     dataset
-      .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
-      .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+      .join(userFactors,
+        checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
+      .join(itemFactors,
+        checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
       .select(dataset("*"),
         predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
   }
 
   @Since("1.3.0")
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
-    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+    // user and item will be cast to Int
+    SchemaUtils.checkNumericType(schema, $(userCol))
+    SchemaUtils.checkNumericType(schema, $(itemCol))
     SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
   }
 
@@ -430,10 +454,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): ALSModel = {
+    transformSchema(dataset.schema)
     import dataset.sparkSession.implicits._
+
     val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
     val ratings = dataset
-      .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
+      .select(checkedCast(col($(userCol)).cast(DoubleType)),
+        checkedCast(col($(itemCol)).cast(DoubleType)), r)
       .rdd
       .map { row =>
         Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bbfc415cbb9b7cdecee80c4404cfbeab286bc4fb..59b5edc4013e8b5b7d1e0e3d042290dd3168344f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
 import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.types.{FloatType, IntegerType}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
@@ -205,7 +206,6 @@ class ALSSuite
 
   /**
    * Generates an explicit feedback dataset for testing ALS.
-   *
    * @param numUsers number of users
    * @param numItems number of items
    * @param rank rank
@@ -246,7 +246,6 @@ class ALSSuite
 
   /**
    * Generates an implicit feedback dataset for testing ALS.
-   *
    * @param numUsers number of users
    * @param numItems number of items
    * @param rank rank
@@ -265,7 +264,6 @@ class ALSSuite
 
   /**
    * Generates random user/item factors, with i.i.d. values drawn from U(a, b).
-   *
    * @param size number of users/items
    * @param rank number of features
    * @param random random number generator
@@ -284,7 +282,6 @@ class ALSSuite
 
   /**
    * Test ALS using the given training/test splits and parameters.
-   *
    * @param training training dataset
    * @param test test dataset
    * @param rank rank of the matrix factorization
@@ -486,6 +483,62 @@ class ALSSuite
     assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
     assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
   }
+
+  test("input type validation") {
+    val spark = this.spark
+    import spark.implicits._
+
+    // check that ALS can handle all numeric types for rating column
+    // and user/item columns (when the user/item ids are within Int range)
+    val als = new ALS().setMaxIter(1).setRank(1)
+    Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach {
+      case (colName, sqlType) =>
+        MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) {
+          (ex, act) =>
+            ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
+        } { (ex, act, _) =>
+          ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
+            act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
+        }
+    }
+    // check user/item ids falling outside of Int range
+    val big = Int.MaxValue.toLong + 1
+    val small = Int.MinValue.toDouble - 1
+    val df = Seq(
+      (0, 0L, 0d, 1, 1L, 1d, 3.0),
+      (0, big, small, 0, big, small, 2.0),
+      (1, 1L, 1d, 0, 0L, 0d, 5.0)
+    ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
+    withClue("fit should fail when ids exceed integer range. ") {
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[IllegalArgumentException] {
+        als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
+      }.getMessage.contains("was out of Integer range"))
+    }
+    withClue("transform should fail when ids exceed integer range. ") {
+      val model = als.fit(df)
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("user_big").as("user"), df("item"))).first
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("user_small").as("user"), df("item"))).first
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("item_big").as("item"), df("user"))).first
+      }.getMessage.contains("was out of Integer range"))
+      assert(intercept[SparkException] {
+        model.transform(df.select(df("item_small").as("item"), df("user"))).first
+      }.getMessage.contains("was out of Integer range"))
+    }
+  }
 }
 
 class ALSCleanerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 6aae625fc83f2283c1632b8c85ab0b4af75e51d8..80b976914cbdfcaaf6717df5f82ab0899375fa0c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.apache.spark.sql.functions._
@@ -58,6 +59,30 @@ object MLTestingUtils extends SparkFunSuite {
       "Column label must be of type NumericType but was actually of type StringType"))
   }
 
+  def checkNumericTypesALS(
+      estimator: ALS,
+      spark: SparkSession,
+      column: String,
+      baseType: NumericType)
+      (check: (ALSModel, ALSModel) => Unit)
+      (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = {
+    val dfs = genRatingsDFWithNumericCols(spark, column)
+    val expected = estimator.fit(dfs(baseType))
+    val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t))))
+    actuals.foreach { case (_, actual) => check(expected, actual) }
+    actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
+
+    val baseDF = dfs(baseType)
+    val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_))
+    val cols = Seq(col(column).cast(StringType)) ++ others
+    val strDF = baseDF.select(cols: _*)
+    val thrown = intercept[IllegalArgumentException] {
+      estimator.fit(strDF)
+    }
+    assert(thrown.getMessage.contains(
+      s"$column must be of type NumericType but was actually of type StringType"))
+  }
+
   def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
     val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
     val expected = evaluator.evaluate(dfs(DoubleType))
@@ -116,6 +141,26 @@ object MLTestingUtils extends SparkFunSuite {
       }.toMap
   }
 
+  def genRatingsDFWithNumericCols(
+      spark: SparkSession,
+      column: String): Map[NumericType, DataFrame] = {
+    val df = spark.createDataFrame(Seq(
+      (0, 10, 1.0),
+      (1, 20, 2.0),
+      (2, 30, 3.0),
+      (3, 40, 4.0),
+      (4, 50, 5.0)
+    )).toDF("user", "item", "rating")
+
+    val others = df.columns.toSeq.diff(Seq(column)).map(col(_))
+    val types: Seq[NumericType] =
+      Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+    types.map { t =>
+      val cols = Seq(col(column).cast(t)) ++ others
+      t -> df.select(cols: _*)
+    }.toMap
+  }
+
   def genEvaluatorDFWithNumericLabelCol(
       spark: SparkSession,
       labelColName: String = "label",
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index d7cb65846574b7822a3e6859e2dda79f18489b82..86c00d91652d167358317a8907877009b0d66345 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -110,10 +110,10 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
                           typeConverter=TypeConverters.toBoolean)
     alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
                   typeConverter=TypeConverters.toFloat)
-    userCol = Param(Params._dummy(), "userCol", "column name for user ids",
-                    typeConverter=TypeConverters.toString)
-    itemCol = Param(Params._dummy(), "itemCol", "column name for item ids",
-                    typeConverter=TypeConverters.toString)
+    userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " +
+                    "the integer value range.", typeConverter=TypeConverters.toString)
+    itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " +
+                    "the integer value range.", typeConverter=TypeConverters.toString)
     ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
                       typeConverter=TypeConverters.toString)
     nonnegative = Param(Params._dummy(), "nonnegative",