diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 4ca062c0b5adf763932b42e5002a8885827442cb..b6909b3386b717dd07145cf06d7cb5df8430f9a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
 
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
@@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
       val wordVectors = instance.wordVectors.getVectors
       val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) }
       val dataPath = new Path(path, "data").toString
+      val bufferSizeInBytes = Utils.byteStringAsBytes(
+        sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
+      val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions(
+        bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize)
       sparkSession.createDataFrame(dataSeq)
-        .repartition(calculateNumberOfPartitions)
+        .repartition(numPartitions)
         .write
         .parquet(dataPath)
     }
+  }
 
-    def calculateNumberOfPartitions(): Int = {
-      val floatSize = 4
+  private[feature]
+  object Word2VecModelWriter {
+    /**
+     * Calculate the number of partitions to use in saving the model.
+     * [SPARK-11994] - We want to partition the model in partitions smaller than
+     * spark.kryoserializer.buffer.max
+     * @param bufferSizeInBytes  Set to spark.kryoserializer.buffer.max
+     * @param numWords  Vocab size
+     * @param vectorSize  Vector length for each word
+     */
+    def calculateNumberOfPartitions(
+        bufferSizeInBytes: Long,
+        numWords: Int,
+        vectorSize: Int): Int = {
+      val floatSize = 4L  // Use Long to help avoid overflow
       val averageWordSize = 15
-      // [SPARK-11994] - We want to partition the model in partitions smaller than
-      // spark.kryoserializer.buffer.max
-      val bufferSizeInBytes = Utils.byteStringAsBytes(
-        sc.conf.get("spark.kryoserializer.buffer.max", "64m"))
       // Calculate the approximate size of the model.
       // Assuming an average word size of 15 bytes, the formula is:
       // (floatSize * vectorSize + 15) * numWords
-      val numWords = instance.wordVectors.wordIndex.size
-      val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords
-      ((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt
+      val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords
+      val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1
+      require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " +
+        s"partitions to save this model, which is too large.  Try increasing " +
+        s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.")
+      numPartitions.toInt
     }
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index a6a1c2b4f32bde0e9ad632ff70a8f5cbfd3c6f29..6183606a7b2ac4d61ef354a98dbe808e587931a8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
+import org.apache.spark.util.Utils
 
 class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
@@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
     assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
   }
 
+  test("Word2Vec read/write numPartitions calculation") {
+    val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
+      Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5)
+    assert(smallModelNumPartitions === 1)
+    val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
+      Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000)
+    assert(largeModelNumPartitions > 1)
+  }
+
   test("Word2Vec read/write") {
     val t = new Word2Vec()
       .setInputCol("myInputCol")