diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 99cb6516e065ce17c3a9a9291698950b7efceeec..711e32a330d7d3a17db89a844fa8979b007d160a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -28,138 +28,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.Logging
-import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
-
-/**
- * Column statistics aggregator implementing
- * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
- * together with add() and merge() function.
- * A numerically stable algorithm is implemented to compute sample mean and variance:
- * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
- * Zero elements (including explicit zero values) are skipped when calling add() and merge(),
- * to have time complexity O(nnz) instead of O(n) for each column.
- */
-private class ColumnStatisticsAggregator(private val n: Int)
-    extends MultivariateStatisticalSummary with Serializable {
-
-  private val currMean: BDV[Double] = BDV.zeros[Double](n)
-  private val currM2n: BDV[Double] = BDV.zeros[Double](n)
-  private var totalCnt = 0.0
-  private val nnz: BDV[Double] = BDV.zeros[Double](n)
-  private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
-  private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
-
-  override def mean: Vector = {
-    val realMean = BDV.zeros[Double](n)
-    var i = 0
-    while (i < n) {
-      realMean(i) = currMean(i) * nnz(i) / totalCnt
-      i += 1
-    }
-    Vectors.fromBreeze(realMean)
-  }
-
-  override def variance: Vector = {
-    val realVariance = BDV.zeros[Double](n)
-
-    val denominator = totalCnt - 1.0
-
-    // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
-    if (denominator > 0.0) {
-      val deltaMean = currMean
-      var i = 0
-      while (i < currM2n.size) {
-        realVariance(i) =
-          currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
-        realVariance(i) /= denominator
-        i += 1
-      }
-    }
-
-    Vectors.fromBreeze(realVariance)
-  }
-
-  override def count: Long = totalCnt.toLong
-
-  override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
-
-  override def max: Vector = {
-    var i = 0
-    while (i < n) {
-      if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
-      i += 1
-    }
-    Vectors.fromBreeze(currMax)
-  }
-
-  override def min: Vector = {
-    var i = 0
-    while (i < n) {
-      if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
-      i += 1
-    }
-    Vectors.fromBreeze(currMin)
-  }
-
-  /**
-   * Aggregates a row.
-   */
-  def add(currData: BV[Double]): this.type = {
-    currData.activeIterator.foreach {
-      case (_, 0.0) => // Skip explicit zero elements.
-      case (i, value) =>
-        if (currMax(i) < value) {
-          currMax(i) = value
-        }
-        if (currMin(i) > value) {
-          currMin(i) = value
-        }
-
-        val tmpPrevMean = currMean(i)
-        currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
-        currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
-
-        nnz(i) += 1.0
-    }
-
-    totalCnt += 1.0
-    this
-  }
-
-  /**
-   * Merges another aggregator.
-   */
-  def merge(other: ColumnStatisticsAggregator): this.type = {
-    require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
-
-    totalCnt += other.totalCnt
-    val deltaMean = currMean - other.currMean
-
-    var i = 0
-    while (i < n) {
-      // merge mean together
-      if (other.currMean(i) != 0.0) {
-        currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
-          (nnz(i) + other.nnz(i))
-      }
-      // merge m2n together
-      if (nnz(i) + other.nnz(i) != 0.0) {
-        currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
-          (nnz(i) + other.nnz(i))
-      }
-      if (currMax(i) < other.currMax(i)) {
-        currMax(i) = other.currMax(i)
-      }
-      if (currMin(i) > other.currMin(i)) {
-        currMin(i) = other.currMin(i)
-      }
-      i += 1
-    }
-
-    nnz += other.nnz
-    this
-  }
-}
+import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
 
 /**
  * :: Experimental ::
@@ -478,8 +347,7 @@ class RowMatrix(
    * Computes column-wise summary statistics.
    */
   def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
