diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index df66f9a082aee6c2d2496f7a2776881ccba27be5..7375a0bcbae75594122ce005e49239cf8b2dc885 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
               arg
             }
           case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
-            val newChild1 = f(arg1.asInstanceOf[BaseType])
-            val newChild2 = f(arg2.asInstanceOf[BaseType])
+            val newChild1 = if (containsChild(arg1)) {
+              f(arg1.asInstanceOf[BaseType])
+            } else {
+              arg1.asInstanceOf[BaseType]
+            }
+
+            val newChild2 = if (containsChild(arg2)) {
+              f(arg2.asInstanceOf[BaseType])
+            } else {
+              arg2.asInstanceOf[BaseType]
+            }
+
             if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
               changed = true
               (newChild1, newChild2)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 712841835acd5199dd99cbfe3c8475869c7f7916..819078218c546039ad3b86a059f5c53ea3782500 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
   override def output: Seq[Attribute] = Nil
 }
 
-case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
+case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable {
   override def children: Seq[Expression] = map.values.toSeq
   override def nullable: Boolean = true
   override def dataType: NullType = NullType
   override lazy val resolved = true
 }
 
+case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
+    nonSons: Seq[(Expression, Expression)]) extends Unevaluable {
+  override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2))
+  override def nullable: Boolean = true
+  override def dataType: NullType = NullType
+  override lazy val resolved = true
+}
+
 case class JsonTestTreeNode(arg: Any) extends LeafNode {
   override def output: Seq[Attribute] = Seq.empty[Attribute]
 }
@@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite {
     assert(actual === Dummy(None))
   }
 
+  test("mapChildren should only works on children") {
+    val children = Seq((Literal(1), Literal(2)))
+    val nonChildren = Seq((Literal(3), Literal(4)))
+    val before = SeqTupleExpression(children, nonChildren)
+    val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }
+    val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren)
+
+    val actual = before mapChildren toZero
+    assert(actual === expect)
+  }
+
   test("preserves origin") {
     CurrentOrigin.setPosition(1, 1)
     val add = Add(Literal(1), Literal(1))