diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 8abec85ee102a1c52f09677419fa0a90112a0db7..f7637e005f31733be14c901435559433df69f19b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -131,7 +131,7 @@ private[sql] object Dataset {
  *
  *   people.filter("age > 30")
  *     .join(department, people("deptId") === department("id"))
- *     .groupBy(department("name"), "gender")
+ *     .groupBy(department("name"), people("gender"))
  *     .agg(avg(people("salary")), max(people("age")))
  * }}}
  *
@@ -141,9 +141,9 @@ private[sql] object Dataset {
  *   Dataset<Row> people = spark.read().parquet("...");
  *   Dataset<Row> department = spark.read().parquet("...");
  *
- *   people.filter("age".gt(30))
- *     .join(department, people.col("deptId").equalTo(department("id")))
- *     .groupBy(department.col("name"), "gender")
+ *   people.filter(people.col("age").gt(30))
+ *     .join(department, people.col("deptId").equalTo(department.col("id")))
+ *     .groupBy(department.col("name"), people.col("gender"))
  *     .agg(avg(people.col("salary")), max(people.col("age")));
  * }}}
  *