diff --git a/docs/ml-features.md b/docs/ml-features.md
index 5df61dd36a07037445c1d52c65b4df27df65be5a..e86f9edc4f68bfb8febaf3a5e5e7642d2ad63491 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -268,5 +268,88 @@ for binarized_feature, in binarizedFeatures.collect():
 </div>
 </div>
 
+## PolynomialExpansion
+
+[Polynomial expansion](http://en.wikipedia.org/wiki/Polynomial_expansion) is the process of expanding your features into a polynomial space, which is formulated by an n-degree combination of original dimensions. A [PolynomialExpansion](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) class provides this functionality.  The example below shows how to expand your features into a 3-degree polynomial space.
+
+<div class="codetabs">
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+import org.apache.spark.ml.feature.PolynomialExpansion
+import org.apache.spark.mllib.linalg.Vectors
+
+val data = Array(
+  Vectors.dense(-2.0, 2.3),
+  Vectors.dense(0.0, 0.0),
+  Vectors.dense(0.6, -1.1)
+)
+val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features")
+val polynomialExpansion = new PolynomialExpansion()
+  .setInputCol("features")
+  .setOutputCol("polyFeatures")
+  .setDegree(3)
+val polyDF = polynomialExpansion.transform(df)
+polyDF.select("polyFeatures").take(3).foreach(println)
+{% endhighlight %}
+</div>
+
+<div data-lang="java" markdown="1">
+{% highlight java %}
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+JavaSparkContext jsc = ...
+SQLContext jsql = ...
+PolynomialExpansion polyExpansion = new PolynomialExpansion()
+  .setInputCol("features")
+  .setOutputCol("polyFeatures")
+  .setDegree(3);
+JavaRDD<Row> data = jsc.parallelize(Lists.newArrayList(
+  RowFactory.create(Vectors.dense(-2.0, 2.3)),
+  RowFactory.create(Vectors.dense(0.0, 0.0)),
+  RowFactory.create(Vectors.dense(0.6, -1.1))
+));
+StructType schema = new StructType(new StructField[] {
+  new StructField("features", new VectorUDT(), false, Metadata.empty()),
+});
+DataFrame df = jsql.createDataFrame(data, schema);
+DataFrame polyDF = polyExpansion.transform(df);
+Row[] row = polyDF.select("polyFeatures").take(3);
+for (Row r : row) {
+  System.out.println(r.get(0));
+}
+{% endhighlight %}
+</div>
+
+<div data-lang="python" markdown="1">
+{% highlight python %}
+from pyspark.ml.feature import PolynomialExpansion
+from pyspark.mllib.linalg import Vectors
+
+df = sqlContext.createDataFrame(
+  [(Vectors.dense([-2.0, 2.3]), ),
+  (Vectors.dense([0.0, 0.0]), ),
+  (Vectors.dense([0.6, -1.1]), )],
+  ["features"])
+px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures")
+polyDF = px.transform(df)
+for expanded in polyDF.select("polyFeatures").take(3):
+  print(expanded)
+{% endhighlight %}
+</div>
+</div>
+
 # Feature Selectors
 
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
new file mode 100644
index 0000000000000000000000000000000000000000..5e8211c2c5118441d9638658d52520523db0d73a
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaPolynomialExpansionSuite {
+  private transient JavaSparkContext jsc;
+  private transient SQLContext jsql;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite");
+    jsql = new SQLContext(jsc);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void polynomialExpansionTest() {
+    PolynomialExpansion polyExpansion = new PolynomialExpansion()
+      .setInputCol("features")
+      .setOutputCol("polyFeatures")
+      .setDegree(3);
+
+    JavaRDD<Row> data = jsc.parallelize(Lists.newArrayList(
+      RowFactory.create(
+        Vectors.dense(-2.0, 2.3),
+        Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
+      ),
+      RowFactory.create(Vectors.dense(0.0, 0.0), Vectors.dense(new double[9])),
+      RowFactory.create(
+        Vectors.dense(0.6, -1.1),
+        Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331)
+      )
+    ));
+
+    StructType schema = new StructType(new StructField[] {
+      new StructField("features", new VectorUDT(), false, Metadata.empty()),
+      new StructField("expected", new VectorUDT(), false, Metadata.empty())
+    });
+
+    DataFrame dataset = jsql.createDataFrame(data, schema);
+
+    Row[] pairs = polyExpansion.transform(dataset)
+      .select("polyFeatures", "expected")
+      .collect();
+
+    for (Row r : pairs) {
+      double[] polyFeatures = ((Vector)r.get(0)).toArray();
+      double[] expected = ((Vector)r.get(1)).toArray();
+      Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
+    }
+  }
+}