diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 437397187356ccb9c067865f6ccafa1bcd82ac28..f10d36862770785b9f20c70b495b81da9e25cdba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -242,6 +242,9 @@ class CodegenContext {
   private val classFunctions: mutable.Map[String, mutable.Map[String, String]] =
     mutable.Map(outerClassName -> mutable.Map.empty[String, String])
 
+  // Verbatim extra code to be added to the OuterClass.
+  private val extraCode: mutable.ListBuffer[String] = mutable.ListBuffer[String]()
+
   // Returns the size of the most recently added class.
   private def currClassSize(): Int = classSize(classes.head._1)
 
@@ -328,6 +331,22 @@ class CodegenContext {
     (inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n")
   }
 
+  /**
+   * Emits any source code added with addExtraCode
+   */
+  def emitExtraCode(): String = {
+    extraCode.mkString("\n")
+  }
+
+  /**
+   * Add extra source code to the outermost generated class.
+   * @param code verbatim source code to be added.
+   */
+  def addExtraCode(code: String): Unit = {
+    extraCode.append(code)
+    classSize(outerClassName) += code.length
+  }
+
   final val JAVA_BOOLEAN = "boolean"
   final val JAVA_BYTE = "byte"
   final val JAVA_SHORT = "short"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index a41a7ca56a0a1ae1a7fa7a8223e22a1cf0158467..268ccfa4edfa052a64825880f1e81406c61c6544 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -197,11 +197,14 @@ trait CodegenSupport extends SparkPlan {
    *
    * This should be override by subclass to support codegen.
    *
-   * For example, Filter will generate the code like this:
+   * Note: The operator should not assume the existence of an outer processing loop,
+   *       which it can jump from with "continue;"!
    *
+   * For example, filter could generate this:
    *   # code to evaluate the predicate expression, result is isNull1 and value2
-   *   if (isNull1 || !value2) continue;
-   *   # call consume(), which will call parent.doConsume()
+   *   if (!isNull1 && value2) {
+   *     # call consume(), which will call parent.doConsume()
+   *   }
    *
    * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
    */
@@ -329,6 +332,15 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
   def doCodeGen(): (CodegenContext, CodeAndComment) = {
     val ctx = new CodegenContext
     val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
+
+    // main next function.
+    ctx.addNewFunction("processNext",
+      s"""
+        protected void processNext() throws java.io.IOException {
+          ${code.trim}
+        }
+       """, inlineToOuterClass = true)
+
     val source = s"""
       public Object generate(Object[] references) {
         return new GeneratedIterator(references);
@@ -352,9 +364,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
           ${ctx.initPartition()}
         }
 
-        protected void processNext() throws java.io.IOException {
-          ${code.trim}
-        }
+        ${ctx.emitExtraCode()}
 
         ${ctx.declareAddedFunctions()}
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index d77405c559c584bc60319e4eefd77f461f87e3d3..abdf9530c6c7b52f966b6a8c6370bb8a8b27f4dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -425,12 +425,14 @@ case class HashAggregateExec(
 
   /**
    * Generate the code for output.
+   * @return function name for the result code.
    */
-  private def generateResultCode(
-      ctx: CodegenContext,
-      keyTerm: String,
-      bufferTerm: String,
-      plan: String): String = {
+  private def generateResultFunction(ctx: CodegenContext): String = {
+    val funcName = ctx.freshName("doAggregateWithKeysOutput")
+    val keyTerm = ctx.freshName("keyTerm")
+    val bufferTerm = ctx.freshName("bufferTerm")
+
+    val body =
     if (modes.contains(Final) || modes.contains(Complete)) {
       // generate output using resultExpressions
       ctx.currentVars = null
@@ -462,18 +464,36 @@ case class HashAggregateExec(
        $evaluateAggResults
        ${consume(ctx, resultVars)}
        """
-
     } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
-      // This should be the last operator in a stage, we should output UnsafeRow directly
-      val joinerTerm = ctx.freshName("unsafeRowJoiner")
-      ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
-        s"$joinerTerm = $plan.createUnsafeJoiner();")
-      val resultRow = ctx.freshName("resultRow")
+      // resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes.
+      assert(resultExpressions.forall(_.isInstanceOf[Attribute]))
+      assert(resultExpressions.length ==
+        groupingExpressions.length + aggregateBufferAttributes.length)
+
+      ctx.currentVars = null
+
+      ctx.INPUT_ROW = keyTerm
+      val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
+        BoundReference(i, e.dataType, e.nullable).genCode(ctx)
+      }
+      val evaluateKeyVars = evaluateVariables(keyVars)
+
+      ctx.INPUT_ROW = bufferTerm
+      val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
+        BoundReference(i, e.dataType, e.nullable).genCode(ctx)
+      }
+      val evaluateResultBufferVars = evaluateVariables(resultBufferVars)
+
+      ctx.currentVars = keyVars ++ resultBufferVars
+      val inputAttrs = resultExpressions.map(_.toAttribute)
+      val resultVars = resultExpressions.map { e =>
+        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
+      }
       s"""
-       UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
-       ${consume(ctx, null, resultRow)}
+       $evaluateKeyVars
+       $evaluateResultBufferVars
+       ${consume(ctx, resultVars)}
        """
