diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ec0b3c78ed72c16b903223e157337760d1d50d7d..703ea4d1498cebb764e7a86a954c899927eda865 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1178,32 +1178,6 @@ class Dataset[T] private[sql](
       withGroupingKey.newColumns)
   }
 
-  /**
-   * :: Experimental ::
-   * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given [[Column]]
-   * expressions.
-   *
-   * @group typedrel
-   * @since 2.0.0
-   */
-  @Experimental
-  @scala.annotation.varargs
-  def groupByKey(cols: Column*): KeyValueGroupedDataset[Row, T] = {
-    val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
-    val withKey = Project(withKeyColumns, logicalPlan)
-    val executed = sqlContext.executePlan(withKey)
-
-    val dataAttributes = executed.analyzed.output.dropRight(cols.size)
-    val keyAttributes = executed.analyzed.output.takeRight(cols.size)
-
-    new KeyValueGroupedDataset(
-      RowEncoder(keyAttributes.toStructType),
-      encoderFor[T],
-      executed,
-      dataAttributes,
-      keyAttributes)
-  }
-
   /**
    * :: Experimental ::
    * (Java-specific)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 18f17a85a9dbda60a126208f310d8b38ea5d9aa4..86db8df4c00fdc303b50e0f00c433a93e1dbcffd 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -245,29 +245,6 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList()));
   }
 
-  @Test
-  public void testGroupByColumn() {
-    List<String> data = Arrays.asList("a", "foo", "bar");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
-    KeyValueGroupedDataset<Integer, String> grouped =
-      ds.groupByKey(length(col("value"))).keyAs(Encoders.INT());
-
-    Dataset<String> mapped = grouped.mapGroups(
-      new MapGroupsFunction<Integer, String, String>() {
-        @Override
-        public String call(Integer key, Iterator<String> data) throws Exception {
-          StringBuilder sb = new StringBuilder(key.toString());
-          while (data.hasNext()) {
-            sb.append(data.next());
-          }
-          return sb.toString();
-        }
-      },
-      Encoders.STRING());
-
-    Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList()));
-  }
-
   @Test
   public void testSelect() {
     List<Integer> data = Arrays.asList(2, 6);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index 2e5179a8d2c9561d384a4427458d27f90284923e..942cc09b6d58e6f4c8eec5785c5edf0e287e24e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -63,7 +63,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
 
   test("persist and then groupBy columns asKey, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupByKey($"_1").keyAs[String]
+    val grouped = ds.groupByKey(_._1)
     val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
     agged.persist()
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 0bcc512d7137dde2044615a92bc0a9bdf323edde..553bc436a64561643365f3c0daf26b9235fea50b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -322,55 +322,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     )
   }
 
-  test("groupBy columns, map") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupByKey($"_1")
-    val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
-
-    checkDataset(
-      agged,
-      ("a", 30), ("b", 3), ("c", 1))
-  }
-
-  test("groupBy columns, count") {
-    val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS()
-    val count = ds.groupByKey($"_1").count()
-
-    checkDataset(
-      count,
-      (Row("a"), 2L), (Row("b"), 1L))
-  }
-
-  test("groupBy columns asKey, map") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupByKey($"_1").keyAs[String]
-    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
-
-    checkDataset(
-      agged,
-      ("a", 30), ("b", 3), ("c", 1))
-  }
-
-  test("groupBy columns asKey tuple, map") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupByKey($"_1", lit(1)).keyAs[(String, Int)]
-    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
-
-    checkDataset(
-      agged,
-      (("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
-  }
-
-  test("groupBy columns asKey class, map") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupByKey($"_1".as("a"), lit(1).as("b")).keyAs[ClassData]
-    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
-
-    checkDataset(
-      agged,
-      (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
-  }
-
   test("typed aggregation: expr") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()