From 98ddbe6cdbe4141df3d008dcb675ecd682c97492 Mon Sep 17 00:00:00 2001
From: Zdenek Farana <zdenek.farana@gmail.com>
Date: Fri, 29 Aug 2014 15:39:15 -0700
Subject: [PATCH]  [SPARK-3173][SQL] Timestamp support in the parser

If you have a table with TIMESTAMP column, that column can't be used in WHERE clause properly - it is not evaluated properly. [More](https://issues.apache.org/jira/browse/SPARK-3173)

Motivation: http://www.aproint.com/aggregation-with-spark-sql/

- [x] modify SqlParser so it supports casting to TIMESTAMP (workaround for item 2)
- [x] the string literal should be converted into Timestamp if the column is Timestamp.

Author: Zdenek Farana <zdenek.farana@gmail.com>
Author: Zdenek Farana <zdenek.farana@aproint.com>

Closes #2084 from byF/SPARK-3173 and squashes the following commits:

442b59d [Zdenek Farana] Fixed test merge conflict
2dbf4f6 [Zdenek Farana] Merge remote-tracking branch 'origin/SPARK-3173' into SPARK-3173
65b6215 [Zdenek Farana] Fixed timezone sensitivity in the test
47b27b4 [Zdenek Farana] Now works in the case of "StringLiteral=TimestampColumn"
96a661b [Zdenek Farana] Code style change
491dfcf [Zdenek Farana] Added test cases for SPARK-3173
4446b1e [Zdenek Farana] A string literal is casted into Timestamp when the column is Timestamp.
59af397 [Zdenek Farana] Added a new TIMESTAMP keyword; CAST to TIMESTAMP now can be used in SQL expression.
---
 .../apache/spark/sql/catalyst/SqlParser.scala |  3 +-
 .../catalyst/analysis/HiveTypeCoercion.scala  | 10 +++++
 .../org/apache/spark/sql/SQLQuerySuite.scala  | 43 ++++++++++++++++++-
 3 files changed, 54 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 4f166c06b6..a88bd859fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -114,6 +114,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
   protected val STRING = Keyword("STRING")
   protected val SUM = Keyword("SUM")
   protected val TABLE = Keyword("TABLE")
+  protected val TIMESTAMP = Keyword("TIMESTAMP")
   protected val TRUE = Keyword("TRUE")
   protected val UNCACHE = Keyword("UNCACHE")
   protected val UNION = Keyword("UNION")
@@ -359,7 +360,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
     literal
 
   protected lazy val dataType: Parser[DataType] =
-    STRING ^^^ StringType
+    STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
 }
 
 class SqlLexical(val keywords: Seq[String]) extends StdLexical {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index ecfcd62d20..d6758eb5b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -218,11 +218,21 @@ trait HiveTypeCoercion {
       case a: BinaryArithmetic if a.right.dataType == StringType =>
         a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
 
+      case p: BinaryPredicate if p.left.dataType == StringType
+        && p.right.dataType == TimestampType =>
+        p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
+      case p: BinaryPredicate if p.left.dataType == TimestampType
+        && p.right.dataType == StringType =>
+        p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
+
       case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
         p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
       case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
         p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
 
+      case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
+        i.makeCopy(Array(a,b.map(Cast(_,TimestampType))))
+
       case Sum(e) if e.dataType == StringType =>
         Sum(Cast(e, DoubleType))
       case Average(e) if e.dataType == StringType =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4047bc0672..1ac2059377 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,15 +19,28 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.test._
+import org.scalatest.BeforeAndAfterAll
+import java.util.TimeZone
 
 /* Implicits */
 import TestSQLContext._
 import TestData._
 
-class SQLQuerySuite extends QueryTest {
+class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
   // Make sure the tables are loaded.
   TestData
 
+  var origZone: TimeZone = _
+  override protected def beforeAll() {
+    origZone = TimeZone.getDefault
+    TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
+  }
+
+  override protected def afterAll() {
+    TimeZone.setDefault(origZone)
+  }
+
+
   test("SPARK-2041 column name equals tablename") {
     checkAnswer(
       sql("SELECT tableName FROM tableName"),
@@ -63,6 +76,34 @@ class SQLQuerySuite extends QueryTest {
       "st")
   }
 
+  test("SPARK-3173 Timestamp support in the parser") {
+    checkAnswer(sql(
+      "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"),
+      Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+
+    checkAnswer(sql(
+      "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"),
+      Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+
+    checkAnswer(sql(
+      "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"),
+      Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+
+    checkAnswer(sql(
+      """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003'
+          AND time>'1970-01-01 00:00:00.001'"""),
+      Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
+
+    checkAnswer(sql(
+      "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"),
+      Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")),
+        Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
+
+    checkAnswer(sql(
+      "SELECT time FROM timestamps WHERE time='123'"),
+      Nil)
+  }
+
   test("index into array") {
     checkAnswer(
       sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"),
-- 
GitLab