-
     } else {
       // generate result based on grouping key
       ctx.INPUT_ROW = keyTerm
@@ -483,6 +503,13 @@ case class HashAggregateExec(
       }
       consume(ctx, eval)
     }
+    ctx.addNewFunction(funcName,
+      s"""
+        private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
+            throws java.io.IOException {
+          $body
+        }
+       """)
   }
 
   /**
@@ -581,11 +608,6 @@ case class HashAggregateExec(
     val iterTerm = ctx.freshName("mapIter")
     ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
 
-    val doAgg = ctx.freshName("doAggregateWithKeys")
-    val peakMemory = metricTerm(ctx, "peakMemory")
-    val spillSize = metricTerm(ctx, "spillSize")
-    val avgHashProbe = metricTerm(ctx, "avgHashProbe")
-
     def generateGenerateCode(): String = {
       if (isFastHashMapEnabled) {
         if (isVectorizedHashMapEnabled) {
@@ -599,10 +621,14 @@ case class HashAggregateExec(
         }
       } else ""
     }
+    ctx.addExtraCode(generateGenerateCode())
 
+    val doAgg = ctx.freshName("doAggregateWithKeys")
+    val peakMemory = metricTerm(ctx, "peakMemory")
+    val spillSize = metricTerm(ctx, "spillSize")
+    val avgHashProbe = metricTerm(ctx, "avgHashProbe")
     val doAggFuncName = ctx.addNewFunction(doAgg,
       s"""
-        ${generateGenerateCode}
         private void $doAgg() throws java.io.IOException {
           $hashMapTerm = $thisPlan.createHashMap();
           ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
@@ -618,7 +644,7 @@ case class HashAggregateExec(
     // generate code for output
     val keyTerm = ctx.freshName("aggKey")
     val bufferTerm = ctx.freshName("aggBuffer")
-    val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
+    val outputFunc = generateResultFunction(ctx)
     val numOutput = metricTerm(ctx, "numOutputRows")
 
     // The child could change `copyResult` to true, but we had already consumed all the rows,
@@ -641,7 +667,7 @@ case class HashAggregateExec(
          $numOutput.add(1);
          UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
          UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
-         $outputCode
+         $outputFunc($keyTerm, $bufferTerm);
 
          if (shouldStop()) return;
        }
@@ -654,18 +680,23 @@ case class HashAggregateExec(
         val row = ctx.freshName("fastHashMapRow")
         ctx.currentVars = null
         ctx.INPUT_ROW = row
-        var schema: StructType = groupingKeySchema
-        bufferSchema.foreach(i => schema = schema.add(i))
-        val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex
-          .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) })
+        val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
+          groupingKeySchema.toAttributes.zipWithIndex
+          .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }
+        )
+        val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
+          bufferSchema.toAttributes.zipWithIndex
+          .map { case (attr, i) =>
+            BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) })
         s"""
            | while ($iterTermForFastHashMap.hasNext()) {
            |   $numOutput.add(1);
            |   org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
            |     (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
            |     $iterTermForFastHashMap.next();
-           |   ${generateRow.code}
-           |   ${consume(ctx, Seq.empty, {generateRow.value})}
+           |   ${generateKeyRow.code}
+           |   ${generateBufferRow.code}
+           |   $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
            |
            |   if (shouldStop()) return;
            | }
@@ -692,7 +723,7 @@ case class HashAggregateExec(
        $numOutput.add(1);
        UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
        UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
-       $outputCode
+       $outputFunc($keyTerm, $bufferTerm);
 
        if (shouldStop()) return;
      }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index e4e9372447f7cde86b2f9d472e04f1afda51a346..18142c44f0295a5af801447174d1498b99602067 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -201,11 +201,14 @@ case class FilterExec(condition: Expression, child: SparkPlan)
       ev
     }
 
+    // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
     s"""
-       |$generated
-       |$nullChecks
-       |$numOutput.add(1);
-       |${consume(ctx, resultVars)}
+       |do {
+       |  $generated
+       |  $nullChecks
+       |  $numOutput.add(1);
+       |  ${consume(ctx, resultVars)}
+       |} while(false);
      """.stripMargin
   }
 
@@ -316,9 +319,10 @@ case class SampleExec(
          """.stripMargin.trim)
 
       s"""
-         | if ($sampler.sample() == 0) continue;
-         | $numOutput.add(1);
-         | ${consume(ctx, input)}
+         | if ($sampler.sample() != 0) {
+         |   $numOutput.add(1);
+         |   ${consume(ctx, input)}
+         | }
        """.stripMargin.trim
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index ab7bb8ab9d87ad7a33c483c1e9dda68712484c6b..b09da9bdacb99b2d7d7ec51a3781cb200bfe27da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -186,8 +186,7 @@ case class BroadcastHashJoinExec(
    */
   private def getJoinCondition(
       ctx: CodegenContext,
-      input: Seq[ExprCode],
-      anti: Boolean = false): (String, String, Seq[ExprCode]) = {
+      input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
     val matched = ctx.freshName("matched")
     val buildVars = genBuildSideVars(ctx, matched)
     val checkCondition = if (condition.isDefined) {
@@ -198,18 +197,12 @@ case class BroadcastHashJoinExec(
       ctx.currentVars = input ++ buildVars
       val ev =
         BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
-      val skipRow = if (!anti) {
-        s"${ev.isNull} || !${ev.value}"
-      } else {
-        s"!${ev.isNull} && ${ev.value}"
-      }
+      val skipRow = s"${ev.isNull} || !${ev.value}"
       s"""
          |$eval
          |${ev.code}
-         |if ($skipRow) continue;
+         |if (!($skipRow))
        """.stripMargin
-    } else if (anti) {
-      "continue;"
     } else {
       ""
     }
@@ -235,10 +228,12 @@ case class BroadcastHashJoinExec(
          |${keyEv.code}
          |// find matches from HashedRelation
          |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
-         |if ($matched == null) continue;
-         |$checkCondition
-         |$numOutput.add(1);
-         |${consume(ctx, resultVars)}
+         |if ($matched != null) {
+         |  $checkCondition {
+         |    $numOutput.add(1);
+         |    ${consume(ctx, resultVars)}
+         |  }
+         |}
        """.stripMargin
 
     } else {
@@ -250,12 +245,14 @@ case class BroadcastHashJoinExec(
          |${keyEv.code}
          |// find matches from HashRelation
          |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
-         |if ($matches == null) continue;
-         |while ($matches.hasNext()) {
-         |  UnsafeRow $matched = (UnsafeRow) $matches.next();
-         |  $checkCondition
-         |  $numOutput.add(1);
-         |  ${consume(ctx, resultVars)}
+         |if ($matches != null) {
+         |  while ($matches.hasNext()) {
+         |    UnsafeRow $matched = (UnsafeRow) $matches.next();
+         |    $checkCondition {
+         |      $numOutput.add(1);
+         |      ${consume(ctx, resultVars)}
+         |    }
+         |  }
          |}
        """.stripMargin
     }
@@ -328,10 +325,11 @@ case class BroadcastHashJoinExec(
          |  UnsafeRow $matched = $matches != null && $matches.hasNext() ?
          |    (UnsafeRow) $matches.next() : null;
          |  ${checkCondition.trim}
-         |  if (!$conditionPassed) continue;
-         |  $found = true;
-         |  $numOutput.add(1);
-         |  ${consume(ctx, resultVars)}
+         |  if ($conditionPassed) {
+         |    $found = true;
+         |    $numOutput.add(1);
+         |    ${consume(ctx, resultVars)}
+         |  }
          |}
        """.stripMargin
     }
@@ -351,10 +349,12 @@ case class BroadcastHashJoinExec(
          |${keyEv.code}
          |// find matches from HashedRelation
          |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
-         |if ($matched == null) continue;
-         |$checkCondition
-         |$numOutput.add(1);
-         |${consume(ctx, input)}
+         |if ($matched != null) {
+         |  $checkCondition {
+         |    $numOutput.add(1);
+         |    ${consume(ctx, input)}
+         |  }
+         |}
        """.stripMargin
     } else {
       val matches = ctx.freshName("matches")
@@ -365,16 +365,19 @@ case class BroadcastHashJoinExec(
          |${keyEv.code}
          |// find matches from HashRelation
          |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
-         |if ($matches == null) continue;
-         |boolean $found = false;
-         |while (!$found && $matches.hasNext()) {
-         |  UnsafeRow $matched = (UnsafeRow) $matches.next();
-         |  $checkCondition
-         |  $found = true;
+         |if ($matches != null) {
+         |  boolean $found = false;
+         |  while (!$found && $matches.hasNext()) {
+         |    UnsafeRow $matched = (UnsafeRow) $matches.next();
+         |    $checkCondition {
+         |      $found = true;
+         |    }
+         |  }
+         |  if ($found) {
+         |    $numOutput.add(1);
+         |    ${consume(ctx, input)}
+         |  }
          |}
-         |if (!$found) continue;
-         |$numOutput.add(1);
-         |${consume(ctx, input)}
        """.stripMargin
     }
   }
@@ -386,11 +389,13 @@ case class BroadcastHashJoinExec(
     val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
     val uniqueKeyCodePath = broadcastRelation.value.keyIsUnique
     val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
-    val (matched, checkCondition, _) = getJoinCondition(ctx, input, uniqueKeyCodePath)
+    val (matched, checkCondition, _) = getJoinCondition(ctx, input)
     val numOutput = metricTerm(ctx, "numOutputRows")
 
     if (uniqueKeyCodePath) {
+      val found = ctx.freshName("found")
       s"""
+         |boolean $found = false;
          |// generate join key for stream side
          |${keyEv.code}
          |// Check if the key has nulls.
@@ -399,17 +404,22 @@ case class BroadcastHashJoinExec(
          |  UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
          |  if ($matched != null) {
          |    // Evaluate the condition.
-         |    $checkCondition
+         |    $checkCondition {
+         |      $found = true;
+         |    }
          |  }
          |}
-         |$numOutput.add(1);
-         |${consume(ctx, input)}
+         |if (!$found) {
+         |  $numOutput.add(1);
+         |  ${consume(ctx, input)}
+         |}
        """.stripMargin
     } else {
       val matches = ctx.freshName("matches")
       val iteratorCls = classOf[Iterator[UnsafeRow]].getName
       val found = ctx.freshName("found")
       s"""
+         |boolean $found = false;
          |// generate join key for stream side
          |${keyEv.code}
          |// Check if the key has nulls.
@@ -418,17 +428,18 @@ case class BroadcastHashJoinExec(
          |  $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value});
          |  if ($matches != null) {
          |    // Evaluate the condition.
-         |    boolean $found = false;
          |    while (!$found && $matches.hasNext()) {
          |      UnsafeRow $matched = (UnsafeRow) $matches.next();
-         |      $checkCondition
-         |      $found = true;
+         |      $checkCondition {
+         |        $found = true;
+         |      }
          |    }
-         |    if ($found) continue;
          |  }
          |}
-         |$numOutput.add(1);
-         |${consume(ctx, input)}
+         |if (!$found) {
+         |  $numOutput.add(1);
+         |  ${consume(ctx, input)}
+         |}
        """.stripMargin
     }
   }