From 6702324b60f99dab55912c08ccd3d03805f6b7b2 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@gmail.com>
Date: Thu, 30 Apr 2015 15:13:43 -0700
Subject: [PATCH] [SPARK-7196][SQL] Support precision and scale of decimal type
 for JDBC

JIRA: https://issues.apache.org/jira/browse/SPARK-7196

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #5777 from viirya/jdbc_precision and squashes the following commits:

f40f5e6 [Liang-Chi Hsieh] Support precision and scale for NUMERIC type.
49acbf9 [Liang-Chi Hsieh] Add unit test.
a509e19 [Liang-Chi Hsieh] Support precision and scale of decimal type for JDBC.
---
 .../main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 10 ++++++++--
 .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala    |  2 ++
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index f3b5455574..cef92abbdc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -37,7 +37,7 @@ private[sql] object JDBCRDD extends Logging {
    * @param sqlType - A field of java.sql.Types
    * @return The Catalyst type corresponding to sqlType.
    */
-  private def getCatalystType(sqlType: Int): DataType = {
+  private def getCatalystType(sqlType: Int, precision: Int, scale: Int): DataType = {
     val answer = sqlType match {
       case java.sql.Types.ARRAY         => null
       case java.sql.Types.BIGINT        => LongType
@@ -49,6 +49,8 @@ private[sql] object JDBCRDD extends Logging {
       case java.sql.Types.CLOB          => StringType
       case java.sql.Types.DATALINK      => null
       case java.sql.Types.DATE          => DateType
+      case java.sql.Types.DECIMAL
+        if precision != 0 || scale != 0 => DecimalType(precision, scale)
       case java.sql.Types.DECIMAL       => DecimalType.Unlimited
       case java.sql.Types.DISTINCT      => null
       case java.sql.Types.DOUBLE        => DoubleType
@@ -61,6 +63,8 @@ private[sql] object JDBCRDD extends Logging {
       case java.sql.Types.NCHAR         => StringType
       case java.sql.Types.NCLOB         => StringType
       case java.sql.Types.NULL          => null
+      case java.sql.Types.NUMERIC
+        if precision != 0 || scale != 0 => DecimalType(precision, scale)
       case java.sql.Types.NUMERIC       => DecimalType.Unlimited
       case java.sql.Types.NVARCHAR      => StringType
       case java.sql.Types.OTHER         => null
@@ -109,10 +113,11 @@ private[sql] object JDBCRDD extends Logging {
           val dataType = rsmd.getColumnType(i + 1)
           val typeName = rsmd.getColumnTypeName(i + 1)
           val fieldSize = rsmd.getPrecision(i + 1)
+          val fieldScale = rsmd.getScale(i + 1)
           val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
           val metadata = new MetadataBuilder().putString("name", columnName)
           var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata)
-          if (columnType == null) columnType = getCatalystType(dataType)
+          if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale)
           fields(i) = StructField(columnName, columnType, nullable, metadata.build())
           i = i + 1
         }
@@ -307,6 +312,7 @@ private[sql] class JDBCRDD(
       case BooleanType           => BooleanConversion
       case DateType              => DateConversion
       case DecimalType.Unlimited => DecimalConversion
+      case DecimalType.Fixed(d)  => DecimalConversion
       case DoubleType            => DoubleConversion
       case FloatType             => FloatConversion
       case IntegerType           => IntegerConversion
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 856a806781..b165ab2b1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -22,6 +22,7 @@ import java.sql.DriverManager
 import java.util.{Calendar, GregorianCalendar, Properties}
 
 import org.apache.spark.sql.test._
+import org.apache.spark.sql.types._
 import org.h2.jdbc.JdbcSQLException
 import org.scalatest.{FunSuite, BeforeAndAfter}
 import TestSQLContext._
@@ -271,6 +272,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
     assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==.
     assert(rows(0).getAs[BigDecimal](2)
         .equals(new BigDecimal("123456789012345.54321543215432100000")))
+    assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20))
   }
 
   test("SQL query as table name") {
-- 
GitLab