diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index c15fabab805a78f43047b80bca103eb879e5ba64..57f4945de980455620f794491fcb4cd3a6981ea3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
   */
 private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
 
+  private def supportCodegen(e: Expression): Boolean = e match {
+    case e: LeafExpression => true
+    // CodegenFallback requires the input to be an InternalRow
+    case e: CodegenFallback => false
+    case _ => true
+  }
+
   private def supportCodegen(plan: SparkPlan): Boolean = plan match {
     case plan: CodegenSupport if plan.supportCodegen =>
-      // Non-leaf with CodegenFallback does not work with whole stage codegen
-      val willFallback = plan.expressions.exists(
-        _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined
-      )
+      val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
       // the generated code will be huge if there are too many columns
       val haveManyColumns = plan.output.length > 200
       !willFallback && !haveManyColumns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 8dcbab4c8cfbc6ca5c9204eb64ca2a246951d2d8..23e54f344d2526b5e55b5248324dfa6580924d79 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.StructType
 
@@ -35,7 +36,7 @@ case class TungstenAggregate(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends UnaryNode {
+  extends UnaryNode with CodegenSupport {
 
   private[this] val aggregateBufferAttributes = {
     aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -113,6 +114,86 @@ case class TungstenAggregate(
     }
   }
 
+  override def supportCodegen: Boolean = {
+    groupingExpressions.isEmpty &&
+      // ImperativeAggregate is not supported right now
+      !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
+      // final aggregation only have one row, do not need to codegen
+      !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+  }
+
+  // The variables used as aggregation buffer
+  private var bufVars: Seq[ExprCode] = _
+
+  private val modes = aggregateExpressions.map(_.mode).distinct
+
+  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+    val initAgg = ctx.freshName("initAgg")
+    ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+    // generate variables for aggregation buffer
+    val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+    val initExpr = functions.flatMap(f => f.initialValues)
+    bufVars = initExpr.map { e =>
+      val isNull = ctx.freshName("bufIsNull")
+      val value = ctx.freshName("bufValue")
+      // The initial expression should not access any column
+      val ev = e.gen(ctx)
+      val initVars = s"""
+         | boolean $isNull = ${ev.isNull};
+         | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+       """.stripMargin
+      ExprCode(ev.code + initVars, isNull, value)
+    }
+
+    val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
+    val source =
+      s"""
+         | if (!$initAgg) {
+         |   $initAgg = true;
+         |
+         |   // initialize aggregation buffer
+         |   ${bufVars.map(_.code).mkString("\n")}
+         |
+         |   $childSource
+         |
+         |   // output the result
+         |   ${consume(ctx, bufVars)}
+         | }
+       """.stripMargin
+
+    (rdd, source)
+  }
+
+  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+    // only have DeclarativeAggregate
+    val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+    // the mode could be only Partial or PartialMerge
+    val updateExpr = if (modes.contains(Partial)) {
+      functions.flatMap(_.updateExpressions)
+    } else {
+      functions.flatMap(_.mergeExpressions)
+    }
+
+    val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
+    val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
+    ctx.currentVars = bufVars ++ input
+    // TODO: support subexpression elimination
+    val codes = boundExpr.zipWithIndex.map { case (e, i) =>
+      val ev = e.gen(ctx)
+      s"""
+         | ${ev.code}
+         | ${bufVars(i).isNull} = ${ev.isNull};
+         | ${bufVars(i).value} = ${ev.value};
+       """.stripMargin
+    }
+
+    s"""
+       | // do aggregate and update aggregation buffer
+       | ${codes.mkString("")}
+     """.stripMargin
+  }
+
   override def simpleString: String = {
     val allAggregateExpressions = aggregateExpressions
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 788b04fcf8c2e9cb7be22c0c87124f6f739bdd00..c4aad398bfa5466381a8263c8d9c4f5161ef1ab7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
     /*
       Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-      Single Int Column Scan:      Avg Time(ms)    Avg Rate(M/s)  Relative Rate
-      -------------------------------------------------------------------------
-      Without whole stage codegen       6725.52            31.18         1.00 X
-      With whole stage codegen          2233.05            93.91         3.01 X
+      Single Int Column Scan:            Avg Time(ms)    Avg Rate(M/s)  Relative Rate
+      -------------------------------------------------------------------------------
+      Without whole stage codegen             7775.53            26.97         1.00 X
+      With whole stage codegen                 342.15           612.94        22.73 X
     */
     benchmark.run()
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index c54fc6ba2de3d8b44b02d4ea25bdbb6c23b3c4bd..300788c88ab2f6059efa0c161e363310c79a4da8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.functions.{avg, col, max}
 import org.apache.spark.sql.test.SharedSQLContext
 
 class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
@@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
       sortAnswers = false
     )
   }
+
+  test("Aggregate should be included in WholeStageCodegen") {
+    val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id")))
+    val plan = df.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+    assert(df.collect() === Array(Row(9, 4.5)))
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 4339f7260dcb9d732ac9bf8718ce39c1e018aceb..51285431a47ed03fa8354f0674593381a8d95080 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
       expectedNumOfJobs: Int,
       expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
     val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
-    df.collect()
+    withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+      df.collect()
+    }
     sparkContext.listenerBus.waitUntilEmpty(10000)
     val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
     assert(executionIds.size === 1)