diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 90923fe31a0637c521d4dadc31760542b2176ac4..f0fd9a8b9a46e8fbe4d058c046017ac87270ff33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.planning
 
 import scala.annotation.tailrec
 
-import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 
@@ -134,8 +135,8 @@ object PartialAggregation {
       // Only do partial aggregation if supported by all aggregate expressions.
       if (allAggregates.size == partialAggregates.size) {
         // Create a map of expressions to their partial evaluations for all aggregate expressions.
-        val partialEvaluations: Map[Long, SplitEvaluation] =
-          partialAggregates.map(a => (a.id, a.asPartial)).toMap
+        val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] =
+          partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap
 
         // We need to pass all grouping expressions though so the grouping can happen a second
         // time. However some of them might be unnamed so we alias them allowing them to be
@@ -148,8 +149,8 @@ object PartialAggregation {
         // Replace aggregations with a new expression that computes the result from the already
         // computed partial evaluations and grouping values.
         val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
-          case e: Expression if partialEvaluations.contains(e.id) =>
-            partialEvaluations(e.id).finalEvaluation
+          case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
+            partialEvaluations(new TreeNodeRef(e)).finalEvaluation
           case e: Expression if namedGroupingExpressions.contains(e) =>
             namedGroupingExpressions(e).toAttribute
         }).asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 1e177e28f80b31d149f1935c03c80ea7ce100a29..af9e4d86e995a2d5f80b7f21c4665cea6ee3de35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -50,11 +50,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
 
     @inline def transformExpressionDown(e: Expression) = {
       val newE = e.transformDown(rule)
-      if (newE.id != e.id && newE != e) {
+      if (newE.fastEquals(e)) {
+        e
+      } else {
         changed = true
         newE
-      } else {
-        e
       }
     }
 
@@ -82,11 +82,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
 
     @inline def transformExpressionUp(e: Expression) = {
       val newE = e.transformUp(rule)
-      if (newE.id != e.id && newE != e) {
+      if (newE.fastEquals(e)) {
+        e
+      } else {
         changed = true
         newE
-      } else {
-        e
       }
     }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 96ce35939e2cc966260615ea805274688053dad6..2013ae4f7bd13d535b1a0c3d33271c3208ec95bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -19,11 +19,6 @@ package org.apache.spark.sql.catalyst.trees
 
 import org.apache.spark.sql.catalyst.errors._
 
-object TreeNode {
-  private val currentId = new java.util.concurrent.atomic.AtomicLong
-  protected def nextId() = currentId.getAndIncrement()
-}
-
 /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
 private class MutableInt(var i: Int)
 
@@ -33,29 +28,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
   /** Returns a Seq of the children of this node */
   def children: Seq[BaseType]
 
-  /**
-   * A globally unique id for this specific instance. Not preserved across copies.
-   * Unlike `equals`, `id` can be used to differentiate distinct but structurally
-   * identical branches of a tree.
-   */
-  val id = TreeNode.nextId()
-
-  /**
-   * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] instance.  Unlike
-   * `equals` this function will return false for different instances of structurally identical
-   * trees.
-   */
-  def sameInstance(other: TreeNode[_]): Boolean = {
-    this.id == other.id
-  }
-
   /**
    * Faster version of equality which short-circuits when two treeNodes are the same instance.
    * We don't just override Object.Equals, as doing so prevents the scala compiler from from
    * generating case class `equals` methods
    */
   def fastEquals(other: TreeNode[_]): Boolean = {
-    sameInstance(other) || this == other
+    this.eq(other) || this == other
   }
 
   /**
@@ -393,3 +372,4 @@ trait UnaryNode[BaseType <: TreeNode[BaseType]] {
   def child: BaseType
   def children = child :: Nil
 }
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
index d725a92c06f7bf38c02443f87f84b35d3833450a..79a8e06d4b4d4a4be03e9a47343b92548e4b68e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
@@ -37,4 +37,15 @@ package object trees extends Logging {
   // Since we want tree nodes to be lightweight, we create one logger for all treenode instances.
   protected override def logName = "catalyst.trees"
 
+  /**
+   * A [[TreeNode]] companion for reference equality for Hash based Collection.
+   */
+  class TreeNodeRef(val obj: TreeNode[_]) {
+    override def equals(o: Any) = o match {
+      case that: TreeNodeRef => that.obj.eq(obj)
+      case _ => false
+    }
+
+    override def hashCode = if (obj == null) 0 else obj.hashCode
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 296202543e2ca97d98bd8e947079c6f0f266e381..036fd3fa1d6a193cc665372c6a95da429198d99c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -51,7 +51,10 @@ class TreeNodeSuite extends FunSuite {
     val after = before transform { case Literal(5, _) => Literal(1)}
 
     assert(before === after)
-    assert(before.map(_.id) === after.map(_.id))
+    // Ensure that the objects after are the same objects before the transformation.
+    before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach {
+      case (b, a) => assert(b eq a)
+    }
   }
 
   test("collect") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 31ad5e8aabb0e038dc4434594aa85b380471f5b3..b3edd5020fa8cd7b82c49dce7ab2b4dee8b05505 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.trees._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.types._
@@ -141,9 +142,10 @@ case class GeneratedAggregate(
 
     val computationSchema = computeFunctions.flatMap(_.schema)
 
-    val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map {
-      case (agg, func) => agg.id -> func.result
-    }.toMap
+    val resultMap: Map[TreeNodeRef, Expression] = 
+      aggregatesToCompute.zip(computeFunctions).map {
+        case (agg, func) => new TreeNodeRef(agg) -> func.result
+      }.toMap
 
     val namedGroups = groupingExpressions.zipWithIndex.map {
       case (ne: NamedExpression, _) => (ne, ne)
@@ -156,7 +158,7 @@ case class GeneratedAggregate(
     // The set of expressions that produce the final output given the aggregation buffer and the
     // grouping expressions.
     val resultExpressions = aggregateExpressions.map(_.transform {
-      case e: Expression if resultMap.contains(e.id) => resultMap(e.id)
+      case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
       case e: Expression if groupMap.contains(e) => groupMap(e)
     })
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5b896c55b73938aa18325a1c316509b462e73eef..8ff757bbe3508f65c2134b36208b733586880a4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.SparkContext._
 import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
 
 /**
  * :: DeveloperApi ::
@@ -43,10 +44,10 @@ package object debug {
   implicit class DebugQuery(query: SchemaRDD) {
     def debug(): Unit = {
       val plan = query.queryExecution.executedPlan
-      val visited = new collection.mutable.HashSet[Long]()
+      val visited = new collection.mutable.HashSet[TreeNodeRef]()
       val debugPlan = plan transform {
-        case s: SparkPlan if !visited.contains(s.id) =>
-          visited += s.id
+        case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) =>
+          visited += new TreeNodeRef(s)
           DebugNode(s)
       }
       println(s"Results returned: ${debugPlan.execute().count()}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index aef6ebf86b1ebd08a9c04e210ef75bec1d1d8839..3dc8be245678103b1db9794b271b2ea6750d60af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -98,7 +98,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
         logical.Project(
           l.output,
           l.transformExpressions {
-            case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute
+            case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
           }.withNewChildren(newChildren))
       }
   }