diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index 7ed611a857accc0bfcc4041e685c7e6e3650e115..d40d5553c1d219b1bff0981c6f396a6c26642dc5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -36,87 +36,25 @@ class IDF {
 
   // TODO: Allow different IDF formulations.
 
-  private var brzIdf: BDV[Double] = _
-
   /**
    * Computes the inverse document frequency.
    * @param dataset an RDD of term frequency vectors
    */
-  def fit(dataset: RDD[Vector]): this.type = {
-    brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
+  def fit(dataset: RDD[Vector]): IDFModel = {
+    val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
       seqOp = (df, v) => df.add(v),
       combOp = (df1, df2) => df1.merge(df2)
     ).idf()
-    this
+    new IDFModel(idf)
   }
 
   /**
    * Computes the inverse document frequency.
    * @param dataset a JavaRDD of term frequency vectors
    */
-  def fit(dataset: JavaRDD[Vector]): this.type = {
+  def fit(dataset: JavaRDD[Vector]): IDFModel = {
     fit(dataset.rdd)
   }
-
-  /**
-   * Transforms term frequency (TF) vectors to TF-IDF vectors.
-   * @param dataset an RDD of term frequency vectors
-   * @return an RDD of TF-IDF vectors
-   */
-  def transform(dataset: RDD[Vector]): RDD[Vector] = {
-    if (!initialized) {
-      throw new IllegalStateException("Haven't learned IDF yet. Call fit first.")
-    }
-    val theIdf = brzIdf
-    val bcIdf = dataset.context.broadcast(theIdf)
-    dataset.mapPartitions { iter =>
-      val thisIdf = bcIdf.value
-      iter.map { v =>
-        val n = v.size
-        v match {
-          case sv: SparseVector =>
-            val nnz = sv.indices.size
-            val newValues = new Array[Double](nnz)
-            var k = 0
-            while (k < nnz) {
-              newValues(k) = sv.values(k) * thisIdf(sv.indices(k))
-              k += 1
-            }
-            Vectors.sparse(n, sv.indices, newValues)
-          case dv: DenseVector =>
-            val newValues = new Array[Double](n)
-            var j = 0
-            while (j < n) {
-              newValues(j) = dv.values(j) * thisIdf(j)
-              j += 1
-            }
-            Vectors.dense(newValues)
-          case other =>
-            throw new UnsupportedOperationException(
-              s"Only sparse and dense vectors are supported but got ${other.getClass}.")
-        }
-      }
-    }
-  }
-
-  /**
-   * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version).
-   * @param dataset a JavaRDD of term frequency vectors
-   * @return a JavaRDD of TF-IDF vectors
-   */
-  def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
-    transform(dataset.rdd).toJavaRDD()
-  }
-
-  /** Returns the IDF vector. */
-  def idf(): Vector = {
-    if (!initialized) {
-      throw new IllegalStateException("Haven't learned IDF yet. Call fit first.")
-    }
-    Vectors.fromBreeze(brzIdf)
-  }
-
-  private def initialized: Boolean = brzIdf != null
 }
 
 private object IDF {
@@ -177,18 +115,72 @@ private object IDF {
     private def isEmpty: Boolean = m == 0L
 
     /** Returns the current IDF vector. */
-    def idf(): BDV[Double] = {
+    def idf(): Vector = {
       if (isEmpty) {
         throw new IllegalStateException("Haven't seen any document yet.")
       }
       val n = df.length
-      val inv = BDV.zeros[Double](n)
+      val inv = new Array[Double](n)
       var j = 0
       while (j < n) {
         inv(j) = math.log((m + 1.0)/ (df(j) + 1.0))
         j += 1
       }
-      inv
+      Vectors.dense(inv)
     }
   }
 }
+
+/**
+ * :: Experimental ::
+ * Represents an IDF model that can transform term frequency vectors.
+ */
+@Experimental
+class IDFModel private[mllib] (val idf: Vector) extends Serializable {
+
+  /**
+   * Transforms term frequency (TF) vectors to TF-IDF vectors.
+   * @param dataset an RDD of term frequency vectors
+   * @return an RDD of TF-IDF vectors
+   */
+  def transform(dataset: RDD[Vector]): RDD[Vector] = {
+    val bcIdf = dataset.context.broadcast(idf)
+    dataset.mapPartitions { iter =>
+      val thisIdf = bcIdf.value
+      iter.map { v =>
+        val n = v.size
+        v match {
+          case sv: SparseVector =>
+            val nnz = sv.indices.size
+            val newValues = new Array[Double](nnz)
+            var k = 0
+            while (k < nnz) {
+              newValues(k) = sv.values(k) * thisIdf(sv.indices(k))
+              k += 1
+            }
+            Vectors.sparse(n, sv.indices, newValues)
+          case dv: DenseVector =>
+            val newValues = new Array[Double](n)
+            var j = 0
+            while (j < n) {
+              newValues(j) = dv.values(j) * thisIdf(j)
+              j += 1
+            }
+            Vectors.dense(newValues)
+          case other =>
+            throw new UnsupportedOperationException(
+              s"Only sparse and dense vectors are supported but got ${other.getClass}.")
+        }
+      }
+    }
+  }
+
+  /**
+   * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version).
+   * @param dataset a JavaRDD of term frequency vectors
+   * @return a JavaRDD of TF-IDF vectors
+   */
+  def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
+    transform(dataset.rdd).toJavaRDD()
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index e6c9f8f67df6347e26896eaa3b5c2cf6f7d8d67b..4dfd1f0ab81349fc9876202b4de6fcac4598ce89 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -17,8 +17,9 @@
 
 package org.apache.spark.mllib.feature
 
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
 
+import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.rdd.RDDFunctions._
@@ -35,37 +36,55 @@ import org.apache.spark.rdd.RDD
  * @param withStd True by default. Scales the data to unit standard deviation.
  */
 @Experimental
-class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer {
+class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
 
   def this() = this(false, true)
 
-  require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.")
-
-  private var mean: BV[Double] = _
-  private var factor: BV[Double] = _
+  if (!(withMean || withStd)) {
+    logWarning("Both withMean and withStd are false. The model does nothing.")
+  }
 
   /**
    * Computes the mean and variance and stores as a model to be used for later scaling.
    *
    * @param data The data used to compute the mean and variance to build the transformation model.
-   * @return This StandardScalar object.
+   * @return a StandardScalarModel
    */
-  def fit(data: RDD[Vector]): this.type = {
+  def fit(data: RDD[Vector]): StandardScalerModel = {
+    // TODO: skip computation if both withMean and withStd are false
     val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
       (aggregator, data) => aggregator.add(data),
       (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
+    new StandardScalerModel(withMean, withStd, summary.mean, summary.variance)
+  }
+}
 
-    mean = summary.mean.toBreeze
-    factor = summary.variance.toBreeze
-    require(mean.length == factor.length)
+/**
+ * :: Experimental ::
+ * Represents a StandardScaler model that can transform vectors.
+ *
+ * @param withMean whether to center the data before scaling
+ * @param withStd whether to scale the data to have unit standard deviation
+ * @param mean column mean values
+ * @param variance column variance values
+ */
+@Experimental
+class StandardScalerModel private[mllib] (
+    val withMean: Boolean,
+    val withStd: Boolean,
+    val mean: Vector,
+    val variance: Vector) extends VectorTransformer {
+
+  require(mean.size == variance.size)
 
+  private lazy val factor: BDV[Double] = {
+    val f = BDV.zeros[Double](variance.size)
     var i = 0
-    while (i < factor.length) {
-      factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0
+    while (i < f.size) {
+      f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
       i += 1
     }
-
-    this
+    f
   }
 
   /**
@@ -76,13 +95,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor
    *         for the column with zero variance.
    */
   override def transform(vector: Vector): Vector = {
-    if (mean == null || factor == null) {
-      throw new IllegalStateException(
-        "Haven't learned column summary statistics yet. Call fit first.")
-    }
-
-    require(vector.size == mean.length)
-
+    require(mean.size == vector.size)
     if (withMean) {
       vector.toBreeze match {
         case dv: BDV[Double] =>
@@ -115,5 +128,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor
       vector
     }
   }
-
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
index 78a2804ff204be467ba00ebe0cfdba9419bbe144..53d9c0c640b980d8523967940a99268432741a04 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
@@ -36,18 +36,12 @@ class IDFSuite extends FunSuite with LocalSparkContext {
     val m = localTermFrequencies.size
     val termFrequencies = sc.parallelize(localTermFrequencies, 2)
     val idf = new IDF
-    intercept[IllegalStateException] {
-      idf.idf()
-    }
-    intercept[IllegalStateException] {
-      idf.transform(termFrequencies)
-    }
-    idf.fit(termFrequencies)
+    val model = idf.fit(termFrequencies)
     val expected = Vectors.dense(Array(0, 3, 1, 2).map { x =>
       math.log((m.toDouble + 1.0) / (x + 1.0))
     })
-    assert(idf.idf() ~== expected absTol 1e-12)
-    val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
+    assert(model.idf ~== expected absTol 1e-12)
+    val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
     assert(tfidf.size === 3)
     val tfidf0 = tfidf(0L).asInstanceOf[SparseVector]
     assert(tfidf0.indices === Array(1, 3))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index 5a9be923a8625812120ec3e4370e51baf4d718ac..e217b93cebbdb77f1be2cb2a20ab6a40d9c66cb0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -50,23 +50,17 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
     val standardizer2 = new StandardScaler()
     val standardizer3 = new StandardScaler(withMean = true, withStd = false)
 
-    withClue("Using a standardizer before fitting the model should throw exception.") {
-      intercept[IllegalStateException] {
-        data.map(standardizer1.transform)
-      }
-    }
-
-    standardizer1.fit(dataRDD)
-    standardizer2.fit(dataRDD)
-    standardizer3.fit(dataRDD)
+    val model1 = standardizer1.fit(dataRDD)
+    val model2 = standardizer2.fit(dataRDD)
+    val model3 = standardizer3.fit(dataRDD)
 
-    val data1 = data.map(standardizer1.transform)
-    val data2 = data.map(standardizer2.transform)
-    val data3 = data.map(standardizer3.transform)
+    val data1 = data.map(model1.transform)
+    val data2 = data.map(model2.transform)
+    val data3 = data.map(model3.transform)
 
-    val data1RDD = standardizer1.transform(dataRDD)
-    val data2RDD = standardizer2.transform(dataRDD)
-    val data3RDD = standardizer3.transform(dataRDD)
+    val data1RDD = model1.transform(dataRDD)
+    val data2RDD = model2.transform(dataRDD)
+    val data3RDD = model3.transform(dataRDD)
 
     val summary = computeSummary(dataRDD)
     val summary1 = computeSummary(data1RDD)
@@ -129,25 +123,25 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
     val standardizer2 = new StandardScaler()
     val standardizer3 = new StandardScaler(withMean = true, withStd = false)
 
-    standardizer1.fit(dataRDD)
-    standardizer2.fit(dataRDD)
-    standardizer3.fit(dataRDD)
+    val model1 = standardizer1.fit(dataRDD)
+    val model2 = standardizer2.fit(dataRDD)
+    val model3 = standardizer3.fit(dataRDD)
 
-    val data2 = data.map(standardizer2.transform)
+    val data2 = data.map(model2.transform)
 
     withClue("Standardization with mean can not be applied on sparse input.") {
       intercept[IllegalArgumentException] {
-        data.map(standardizer1.transform)
+        data.map(model1.transform)
       }
     }
 
     withClue("Standardization with mean can not be applied on sparse input.") {
       intercept[IllegalArgumentException] {
-        data.map(standardizer3.transform)
+        data.map(model3.transform)
       }
     }
 
-    val data2RDD = standardizer2.transform(dataRDD)
+    val data2RDD = model2.transform(dataRDD)
 
     val summary2 = computeSummary(data2RDD)
 
@@ -181,13 +175,13 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
     val standardizer2 = new StandardScaler(withMean = true, withStd = false)
     val standardizer3 = new StandardScaler(withMean = false, withStd = true)
 
-    standardizer1.fit(dataRDD)
-    standardizer2.fit(dataRDD)
-    standardizer3.fit(dataRDD)
+    val model1 = standardizer1.fit(dataRDD)
+    val model2 = standardizer2.fit(dataRDD)
+    val model3 = standardizer3.fit(dataRDD)
 
-    val data1 = data.map(standardizer1.transform)
-    val data2 = data.map(standardizer2.transform)
-    val data3 = data.map(standardizer3.transform)
+    val data1 = data.map(model1.transform)
+    val data2 = data.map(model2.transform)
+    val data3 = data.map(model3.transform)
 
     assert(data1.forall(_.toArray.forall(_ == 0.0)),
       "The variance is zero, so the transformed result should be 0.0")