From 0e70fd61b4bc92bd744fc44dd3cbe91443207c72 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Fri, 20 May 2016 13:21:53 -0700
Subject: [PATCH] [SPARK-15438][SQL] improve explain of whole stage codegen

## What changes were proposed in this pull request?

Currently, the explain of a query with whole-stage codegen looks like this
```
>>> df = sqlCtx.range(1000);df2 = sqlCtx.range(1000);df.join(pyspark.sql.functions.broadcast(df2), 'id').explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [id#1L]
:     +- BroadcastHashJoin [id#1L], [id#4L], Inner, BuildRight, None
:        :- Range 0, 1, 4, 1000, [id#1L]
:        +- INPUT
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint]))
   +- WholeStageCodegen
      :  +- Range 0, 1, 4, 1000, [id#4L]
```

The problem is that the plan looks much different than logical plan, make us hard to understand the plan (especially when the logical plan is not showed together).

This PR will change it to:

```
>>> df = sqlCtx.range(1000);df2 = sqlCtx.range(1000);df.join(pyspark.sql.functions.broadcast(df2), 'id').explain()
== Physical Plan ==
*Project [id#0L]
+- *BroadcastHashJoin [id#0L], [id#3L], Inner, BuildRight, None
   :- *Range 0, 1, 4, 1000, [id#0L]
   +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]))
      +- *Range 0, 1, 4, 1000, [id#3L]
```

The `*`before the plan means that it's part of whole-stage codegen, it's easy to understand.

## How was this patch tested?

Manually ran some queries and check the explain.

Author: Davies Liu <davies@databricks.com>

Closes #13204 from davies/explain_codegen.
---
 .../spark/sql/catalyst/trees/TreeNode.scala   | 57 +++----------------
 .../sql/execution/WholeStageCodegenExec.scala | 29 +++++-----
 .../sql/execution/exchange/Exchange.scala     |  3 -
 3 files changed, 22 insertions(+), 67 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 5eb8fdf048..e8e2a7bbab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -467,50 +467,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
   }
 
   /**
-   * All the nodes that will be used to generate tree string.
-   *
-   * For example:
-   *
-   *   WholeStageCodegen
-   *   +-- SortMergeJoin
-   *       |-- InputAdapter
-   *       |   +-- Sort
-   *       +-- InputAdapter
-   *           +-- Sort
-   *
-   * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string
-   * like this:
-   *
-   *   WholeStageCodegen
-   *   : +- SortMergeJoin
-   *   :    :- INPUT
-   *   :    :- INPUT
-   *   :-  Sort
-   *   :-  Sort
-   */
-  protected def treeChildren: Seq[BaseType] = children
-
-  /**
-   * All the nodes that are parts of this node.
-   *
-   * For example:
-   *
-   *   WholeStageCodegen
-   *   +- SortMergeJoin
-   *      |-- InputAdapter
-   *      |   +-- Sort
-   *      +-- InputAdapter
-   *          +-- Sort
-   *
-   * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree
-   * string like this:
-   *
-   *   WholeStageCodegen
-   *   : +- SortMergeJoin
-   *   :    :- INPUT
-   *   :    :- INPUT
-   *   :-  Sort
-   *   :-  Sort
+   * All the nodes that are parts of this node, this is used by subquries.
    */
   protected def innerChildren: Seq[BaseType] = Nil
 
@@ -522,7 +479,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
    * `lastChildren` for the root node should be empty.
    */
   def generateTreeString(
-      depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = {
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      builder: StringBuilder,
+      prefix: String = ""): StringBuilder = {
     if (depth > 0) {
       lastChildren.init.foreach { isLast =>
         val prefixFragment = if (isLast) "   " else ":  "
@@ -533,6 +493,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
       builder.append(branch)
     }
 
+    builder.append(prefix)
     builder.append(simpleString)
     builder.append("\n")
 
@@ -542,9 +503,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
       innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
     }
 
-    if (treeChildren.nonEmpty) {
-      treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
-      treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
+    if (children.nonEmpty) {
+      children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder, prefix))
+      children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, prefix)
     }
 
     builder
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 37fdc362b5..2a1ce735b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -245,9 +245,13 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
      """.stripMargin
   }
 
-  override def simpleString: String = "INPUT"
-
-  override def treeChildren: Seq[SparkPlan] = Nil
+  override def generateTreeString(
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      builder: StringBuilder,
+      prefix: String = ""): StringBuilder = {
+    child.generateTreeString(depth, lastChildren, builder, "")
+  }
 }
 
 object WholeStageCodegenExec {
@@ -398,20 +402,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
      """.stripMargin.trim
   }
 
-  override def innerChildren: Seq[SparkPlan] = {
-    child :: Nil
-  }
-
-  private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
-    case InputAdapter(c) => c :: Nil
-    case other => other.children.flatMap(collectInputs)
+  override def generateTreeString(
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      builder: StringBuilder,
+      prefix: String = ""): StringBuilder = {
+    child.generateTreeString(depth, lastChildren, builder, "*")
   }
-
-  override def treeChildren: Seq[SparkPlan] = {
-    collectInputs(child)
-  }
-
-  override def simpleString: String = "WholeStageCodegen"
 }
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index 9da9df6174..9a9597d373 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -60,9 +60,6 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
   override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
     child.executeBroadcast()
   }
-
-  // Do not repeat the same tree in explain.
-  override def treeChildren: Seq[SparkPlan] = Nil
 }
 
 /**
-- 
GitLab