diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 90923fe31a0637c521d4dadc31760542b2176ac4..f0fd9a8b9a46e8fbe4d058c046017ac87270ff33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -134,8 +135,8 @@ object PartialAggregation { // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap + val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = + partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be @@ -148,8 +149,8 @@ object PartialAggregation { // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation + case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => + partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression if namedGroupingExpressions.contains(e) => namedGroupingExpressions(e).toAttribute }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 1e177e28f80b31d149f1935c03c80ea7ce100a29..af9e4d86e995a2d5f80b7f21c4665cea6ee3de35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -50,11 +50,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy @inline def transformExpressionDown(e: Expression) = { val newE = e.transformDown(rule) - if (newE.id != e.id && newE != e) { + if (newE.fastEquals(e)) { + e + } else { changed = true newE - } else { - e } } @@ -82,11 +82,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy @inline def transformExpressionUp(e: Expression) = { val newE = e.transformUp(rule) - if (newE.id != e.id && newE != e) { + if (newE.fastEquals(e)) { + e + } else { changed = true newE - } else { - e } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 96ce35939e2cc966260615ea805274688053dad6..2013ae4f7bd13d535b1a0c3d33271c3208ec95bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,11 +19,6 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ -object TreeNode { - private val currentId = new java.util.concurrent.atomic.AtomicLong - protected def nextId() = currentId.getAndIncrement() -} - /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -33,29 +28,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a Seq of the children of this node */ def children: Seq[BaseType] - /** - * A globally unique id for this specific instance. Not preserved across copies. - * Unlike `equals`, `id` can be used to differentiate distinct but structurally - * identical branches of a tree. - */ - val id = TreeNode.nextId() - - /** - * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] instance. Unlike - * `equals` this function will return false for different instances of structurally identical - * trees. - */ - def sameInstance(other: TreeNode[_]): Boolean = { - this.id == other.id - } - /** * Faster version of equality which short-circuits when two treeNodes are the same instance. * We don't just override Object.Equals, as doing so prevents the scala compiler from from * generating case class `equals` methods */ def fastEquals(other: TreeNode[_]): Boolean = { - sameInstance(other) || this == other + this.eq(other) || this == other } /** @@ -393,3 +372,4 @@ trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType def children = child :: Nil } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index d725a92c06f7bf38c02443f87f84b35d3833450a..79a8e06d4b4d4a4be03e9a47343b92548e4b68e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -37,4 +37,15 @@ package object trees extends Logging { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. protected override def logName = "catalyst.trees" + /** + * A [[TreeNode]] companion for reference equality for Hash based Collection. + */ + class TreeNodeRef(val obj: TreeNode[_]) { + override def equals(o: Any) = o match { + case that: TreeNodeRef => that.obj.eq(obj) + case _ => false + } + + override def hashCode = if (obj == null) 0 else obj.hashCode + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 296202543e2ca97d98bd8e947079c6f0f266e381..036fd3fa1d6a193cc665372c6a95da429198d99c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -51,7 +51,10 @@ class TreeNodeSuite extends FunSuite { val after = before transform { case Literal(5, _) => Literal(1)} assert(before === after) - assert(before.map(_.id) === after.map(_.id)) + // Ensure that the objects after are the same objects before the transformation. + before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach { + case (b, a) => assert(b eq a) + } } test("collect") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 31ad5e8aabb0e038dc4434594aa85b380471f5b3..b3edd5020fa8cd7b82c49dce7ab2b4dee8b05505 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.types._ @@ -141,9 +142,10 @@ case class GeneratedAggregate( val computationSchema = computeFunctions.flatMap(_.schema) - val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { - case (agg, func) => agg.id -> func.result - }.toMap + val resultMap: Map[TreeNodeRef, Expression] = + aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => new TreeNodeRef(agg) -> func.result + }.toMap val namedGroups = groupingExpressions.zipWithIndex.map { case (ne: NamedExpression, _) => (ne, ne) @@ -156,7 +158,7 @@ case class GeneratedAggregate( // The set of expressions that produce the final output given the aggregation buffer and the // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { - case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) case e: Expression if groupMap.contains(e) => groupMap(e) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 5b896c55b73938aa18325a1c316509b462e73eef..8ff757bbe3508f65c2134b36208b733586880a4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -23,6 +23,7 @@ import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.catalyst.trees.TreeNodeRef /** * :: DeveloperApi :: @@ -43,10 +44,10 @@ package object debug { implicit class DebugQuery(query: SchemaRDD) { def debug(): Unit = { val plan = query.queryExecution.executedPlan - val visited = new collection.mutable.HashSet[Long]() + val visited = new collection.mutable.HashSet[TreeNodeRef]() val debugPlan = plan transform { - case s: SparkPlan if !visited.contains(s.id) => - visited += s.id + case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => + visited += new TreeNodeRef(s) DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index aef6ebf86b1ebd08a9c04e210ef75bec1d1d8839..3dc8be245678103b1db9794b271b2ea6750d60af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -98,7 +98,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { logical.Project( l.output, l.transformExpressions { - case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute }.withNewChildren(newChildren)) } }