-    val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
-    val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
+    val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
       (aggregator, data) => aggregator.add(data),
       (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
     )
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5105b5c37aaaa5b640bcf67edd58e986b38e4bd6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
+
+/**
+ * :: DeveloperApi ::
+ * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
+ * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
+ * format in a online fashion.
+ *
+ * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
+ * the corresponding joint dataset.
+ *
+ * A numerically stable algorithm is implemented to compute sample mean and variance:
+ * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
+ * Zero elements (including explicit zero values) are skipped when calling add(),
+ * to have time complexity O(nnz) instead of O(n) for each column.
+ */
+@DeveloperApi
+class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
+
+  private var n = 0
+  private var currMean: BDV[Double] = _
+  private var currM2n: BDV[Double] = _
+  private var totalCnt: Long = 0
+  private var nnz: BDV[Double] = _
+  private var currMax: BDV[Double] = _
+  private var currMin: BDV[Double] = _
+
+  /**
+   * Add a new sample to this summarizer, and update the statistical summary.
+   *
+   * @param sample The sample in dense/sparse vector format to be added into this summarizer.
+   * @return This MultivariateOnlineSummarizer object.
+   */
+  def add(sample: Vector): this.type = {
+    if (n == 0) {
+      require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
+      n = sample.toBreeze.length
+
+      currMean = BDV.zeros[Double](n)
+      currM2n = BDV.zeros[Double](n)
+      nnz = BDV.zeros[Double](n)
+      currMax = BDV.fill(n)(Double.MinValue)
+      currMin = BDV.fill(n)(Double.MaxValue)
+    }
+
+    require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
+      s" Expecting $n but got ${sample.toBreeze.length}.")
+
+    sample.toBreeze.activeIterator.foreach {
+      case (_, 0.0) => // Skip explicit zero elements.
+      case (i, value) =>
+        if (currMax(i) < value) {
+          currMax(i) = value
+        }
+        if (currMin(i) > value) {
+          currMin(i) = value
+        }
+
+        val tmpPrevMean = currMean(i)
+        currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
+        currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
+
+        nnz(i) += 1.0
+    }
+
+    totalCnt += 1
+    this
+  }
+
+  /**
+   * Merge another MultivariateOnlineSummarizer, and update the statistical summary.
+   * (Note that it's in place merging; as a result, `this` object will be modified.)
+   *
+   * @param other The other MultivariateOnlineSummarizer to be merged.
+   * @return This MultivariateOnlineSummarizer object.
+   */
+  def merge(other: MultivariateOnlineSummarizer): this.type = {
+   if (this.totalCnt != 0 && other.totalCnt != 0) {
+      require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
+        s"Expecting $n but got ${other.n}.")
+      totalCnt += other.totalCnt
+      val deltaMean: BDV[Double] = currMean - other.currMean
+      var i = 0
+      while (i < n) {
+        // merge mean together
+        if (other.currMean(i) != 0.0) {
+          currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
+            (nnz(i) + other.nnz(i))
+        }
+        // merge m2n together
+        if (nnz(i) + other.nnz(i) != 0.0) {
+          currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
+            (nnz(i) + other.nnz(i))
+        }
+        if (currMax(i) < other.currMax(i)) {
+          currMax(i) = other.currMax(i)
+        }
+        if (currMin(i) > other.currMin(i)) {
+          currMin(i) = other.currMin(i)
+        }
+        i += 1
+      }
+      nnz += other.nnz
+    } else if (totalCnt == 0 && other.totalCnt != 0) {
+      this.n = other.n
+      this.currMean = other.currMean.copy
+      this.currM2n = other.currM2n.copy
+      this.totalCnt = other.totalCnt
+      this.nnz = other.nnz.copy
+      this.currMax = other.currMax.copy
+      this.currMin = other.currMin.copy
+    }
+    this
+  }
+
+  override def mean: Vector = {
+    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+    val realMean = BDV.zeros[Double](n)
+    var i = 0
+    while (i < n) {
+      realMean(i) = currMean(i) * (nnz(i) / totalCnt)
+      i += 1
+    }
+    Vectors.fromBreeze(realMean)
+  }
+
+  override def variance: Vector = {
+    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+    val realVariance = BDV.zeros[Double](n)
+
+    val denominator = totalCnt - 1.0
+
+    // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
+    if (denominator > 0.0) {
+      val deltaMean = currMean
+      var i = 0
+      while (i < currM2n.size) {
+        realVariance(i) =
+          currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
+        realVariance(i) /= denominator
+        i += 1
+      }
+    }
+
+    Vectors.fromBreeze(realVariance)
+  }
+
+  override def count: Long = totalCnt
+
+  override def numNonzeros: Vector = {
+    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+    Vectors.fromBreeze(nnz)
+  }
+
+  override def max: Vector = {
+    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+    var i = 0
+    while (i < n) {
+      if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
+      i += 1
+    }
+    Vectors.fromBreeze(currMax)
+  }
+
+  override def min: Vector = {
+    require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+    var i = 0
+    while (i < n) {
+      if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
+      i += 1
+    }
+    Vectors.fromBreeze(currMin)
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4b7b019d820b4fac23fd829c3ac784183d656c04
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.TestingUtils._
+
+class MultivariateOnlineSummarizerSuite extends FunSuite {
+
+  test("basic error handing") {
+    val summarizer = new MultivariateOnlineSummarizer
+
+    assert(summarizer.count === 0, "should be zero since nothing is added.")
+
+    withClue("Getting numNonzeros from empty summarizer should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.numNonzeros
+      }
+    }
+
+    withClue("Getting variance from empty summarizer should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.variance
+      }
+    }
+
+    withClue("Getting mean from empty summarizer should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.mean
+      }
+    }
+
+    withClue("Getting max from empty summarizer should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.max
+      }
+    }
+
+    withClue("Getting min from empty summarizer should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.min
+      }
+    }
+
+    summarizer.add(Vectors.dense(-1.0, 2.0, 6.0)).add(Vectors.sparse(3, Seq((0, -2.0), (1, 6.0))))
+
+    withClue("Adding a new dense sample with different array size should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.add(Vectors.dense(3.0, 1.0))
+      }
+    }
+
+    withClue("Adding a new sparse sample with different array size should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.add(Vectors.sparse(5, Seq((0, -2.0), (1, 6.0))))
+      }
+    }
+
+    val summarizer2 = (new MultivariateOnlineSummarizer).add(Vectors.dense(1.0, -2.0, 0.0, 4.0))
+    withClue("Merging a new summarizer with different dimensions should throw exception.") {
+      intercept[IllegalArgumentException] {
+        summarizer.merge(summarizer2)
+      }
+    }
+  }
+
+  test("dense vector input") {
+    // For column 2, the maximum will be 0.0, and it's not explicitly added since we ignore all
+    // the zeros; it's a case we need to test. For column 3, the minimum will be 0.0 which we
+    // need to test as well.
+    val summarizer = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(-1.0, 0.0, 6.0))
+      .add(Vectors.dense(3.0, -3.0, 0.0))
+
+    assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+
+    assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+
+    assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+
+    assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+
+    assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+
+    assert(summarizer.count === 2)
+  }
+
+  test("sparse vector input") {
+    val summarizer = (new MultivariateOnlineSummarizer)
+      .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0))))
+      .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0))))
+
+    assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+
+    assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+
+    assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+
+    assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+
+    assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+
+    assert(summarizer.count === 2)
+  }
+
+  test("mixing dense and sparse vector input") {
+    val summarizer = (new MultivariateOnlineSummarizer)
+      .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
+      .add(Vectors.dense(0.0, -1.0, -3.0))
+      .add(Vectors.sparse(3, Seq((1, -5.1))))
+      .add(Vectors.dense(3.8, 0.0, 1.9))
+      .add(Vectors.dense(1.7, -0.6, 0.0))
+      .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
+
+    assert(summarizer.mean.almostEquals(
+      Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+
+    assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+
+    assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+
+    assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+
+    assert(summarizer.variance.almostEquals(
+      Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+
+    assert(summarizer.count === 6)
+  }
+
+  test("merging two summarizers") {
+    val summarizer1 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
+      .add(Vectors.dense(0.0, -1.0, -3.0))
+
+    val summarizer2 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.sparse(3, Seq((1, -5.1))))
+      .add(Vectors.dense(3.8, 0.0, 1.9))
+      .add(Vectors.dense(1.7, -0.6, 0.0))
+      .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
+
+    val summarizer = summarizer1.merge(summarizer2)
+
+    assert(summarizer.mean.almostEquals(
+      Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+
+    assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+
+    assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+
+    assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+
+    assert(summarizer.variance.almostEquals(
+      Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+
+    assert(summarizer.count === 6)
+  }
+
+  test("merging summarizer with empty summarizer") {
+    // If one of two is non-empty, this should return the non-empty summarizer.
+    // If both of them are empty, then just return the empty summarizer.
+    val summarizer1 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(0.0, -1.0, -3.0)).merge(new MultivariateOnlineSummarizer)
+    assert(summarizer1.count === 1)
+
+    val summarizer2 = (new MultivariateOnlineSummarizer)
+      .merge((new MultivariateOnlineSummarizer).add(Vectors.dense(0.0, -1.0, -3.0)))
+    assert(summarizer2.count === 1)
+
+    val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer)
+    assert(summarizer3.count === 0)
+
+    assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+
+    assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+
+    assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+
+    assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+
+    assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+
+    assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+
+    assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+
+    assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+
+    assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+
+    assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
new file mode 100644
index 0000000000000000000000000000000000000000..64b1ba75271836cb3700e949a37455584fbb1a10
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.apache.spark.mllib.linalg.Vector
+
+object TestingUtils {
+
+  implicit class DoubleWithAlmostEquals(val x: Double) {
+    // An improved version of AlmostEquals would always divide by the larger number.
+    // This will avoid the problem of diving by zero.
+    def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = {
+      if(x == y) {
+        true
+      } else if(math.abs(x) > math.abs(y)) {
+        math.abs(x - y) / math.abs(x) < epsilon
+      } else {
+        math.abs(x - y) / math.abs(y) < epsilon
+      }
+    }
+  }
+
+  implicit class VectorWithAlmostEquals(val x: Vector) {
+    def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = {
+      x.toArray.corresponds(y.toArray) {
+        _.almostEquals(_, epsilon)
+      }
+    }
+  }
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3b7b87b80cda00d5355650d27e8f1c8f7aaea407..d67c6571a0623f2f88789d6532a5d9a8b2890121 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -75,6 +75,7 @@ object MimaExcludes {
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$<init>$default$7")
           ) ++
+          MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++
           MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
           MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
           MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++