diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8df150e2f855ffc97e1d71d79b841036310a35f0..73ec7a6d114f540f058f829d286edb718b5c3c5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -114,7 +114,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") values(i).asInstanceOf[String] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7fc8347428df489489f6703e9b00946802df5979..7f20cf8d767978aadb96e1ec96837ec2a63a8537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -252,7 +252,10 @@ class Column( /** * Equality test with an expression that is safe for null values. */ - override def <=> (other: Column): Column = EqualNullSafe(expr, other.expr) + override def <=> (other: Column): Column = other match { + case null => EqualNullSafe(expr, Literal.anyToLiteral(null).expr) + case _ => EqualNullSafe(expr, other.expr) + } /** * Equality test with a literal value that is safe for null values. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d0bb3640f8c1ccc90f6425d5431f06730e931625..3198215b2c3abb229430e22807777552be4c5b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -230,9 +230,12 @@ class DataFrame protected[sql]( /** * Selecting a single column and return it as a [[Column]]. */ - override def apply(colName: String): Column = { - val expr = resolve(colName) - new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr) + override def apply(colName: String): Column = colName match { + case "*" => + Column("*") + case _ => + val expr = resolve(colName) + new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala index 29c3d26ae56d91332bbe41936d0fb5d212a40294..4c44e178b9976e0499ba022dfe3cda9a1e24c39a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala @@ -53,6 +53,7 @@ package object dsl { def last(e: Column): Column = Last(e.expr) def min(e: Column): Column = Min(e.expr) def max(e: Column): Column = Max(e.expr) + def upper(e: Column): Column = Upper(e.expr) def lower(e: Column): Column = Lower(e.expr) def sqrt(e: Column): Column = Sqrt(e.expr) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..825a1862ba6fffdefabcc15a71ee0636cfcaa81c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.dsl._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} + + +class ColumnExpressionSuite extends QueryTest { + import org.apache.spark.sql.TestData._ + + // TODO: Add test cases for bitwise operations. + + test("star") { + checkAnswer(testData.select($"*"), testData.collect().toSeq) + } + + ignore("star qualified by data frame object") { + // This is not yet supported. + val df = testData.toDF + checkAnswer(df.select(df("*")), df.collect().toSeq) + } + + test("star qualified by table name") { + checkAnswer(testData.as("testData").select($"testData.*"), testData.collect().toSeq) + } + + test("+") { + checkAnswer( + testData2.select($"a" + 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + 1))) + + checkAnswer( + testData2.select($"a" + $"b" + 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + r.getInt(1) + 2))) + } + + test("-") { + checkAnswer( + testData2.select($"a" - 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - 1))) + + checkAnswer( + testData2.select($"a" - $"b" - 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - r.getInt(1) - 2))) + } + + test("*") { + checkAnswer( + testData2.select($"a" * 10), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * 10))) + + checkAnswer( + testData2.select($"a" * $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * r.getInt(1)))) + } + + test("/") { + checkAnswer( + testData2.select($"a" / 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / 2))) + + checkAnswer( + testData2.select($"a" / $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / r.getInt(1)))) + } + + + test("%") { + checkAnswer( + testData2.select($"a" % 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % 2))) + + checkAnswer( + testData2.select($"a" % $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % r.getInt(1)))) + } + + test("unary -") { + checkAnswer( + testData2.select(-$"a"), + testData2.collect().toSeq.map(r => Row(-r.getInt(0)))) + } + + test("unary !") { + checkAnswer( + complexData.select(!$"b"), + complexData.collect().toSeq.map(r => Row(!r.getBoolean(3)))) + } + + test("isNull") { + checkAnswer( + nullStrings.toDF.where($"s".isNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + } + + test("isNotNull") { + checkAnswer( + nullStrings.toDF.where($"s".isNotNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + } + + test("===") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("<=>") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("!==") { + val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + Row(1, 1) :: + Row(1, 2) :: + Row(1, null) :: + Row(null, null) :: Nil), + StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) + + checkAnswer( + nullData.filter($"b" <=> 1), + Row(1, 1) :: Nil) + + checkAnswer( + nullData.filter($"b" <=> null), + Row(1, null) :: Row(null, null) :: Nil) + + checkAnswer( + nullData.filter($"a" <=> $"b"), + Row(1, 1) :: Row(null, null) :: Nil) + } + + test(">") { + checkAnswer( + testData2.filter($"a" > 1), + testData2.collect().toSeq.filter(r => r.getInt(0) > 1)) + + checkAnswer( + testData2.filter($"a" > $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1))) + } + + test(">=") { + checkAnswer( + testData2.filter($"a" >= 1), + testData2.collect().toSeq.filter(r => r.getInt(0) >= 1)) + + checkAnswer( + testData2.filter($"a" >= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1))) + } + + test("<") { + checkAnswer( + testData2.filter($"a" < 2), + testData2.collect().toSeq.filter(r => r.getInt(0) < 2)) + + checkAnswer( + testData2.filter($"a" < $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) < r.getInt(1))) + } + + test("<=") { + checkAnswer( + testData2.filter($"a" <= 2), + testData2.collect().toSeq.filter(r => r.getInt(0) <= 2)) + + checkAnswer( + testData2.filter($"a" <= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) + } + + val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + + test("&&") { + checkAnswer( + booleanData.filter($"a" && true), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" && false), + Nil) + + checkAnswer( + booleanData.filter($"a" && $"b"), + Row(true, true) :: Nil) + } + + test("||") { + checkAnswer( + booleanData.filter($"a" || true), + booleanData.collect()) + + checkAnswer( + booleanData.filter($"a" || false), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" || $"b"), + Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) + } + + test("sqrt") { + checkAnswer( + testData.select(sqrt('key)).orderBy('key.asc), + (1 to 100).map(n => Row(math.sqrt(n))) + ) + + checkAnswer( + testData.select(sqrt('value), 'key).orderBy('key.asc, 'value.asc), + (1 to 100).map(n => Row(math.sqrt(n), n)) + ) + + checkAnswer( + testData.select(sqrt(Literal(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("abs") { + checkAnswer( + testData.select(abs('key)).orderBy('key.asc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + negativeData.select(abs('key)).orderBy('key.desc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + testData.select(abs(Literal(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("upper") { + checkAnswer( + lowerCaseData.select(upper('l)), + ('a' to 'd').map(c => Row(c.toString.toUpperCase)) + ) + + checkAnswer( + testData.select(upper('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(upper(Literal(null))), + (1 to 100).map(n => Row(null)) + ) + } + + test("lower") { + checkAnswer( + upperCaseData.select(lower('L)), + ('A' to 'F').map(c => Row(c.toString.toLowerCase)) + ) + + checkAnswer( + testData.select(lower('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(lower(Literal(null))), + (1 to 100).map(n => Row(null)) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala similarity index 82% rename from sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a5848f219cea9941e927660bb0e7354ca4fdd00d..6d7d5aa49358b1e81ed70ad321b86f4c1061eb03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.test.TestSQLContext._ import scala.language.postfixOps -class DslQuerySuite extends QueryTest { +class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ test("table scan") { @@ -276,71 +276,5 @@ class DslQuerySuite extends QueryTest { ) } - test("sqrt") { - checkAnswer( - testData.select(sqrt('key)).orderBy('key asc), - (1 to 100).map(n => Row(math.sqrt(n))) - ) - - checkAnswer( - testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc), - (1 to 100).map(n => Row(math.sqrt(n), n)) - ) - - checkAnswer( - testData.select(sqrt(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("abs") { - checkAnswer( - testData.select(abs('key)).orderBy('key asc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - negativeData.select(abs('key)).orderBy('key desc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - testData.select(abs(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - test("upper") { - checkAnswer( - lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Row(c.toString.toUpperCase)) - ) - - checkAnswer( - testData.select(upper('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(upper(Literal(null))), - (1 to 100).map(n => Row(null)) - ) - } - - test("lower") { - checkAnswer( - upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Row(c.toString.toLowerCase)) - ) - - checkAnswer( - testData.select(lower('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(lower(Literal(null))), - (1 to 100).map(n => Row(null)) - ) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index fffa2b7dfa6e1ea8ebecc43d78f03107c509b2d3..9eefe67c0443439c41fb7711675fc855e3026649 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -161,7 +161,7 @@ object TestData { TestSQLContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil) + NullStrings(3, null) :: Nil).toDF nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String)