diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8b6662ab1fae6f1c1f047b99d7170d9ef6f1c135..31000dc41be6af1c07175729680bf7f492ed53ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1703,8 +1703,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with a column dropped. - * This is a no-op if schema doesn't contain column name. + * Returns a new [[Dataset]] with a column dropped. This is a no-op if schema doesn't contain + * column name. + * + * This method can only be used to drop top level columns. the colName string is treated + * literally without further interpretation. * * @group untypedrel * @since 2.0.0 @@ -1717,15 +1720,20 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] with columns dropped. * This is a no-op if schema doesn't contain column name(s). * + * This method can only be used to drop top level columns. the colName string is treated literally + * without further interpretation. + * * @group untypedrel * @since 2.0.0 */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = { val resolver = sparkSession.sessionState.analyzer.resolver - val remainingCols = - schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) - if (remainingCols.size == this.schema.size) { + val allColumns = queryExecution.analyzed.output + val remainingCols = allColumns.filter { attribute => + colNames.forall(n => !resolver(attribute.name, n)) + }.map(attribute => Column(attribute)) + if (remainingCols.size == allColumns.size) { toDF() } else { this.select(remainingCols: _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e2dc4d86395ee85d67400ea5d313b363bd8331f5..0e18ade09cbe5dc5a63e8dde4fa28a1e73ecb3e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -609,6 +609,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df("id") == person("id")) } + test("drop top level columns that contains dot") { + val df1 = Seq((1, 2)).toDF("a.b", "a.c") + checkAnswer(df1.drop("a.b"), Row(2)) + + // Creates data set: {"a.b": 1, "a": {"b": 3}} + val df2 = Seq((1)).toDF("a.b").withColumn("a", struct(lit(3) as "b")) + // Not like select(), drop() parses the column name "a.b" literally without interpreting "." + checkAnswer(df2.drop("a.b").select("a.b"), Row(3)) + + // "`" is treated as a normal char here with no interpreting, "`a`b" is a valid column name. + assert(df2.drop("`a.b`").columns.size == 2) + } + + test("drop(name: String) search and drop all top level columns that matchs the name") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((3, 4)).toDF("a", "b") + checkAnswer(df1.join(df2), Row(1, 2, 3, 4)) + // Finds and drops all columns that match the name (case insensitive). + checkAnswer(df1.join(df2).drop("A"), Row(2, 4)) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed")