diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
index 26ca1ef9be8700c0c1c2a7b37e11bc2a432ed8e3..0d223de9b6f7e824a305d9d3d75824b92867e610 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
@@ -125,6 +125,13 @@ class CoordinateMatrix @Since("1.0.0") (
       s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock")
     val m = numRows()
     val n = numCols()
+
+    // Since block matrices require an integer row and col index
+    require(math.ceil(m.toDouble / rowsPerBlock) <= Int.MaxValue,
+      "Number of rows divided by rowsPerBlock cannot exceed maximum integer.")
+    require(math.ceil(n.toDouble / colsPerBlock) <= Int.MaxValue,
+      "Number of cols divided by colsPerBlock cannot exceed maximum integer.")
+
     val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt
     val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt
     val partitioner = GridPartitioner(numRowBlocks, numColBlocks, entries.partitions.length)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index d7255d527f036391c6222dcb26f0e30fc24ecda6..8890662d99b52f16bd6ed4d1b6ebb504c3f968b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -91,7 +91,7 @@ class IndexedRowMatrix @Since("1.0.0") (
   }
 
   /**
-   * Converts to BlockMatrix. Creates blocks of `SparseMatrix` with size 1024 x 1024.
+   * Converts to BlockMatrix. Creates blocks with size 1024 x 1024.
    */
   @Since("1.3.0")
   def toBlockMatrix(): BlockMatrix = {
@@ -99,7 +99,7 @@ class IndexedRowMatrix @Since("1.0.0") (
   }
 
   /**
-   * Converts to BlockMatrix. Creates blocks of `SparseMatrix`.
+   * Converts to BlockMatrix. Blocks may be sparse or dense depending on the sparsity of the rows.
    * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have
    *                     a smaller value. Must be an integer value greater than 0.
    * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have
@@ -108,8 +108,70 @@ class IndexedRowMatrix @Since("1.0.0") (
    */
   @Since("1.3.0")
   def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = {
-    // TODO: This implementation may be optimized
-    toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock)
+    require(rowsPerBlock > 0,
+      s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock")
+    require(colsPerBlock > 0,
+      s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock")
+
+    val m = numRows()
+    val n = numCols()
+
+    // Since block matrices require an integer row index
+    require(math.ceil(m.toDouble / rowsPerBlock) <= Int.MaxValue,
+      "Number of rows divided by rowsPerBlock cannot exceed maximum integer.")
+
+    // The remainder calculations only matter when m % rowsPerBlock != 0 or n % colsPerBlock != 0
+    val remainderRowBlockIndex = m / rowsPerBlock
+    val remainderColBlockIndex = n / colsPerBlock
+    val remainderRowBlockSize = (m % rowsPerBlock).toInt
+    val remainderColBlockSize = (n % colsPerBlock).toInt
+    val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt
+    val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt
+
+    val blocks = rows.flatMap { ir: IndexedRow =>
+      val blockRow = ir.index / rowsPerBlock
+      val rowInBlock = ir.index % rowsPerBlock
+
+      ir.vector match {
+        case SparseVector(size, indices, values) =>
+          indices.zip(values).map { case (index, value) =>
+            val blockColumn = index / colsPerBlock
+            val columnInBlock = index % colsPerBlock
+            ((blockRow.toInt, blockColumn.toInt), (rowInBlock.toInt, Array((value, columnInBlock))))
+          }
+        case DenseVector(values) =>
+          values.grouped(colsPerBlock)
+            .zipWithIndex
+            .map { case (values, blockColumn) =>
+              ((blockRow.toInt, blockColumn), (rowInBlock.toInt, values.zipWithIndex))
+            }
+      }
+    }.groupByKey(GridPartitioner(numRowBlocks, numColBlocks, rows.getNumPartitions)).map {
+      case ((blockRow, blockColumn), itr) =>
+        val actualNumRows =
+          if (blockRow == remainderRowBlockIndex) remainderRowBlockSize else rowsPerBlock
+        val actualNumColumns =
+          if (blockColumn == remainderColBlockIndex) remainderColBlockSize else colsPerBlock
+
+        val arraySize = actualNumRows * actualNumColumns
+        val matrixAsArray = new Array[Double](arraySize)
+        var countForValues = 0
+        itr.foreach { case (rowWithinBlock, valuesWithColumns) =>
+          valuesWithColumns.foreach { case (value, columnWithinBlock) =>
+            matrixAsArray.update(columnWithinBlock * actualNumRows + rowWithinBlock, value)
+            countForValues += 1
+          }
+        }
+        val denseMatrix = new DenseMatrix(actualNumRows, actualNumColumns, matrixAsArray)
+        val finalMatrix = if (countForValues / arraySize.toDouble >= 0.1) {
+          denseMatrix
+        } else {
+          denseMatrix.toSparse
+        }
+
+        ((blockRow, blockColumn), finalMatrix)
+    }
+    new BlockMatrix(blocks, rowsPerBlock, colsPerBlock, m, n)
   }
 
   /**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 99af5fa10d999e12851d9d53db94d8dc08a29d2b..566ce95be084a551a7c0f375791d0dd82a89c6c4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.linalg.distributed
 import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.{Matrices, Vectors}
+import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 
@@ -87,19 +87,96 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(coordMat.toBreeze() === idxRowMat.toBreeze())
   }
 
-  test("toBlockMatrix") {
-    val idxRowMat = new IndexedRowMatrix(indexedRows)
-    val blockMat = idxRowMat.toBlockMatrix(2, 2)
+  test("toBlockMatrix dense backing") {
+    val idxRowMatDense = new IndexedRowMatrix(indexedRows)
+
+    // Tests when n % colsPerBlock != 0
+    val blockMat = idxRowMatDense.toBlockMatrix(2, 2)
     assert(blockMat.numRows() === m)
     assert(blockMat.numCols() === n)
-    assert(blockMat.toBreeze() === idxRowMat.toBreeze())
+    assert(blockMat.toBreeze() === idxRowMatDense.toBreeze())
+
+    // Tests when m % rowsPerBlock != 0
+    val blockMat2 = idxRowMatDense.toBlockMatrix(3, 1)
+    assert(blockMat2.numRows() === m)
+    assert(blockMat2.numCols() === n)
+    assert(blockMat2.toBreeze() === idxRowMatDense.toBreeze())
 
     intercept[IllegalArgumentException] {
-      idxRowMat.toBlockMatrix(-1, 2)
+      idxRowMatDense.toBlockMatrix(-1, 2)
     }
     intercept[IllegalArgumentException] {
-      idxRowMat.toBlockMatrix(2, 0)
+      idxRowMatDense.toBlockMatrix(2, 0)
     }
+
+    assert(blockMat.blocks.map { case (_, matrix: Matrix) =>
+      matrix.isInstanceOf[DenseMatrix]
+    }.reduce(_ && _))
+    assert(blockMat2.blocks.map { case (_, matrix: Matrix) =>
+      matrix.isInstanceOf[DenseMatrix]
+    }.reduce(_ && _))
+  }
+
+  test("toBlockMatrix sparse backing") {
+    val sparseData = Seq(
+      (15L, Vectors.sparse(12, Seq((0, 4.0))))
+    ).map(x => IndexedRow(x._1, x._2))
+
+    // Gonna make m and n larger here so the matrices can easily be completely sparse:
+    val m = 16
+    val n = 12
+
+    val idxRowMatSparse = new IndexedRowMatrix(sc.parallelize(sparseData))
+
+    // Tests when n % colsPerBlock != 0
+    val blockMat = idxRowMatSparse.toBlockMatrix(8, 8)
+    assert(blockMat.numRows() === m)
+    assert(blockMat.numCols() === n)
+    assert(blockMat.toBreeze() === idxRowMatSparse.toBreeze())
+
+    // Tests when m % rowsPerBlock != 0
+    val blockMat2 = idxRowMatSparse.toBlockMatrix(6, 6)
+    assert(blockMat2.numRows() === m)
+    assert(blockMat2.numCols() === n)
+    assert(blockMat2.toBreeze() === idxRowMatSparse.toBreeze())
+
+    assert(blockMat.blocks.collect().forall{ case (_, matrix: Matrix) =>
+      matrix.isInstanceOf[SparseMatrix]
+    })
+    assert(blockMat2.blocks.collect().forall{ case (_, matrix: Matrix) =>
+      matrix.isInstanceOf[SparseMatrix]
+    })
+  }
+
+  test("toBlockMatrix mixed backing") {
+    val m = 24
+    val n = 18
+
+    val mixedData = Seq(
+      (0L, Vectors.dense((0 to 17).map(_.toDouble).toArray)),
+      (1L, Vectors.dense((0 to 17).map(_.toDouble).toArray)),
+      (23L, Vectors.sparse(18, Seq((0, 4.0)))))
+      .map(x => IndexedRow(x._1, x._2))
+
+    val idxRowMatMixed = new IndexedRowMatrix(
+      sc.parallelize(mixedData))
+
+    // Tests when n % colsPerBlock != 0
+    val blockMat = idxRowMatMixed.toBlockMatrix(12, 12)
+    assert(blockMat.numRows() === m)
+    assert(blockMat.numCols() === n)
+    assert(blockMat.toBreeze() === idxRowMatMixed.toBreeze())
+
+    // Tests when m % rowsPerBlock != 0
+    val blockMat2 = idxRowMatMixed.toBlockMatrix(18, 6)
+    assert(blockMat2.numRows() === m)
+    assert(blockMat2.numCols() === n)
+    assert(blockMat2.toBreeze() === idxRowMatMixed.toBreeze())
+
+    val blocks = blockMat.blocks.collect()
+
+    assert(blocks.forall { case((row, col), matrix) =>
+      if (row == 0) matrix.isInstanceOf[DenseMatrix] else matrix.isInstanceOf[SparseMatrix]})
   }
 
   test("multiply a local matrix") {