diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index fc37720809ba25d3b174e6ba29183621cc1ede2e..cbd506465ae6a0dd3c3ae85687e72d69de02564f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -40,10 +40,10 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. case j @ Join(_, _, _: InnerLike, Some(cond)) => - reorder(j, j.outputSet) + reorder(j, j.output) case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) if projectList.forall(_.isInstanceOf[Attribute]) => - reorder(p, p.outputSet) + reorder(p, p.output) } // After reordering is finished, convert OrderedJoin back to Join result transformDown { @@ -52,7 +52,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } } - private def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) // TODO: Compute the set of star-joins and use them in the join enumeration // algorithm to prune un-optimal plan choices. @@ -140,7 +140,7 @@ object JoinReorderDP extends PredicateHelper with Logging { conf: SQLConf, items: Seq[LogicalPlan], conditions: Set[Expression], - topOutput: AttributeSet): LogicalPlan = { + output: Seq[Attribute]): LogicalPlan = { val startTime = System.nanoTime() // Level i maintains all found plans for i + 1 items. @@ -152,9 +152,10 @@ object JoinReorderDP extends PredicateHelper with Logging { // Build plans for next levels until the last level has only one plan. This plan contains // all items that can be joined, so there's no need to continue. + val topOutputSet = AttributeSet(output) while (foundPlans.size < items.length && foundPlans.last.size > 1) { // Build plans for the next level. - foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) + foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet) } val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) @@ -163,7 +164,14 @@ object JoinReorderDP extends PredicateHelper with Logging { // The last level must have one and only one plan, because all items are joinable. assert(foundPlans.size == items.length && foundPlans.last.size == 1) - foundPlans.last.head._2.plan + foundPlans.last.head._2.plan match { + case p @ Project(projectList, j: Join) if projectList != output => + assert(topOutputSet == p.outputSet) + // Keep the same order of final output attributes. + p.copy(projectList = output) + case finalPlan => + finalPlan + } } /** Find all possible plans at the next level, based on existing levels. */ @@ -254,10 +262,10 @@ object JoinReorderDP extends PredicateHelper with Logging { val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds val remainingConds = conditions -- collectedJoinConds val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput - val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val neededFromNewJoin = newJoin.output.filter(neededAttr.contains) val newPlan = if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { - Project(neededFromNewJoin.toSeq, newJoin) + Project(neededFromNewJoin, newJoin) } else { newJoin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 05b839b0119f481add5d4fabe7532ae68f0a1b82..d74008c1b302787494c6d657655aed62f36ddce5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -198,6 +198,19 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } + test("keep the order of attributes in the final output") { + val outputLists = Seq("t1.k-1-2", "t1.v-1-10", "t3.v-1-100").permutations + while (outputLists.hasNext) { + val expectedOrder = outputLists.next().map(nameToAttr) + val expectedPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(expectedOrder: _*) + // The plan should not change after optimization + assertEqualPlans(expectedPlan, expectedPlan) + } + } + private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2a9d0570148ad4e98916155d69dbb927403dbca8..c73dfaf3f8fe3dc2361a9be548e07b8d6b8e77b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -126,8 +126,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { case (j1: Join, j2: Join) => (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => - (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) case _ => plan1 == plan2 }