diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index e83f8a7cd1b5cf3ec81984dbe5d68120f78a2cc1..1bf461c912b612043ef8f6e6898e236119ea20ca 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -91,10 +91,17 @@ fromClause joinSource @init { gParent.pushMsg("join source", state); } @after { gParent.popMsg(state); } - : fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )* + : fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )* | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ ; +joinCond +@init { gParent.pushMsg("join expression list", state); } +@after { gParent.popMsg(state); } + : KW_ON! expression + | KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList) + ; + uniqueJoinSource @init { gParent.pushMsg("unique join source", state); } @after { gParent.popMsg(state); } diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 1db3aed65815dd310dde4640b531838abe73946c..f0c236859ddca35eed316294f3d18faa84194e7a 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -387,6 +387,7 @@ TOK_SETCONFIG; TOK_DFS; TOK_ADDFILE; TOK_ADDJAR; +TOK_USING; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 53ea3cfef6786e3eced147eb6777acded599bf27..e4e934a01541d849b5b7ad0022637ca23c6811f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -87,7 +87,7 @@ class Analyzer( ResolveSubquery :: ResolveWindowOrder :: ResolveWindowFrame :: - ResolveNaturalJoin :: + ResolveNaturalAndUsingJoin :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -1329,48 +1329,69 @@ class Analyzer( } /** - * Removes natural joins by calculating output columns based on output from two sides, - * Then apply a Project on a normal Join to eliminate natural join. + * Removes natural or using joins by calculating output columns based on output from two sides, + * Then apply a Project on a normal Join to eliminate natural or using join. */ - object ResolveNaturalJoin extends Rule[LogicalPlan] { + object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) + if left.resolved && right.resolved && j.duplicateResolved => + // Resolve the column names referenced in using clause from both the legs of join. + val lCols = usingCols.flatMap(col => left.resolveQuoted(col.name, resolver)) + val rCols = usingCols.flatMap(col => right.resolveQuoted(col.name, resolver)) + if ((lCols.length == usingCols.length) && (rCols.length == usingCols.length)) { + val joinNames = lCols.map(exp => exp.name) + commonNaturalJoinProcessing(left, right, joinType, joinNames, None) + } else { + j + } case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) - val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) - val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) - val joinPairs = leftKeys.zip(rightKeys) - - // Add joinPairs to joinConditions - val newCondition = (condition ++ joinPairs.map { - case (l, r) => EqualTo(l, r) - }).reduceOption(And) - - // columns not in joinPairs - val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) - val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) - - // the output list looks like: join keys, columns from left, columns from right - val projectList = joinType match { - case LeftOuter => - leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) - case RightOuter => - rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput - case FullOuter => - // in full outer join, joinCols should be non-null if there is. - val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } - joinedCols ++ - lUniqueOutput.map(_.withNullability(true)) ++ - rUniqueOutput.map(_.withNullability(true)) - case Inner => - rightKeys ++ lUniqueOutput ++ rUniqueOutput - case _ => - sys.error("Unsupported natural join type " + joinType) - } - // use Project to trim unnecessary fields - Project(projectList, Join(left, right, joinType, newCondition)) + commonNaturalJoinProcessing(left, right, joinType, joinNames, condition) + } + } + + private def commonNaturalJoinProcessing( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + joinNames: Seq[String], + condition: Option[Expression]) = { + val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) + val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + val joinPairs = leftKeys.zip(rightKeys) + + val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) + + // columns not in joinPairs + val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) + val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) + + // the output list looks like: join keys, columns from left, columns from right + val projectList = joinType match { + case LeftOuter => + leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) + case LeftSemi => + leftKeys ++ lUniqueOutput + case RightOuter => + rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput + case FullOuter => + // in full outer join, joinCols should be non-null if there is. + val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } + joinedCols ++ + lUniqueOutput.map(_.withNullability(true)) ++ + rUniqueOutput.map(_.withNullability(true)) + case Inner => + leftKeys ++ lUniqueOutput ++ rUniqueOutput + case _ => + sys.error("Unsupported natural join type " + joinType) } + // use Project to trim unnecessary fields + Project(projectList, Join(left, right, joinType, newCondition)) } + + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1e430c1fbbdf06052d56e6674e498dedd8bb2fb0..1d1e892e32cd3ba1ff34b17218801ebdb8886f23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -109,6 +110,12 @@ trait CheckAnalysis { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, UsingJoin(_, cols), _) => + val from = operator.inputSet.map(_.name).mkString(", ") + failAnalysis( + s"using columns [${cols.mkString(",")}] " + + s"can not be resolved given input columns: [$from] ") + case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.sql}' " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d0e5859d2702e35715b4c72dc5b5e2fab6fc621c..c419b5fd2204ba4abf1ae53574de7b7ff4fe5e35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1133,6 +1133,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) case FullOuter => f // DO Nothing for Full Outer Join case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") } // push down the join filter into sub query scanning if applicable @@ -1168,6 +1169,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala index 7d5a46873c217b2dbd51f43e902b53d5b52ee973..c188c5b108491dc2d6d00e89d6ef9109734455e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala @@ -419,30 +419,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Unsupported join operation: $other") } - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) - case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) - case "TOK_NATURALJOIN" => NaturalJoin(Inner) - case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) - case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) - case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) - } + val (joinType, joinCondition) = getJoinInfo(joinToken, other, node) + Join(nodeToRelation(relation1), nodeToRelation(relation2), joinType, - other.headOption.map(nodeToExpr)) - + joinCondition) case _ => noParseRule("Relation", node) } } + protected def getJoinInfo( + joinToken: String, + joinConditionToken: Seq[ASTNode], + node: ASTNode): (JoinType, Option[Expression]) = { + val joinType = joinToken match { + case "TOK_JOIN" => Inner + case "TOK_CROSSJOIN" => Inner + case "TOK_RIGHTOUTERJOIN" => RightOuter + case "TOK_LEFTOUTERJOIN" => LeftOuter + case "TOK_FULLOUTERJOIN" => FullOuter + case "TOK_LEFTSEMIJOIN" => LeftSemi + case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) + case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + case "TOK_NATURALJOIN" => NaturalJoin(Inner) + case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) + case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) + case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) + } + + joinConditionToken match { + case Token("TOK_USING", columnList :: Nil) :: Nil => + val colNames = columnList.children.collect { + case Token(name, Nil) => UnresolvedAttribute(name) + } + (UsingJoin(joinType, colNames), None) + /* Join expression specified using ON clause */ + case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr)) + } + } + protected def nodeToSortOrder(node: ASTNode): SortOrder = node match { case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => SortOrder(nodeToExpr(sortExpr), Ascending) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 27a75326eba07f26c4b46ffd5f68ee09ecd4163a..9ca4f13dd73cdef496634ab5cbf706572947652f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute + object JoinType { def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { case "inner" => Inner @@ -66,3 +68,9 @@ case class NaturalJoin(tpe: JoinType) extends JoinType { "Unsupported natural join type " + tpe) override def sql: String = "NATURAL " + tpe.sql } + +case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType { + require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe), + "Unsupported using join type " + tpe) + override def sql: String = "USING " + tpe.sql +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09ea3fea6a69414e8c6f374692190f071f615d14..ccc9916d090d289b8c9da9569437b319364ff37c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -298,10 +298,11 @@ case class Join( condition.forall(_.dataType == BooleanType) } - // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need - // to eliminate natural before we mark it resolved. + // if not a natural join, use `resolvedExceptNatural`. if it is a natural join or + // using join, we still need to eliminate natural or using before we mark it resolved. override lazy val resolved: Boolean = joinType match { case NaturalJoin(_) => false + case UsingJoin(_, _) => false case _ => resolvedExceptNatural } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index fcf4ac1967a53942fd0559c432c5667534d69926..1423a8705af27c0b57ee16c974c217f4f81653bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -35,56 +36,81 @@ class ResolveNaturalJoinSuite extends AnalysisTest { lazy val r3 = LocalRelation(aNotNull, bNotNull) lazy val r4 = LocalRelation(cNotNull, bNotNull) - test("natural inner join") { - val plan = r1.join(r2, NaturalJoin(Inner), None) + test("natural/using inner join") { + val naturalPlan = r1.join(r2, NaturalJoin(Inner), None) + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None) val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) } - test("natural left join") { - val plan = r1.join(r2, NaturalJoin(LeftOuter), None) + test("natural/using left join") { + val naturalPlan = r1.join(r2, NaturalJoin(LeftOuter), None) + val usingPlan = r1.join(r2, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("a"))), None) val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) } - test("natural right join") { - val plan = r1.join(r2, NaturalJoin(RightOuter), None) + test("natural/using right join") { + val naturalPlan = r1.join(r2, NaturalJoin(RightOuter), None) + val usingPlan = r1.join(r2, UsingJoin(RightOuter, Seq(UnresolvedAttribute("a"))), None) val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) } - test("natural full outer join") { - val plan = r1.join(r2, NaturalJoin(FullOuter), None) + test("natural/using full outer join") { + val naturalPlan = r1.join(r2, NaturalJoin(FullOuter), None) + val usingPlan = r1.join(r2, UsingJoin(FullOuter, Seq(UnresolvedAttribute("a"))), None) val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select( Alias(Coalesce(Seq(a, a)), "a")(), b, c) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) } - test("natural inner join with no nullability") { - val plan = r3.join(r4, NaturalJoin(Inner), None) + test("natural/using inner join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(Inner), None) + val usingPlan = r3.join(r4, UsingJoin(Inner, Seq(UnresolvedAttribute("b"))), None) val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select( bNotNull, aNotNull, cNotNull) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) } - test("natural left join with no nullability") { - val plan = r3.join(r4, NaturalJoin(LeftOuter), None) + test("natural/using left join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(LeftOuter), None) + val usingPlan = r3.join(r4, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("b"))), None) val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select( bNotNull, aNotNull, c) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) } - test("natural right join with no nullability") { - val plan = r3.join(r4, NaturalJoin(RightOuter), None) + test("natural/using right join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(RightOuter), None) + val usingPlan = r3.join(r4, UsingJoin(RightOuter, Seq(UnresolvedAttribute("b"))), None) val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select( bNotNull, a, cNotNull) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) } - test("natural full outer join with no nullability") { - val plan = r3.join(r4, NaturalJoin(FullOuter), None) + test("natural/using full outer join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None) + val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None) val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) - checkAnalysis(plan, expected) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("using unresolved attribute") { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("d"))), None) + val error = intercept[AnalysisException] { + SimpleAnalyzer.checkAnalysis(usingPlan) + } + assert(error.message.contains( + "using columns ['d] can not be resolved given input columns: [b, a, c]")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala index 048b4f12b9edf9e14cf63a41138daa8ac8db49ba..c068e895b6643245f4aa9ea1860495863b7a4b76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala @@ -219,4 +219,25 @@ class CatalystQlSuite extends PlanTest { parser.parsePlan("select * from t where a = (select b from s)") parser.parsePlan("select * from t group by g having a > (select b from s)") } + + test("using clause in JOIN") { + // Tests parsing of using clause for different join types. + parser.parsePlan("select * from t1 join t2 using (c1)") + parser.parsePlan("select * from t1 join t2 using (c1, c2)") + parser.parsePlan("select * from t1 left join t2 using (c1, c2)") + parser.parsePlan("select * from t1 right join t2 using (c1, c2)") + parser.parsePlan("select * from t1 full outer join t2 using (c1, c2)") + parser.parsePlan("select * from t1 join t2 using (c1) join t3 using (c2)") + // Tests errors + // (1) Empty using clause + // (2) Qualified columns in using + // (3) Both on and using clause + var error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using ()")) + assert(error.message.contains("cannot recognize input near ')'")) + error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using (t1.c1)")) + assert(error.message.contains("mismatched input '.'")) + error = intercept[AnalysisException](parser.parsePlan("select * from t1" + + " join t2 using (c1) on t1.c1 = t2.c1")) + assert(error.message.contains("missing EOF at 'on' near ')'")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ac2ca3c5a35d75ab1f1b678e15e9078abc76e826..75f1ffd51f6d6af1a059d44370e2b130b42b0b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -490,41 +490,12 @@ class Dataset[T] private[sql]( Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] - val condition = usingColumns.map { col => - catalyst.expressions.EqualTo( - withPlan(joined.left).resolve(col), - withPlan(joined.right).resolve(col)) - }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => - catalyst.expressions.And(cond, eqTo) - } - - // Project only one of the join columns. - val joinedCols = JoinType(joinType) match { - case Inner | LeftOuter | LeftSemi => - usingColumns.map(col => withPlan(joined.left).resolve(col)) - case RightOuter => - usingColumns.map(col => withPlan(joined.right).resolve(col)) - case FullOuter => - usingColumns.map { col => - val leftCol = withPlan(joined.left).resolve(col).toAttribute.withNullability(true) - val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) - Alias(Coalesce(Seq(leftCol, rightCol)), col)() - } - case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.") - } - // The nullability of output of joined could be different than original column, - // so we can only compare them by exprId - val joinRefs = AttributeSet(condition.toSeq.flatMap(_.references)) - val resultCols = joinedCols ++ joined.output.filterNot(joinRefs.contains(_)) withPlan { - Project( - resultCols, - Join( - joined.left, - joined.right, - joinType = JoinType(joinType), - condition) - ) + Join( + joined.left, + joined.right, + UsingJoin(JoinType(joinType), usingColumns.map(UnresolvedAttribute(_))), + None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3efe984c09eb86ff2a85d0373995ec01e8fba075..6716982118fed931e572f653eb82b78b46ff49bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2179,4 +2179,68 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(4) :: Nil) } } + + test("join with using clause") { + val df1 = Seq(("r1c1", "r1c2", "t1r1c3"), + ("r2c1", "r2c2", "t1r2c3"), ("r3c1x", "r3c2", "t1r3c3")).toDF("c1", "c2", "c3") + val df2 = Seq(("r1c1", "r1c2", "t2r1c3"), + ("r2c1", "r2c2", "t2r2c3"), ("r3c1y", "r3c2", "t2r3c3")).toDF("c1", "c2", "c3") + val df3 = Seq((null, "r1c2", "t3r1c3"), + ("r2c1", "r2c2", "t3r2c3"), ("r3c1y", "r3c2", "t3r3c3")).toDF("c1", "c2", "c3") + withTempTable("t1", "t2", "t3") { + df1.registerTempTable("t1") + df2.registerTempTable("t2") + df3.registerTempTable("t3") + // inner join with one using column + checkAnswer( + sql("SELECT * FROM t1 join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: Nil) + + // inner join with two using columns + checkAnswer( + sql("SELECT * FROM t1 join t2 using (c1, c2)"), + Row("r1c1", "r1c2", "t1r1c3", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "t2r2c3") :: Nil) + + // Left outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 left join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: Nil) + + // Right outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 right join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1y", null, null, "r3c2", "t2r3c3") :: Nil) + + // Full outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 full outer join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: + Row("r3c1y", null, + null, "r3c2", "t2r3c3") :: Nil) + + // Full outer join with null value in join column. + checkAnswer( + sql("SELECT * FROM t1 full outer join t3 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", null, null) :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t3r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: + Row("r3c1y", null, null, "r3c2", "t3r3c3") :: + Row(null, null, null, "r1c2", "t3r1c3") :: Nil) + + // Self join with using columns. + checkAnswer( + sql("SELECT * FROM t1 join t1 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t1r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t1r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", "r3c2", "t1r3c3") :: Nil) + } + } }