From 6d0411b4f3a202cfb53f638ee5fd49072b42d3a6 Mon Sep 17 00:00:00 2001
From: Cheng Hao <hao.cheng@intel.com>
Date: Sun, 5 Jul 2015 21:50:52 -0700
Subject: [PATCH] [SQL][Minor] Update the DataFrame API for encode/decode

This is a the follow up of #6843.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #7230 from chenghao-intel/str_funcs2_followup and squashes the following commits:

52cc553 [Cheng Hao] update the code as comment
---
 .../expressions/stringOperations.scala        | 21 ++++++++++---------
 .../org/apache/spark/sql/functions.scala      | 14 +++++++------
 .../spark/sql/DataFrameFunctionsSuite.scala   |  8 +++++--
 3 files changed, 25 insertions(+), 18 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 6de40629ff..1a14a7a449 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -392,12 +392,13 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput
 /**
  * Decodes the first argument into a String using the provided character set
  * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
- * If either argument is null, the result will also be null. (As of Hive 0.12.0.).
+ * If either argument is null, the result will also be null.
  */
-case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes {
-  override def children: Seq[Expression] = bin :: charset :: Nil
-  override def foldable: Boolean = bin.foldable && charset.foldable
-  override def nullable: Boolean = bin.nullable || charset.nullable
+case class Decode(bin: Expression, charset: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  override def left: Expression = bin
+  override def right: Expression = charset
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)
 
@@ -420,13 +421,13 @@ case class Decode(bin: Expression, charset: Expression) extends Expression with
 /**
  * Encodes the first argument into a BINARY using the provided character set
  * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
- * If either argument is null, the result will also be null. (As of Hive 0.12.0.)
+ * If either argument is null, the result will also be null.
 */
 case class Encode(value: Expression, charset: Expression)
-  extends Expression with ExpectsInputTypes {
-  override def children: Seq[Expression] = value :: charset :: Nil
-  override def foldable: Boolean = value.foldable && charset.foldable
-  override def nullable: Boolean = value.nullable || charset.nullable
+  extends BinaryExpression with ExpectsInputTypes {
+
+  override def left: Expression = value
+  override def right: Expression = charset
   override def dataType: DataType = BinaryType
   override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index abcfc0b650..f80291776f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1666,18 +1666,19 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr)
+  def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr)
 
   /**
    * Computes the first argument into a binary from a string using the provided character set
    * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
    * If either argument is null, the result will also be null.
+   * NOTE: charset represents the string value of the character set, not the column name.
    *
    * @group string_funcs
    * @since 1.5.0
    */
-  def encode(columnName: String, charsetColumnName: String): Column =
-    encode(Column(columnName), Column(charsetColumnName))
+  def encode(columnName: String, charset: String): Column =
+    encode(Column(columnName), charset)
 
   /**
    * Computes the first argument into a string from a binary using the provided character set
@@ -1687,18 +1688,19 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr)
+  def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr)
 
   /**
    * Computes the first argument into a string from a binary using the provided character set
    * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
    * If either argument is null, the result will also be null.
+   * NOTE: charset represents the string value of the character set, not the column name.
    *
    * @group string_funcs
    * @since 1.5.0
    */
-  def decode(columnName: String, charsetColumnName: String): Column =
-    decode(Column(columnName), Column(charsetColumnName))
+  def decode(columnName: String, charset: String): Column =
+    decode(Column(columnName), charset)
 
 
   //////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index bc455a922d..afba28515e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -261,11 +261,15 @@ class DataFrameFunctionsSuite extends QueryTest {
     // non ascii characters are not allowed in the code, so we disable the scalastyle here.
     val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
     checkAnswer(
-      df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")),
+      df.select(
+        encode($"a", "utf-8"),
+        encode("a", "utf-8"),
+        decode($"c", "utf-8"),
+        decode("c", "utf-8")),
       Row(bytes, bytes, "大千世界", "大千世界"))
 
     checkAnswer(
-      df.selectExpr("encode(a, b)", "decode(c, b)"),
+      df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
       Row(bytes, "大千世界"))
     // scalastyle:on
   }
-- 
GitLab