Skip to content
Snippets Groups Projects
Commit 87ab0cec authored by Xianyang Liu's avatar Xianyang Liu Committed by Wenchen Fan
Browse files

[SPARK-21072][SQL] TreeNode.mapChildren should only apply to the children node.

## What changes were proposed in this pull request?

Just as the function name and comments of `TreeNode.mapChildren` mentioned, the function should be apply to all currently node children. So, the follow code should judge whether it is the children node.

https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L342

## How was this patch tested?

Existing tests.

Author: Xianyang Liu <xianyang.liu@intel.com>

Closes #18284 from ConeyLiu/treenode.
parent 5d35d5c1
No related branches found
No related tags found
No related merge requests found
...@@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { ...@@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg arg
} }
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = f(arg1.asInstanceOf[BaseType]) val newChild1 = if (containsChild(arg1)) {
val newChild2 = f(arg2.asInstanceOf[BaseType]) f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}
val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true changed = true
(newChild1, newChild2) (newChild1, newChild2)
......
...@@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) ...@@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
override def output: Seq[Attribute] = Nil override def output: Seq[Attribute] = Nil
} }
case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable {
override def children: Seq[Expression] = map.values.toSeq override def children: Seq[Expression] = map.values.toSeq
override def nullable: Boolean = true override def nullable: Boolean = true
override def dataType: NullType = NullType override def dataType: NullType = NullType
override lazy val resolved = true override lazy val resolved = true
} }
case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
nonSons: Seq[(Expression, Expression)]) extends Unevaluable {
override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2))
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}
case class JsonTestTreeNode(arg: Any) extends LeafNode { case class JsonTestTreeNode(arg: Any) extends LeafNode {
override def output: Seq[Attribute] = Seq.empty[Attribute] override def output: Seq[Attribute] = Seq.empty[Attribute]
} }
...@@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite { ...@@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite {
assert(actual === Dummy(None)) assert(actual === Dummy(None))
} }
test("mapChildren should only works on children") {
val children = Seq((Literal(1), Literal(2)))
val nonChildren = Seq((Literal(3), Literal(4)))
val before = SeqTupleExpression(children, nonChildren)
val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }
val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren)
val actual = before mapChildren toZero
assert(actual === expect)
}
test("preserves origin") { test("preserves origin") {
CurrentOrigin.setPosition(1, 1) CurrentOrigin.setPosition(1, 1)
val add = Add(Literal(1), Literal(1)) val add = Add(Literal(1), Literal(1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment