From 6cbc61d1070584ffbc34b1f53df352c9162f414a Mon Sep 17 00:00:00 2001
From: Ruben Berenguel Montoro <ruben@mostlymaths.net>
Date: Sat, 3 Jun 2017 14:56:42 +0900
Subject: [PATCH] [SPARK-19732][SQL][PYSPARK] Add fill functions for nulls in
 bool fields of datasets

## What changes were proposed in this pull request?

Allow fill/replace of NAs with booleans, both in Python and Scala

## How was this patch tested?

Unit tests, doctests

This PR is original work from me and I license this work to the Spark project

Author: Ruben Berenguel Montoro <ruben@mostlymaths.net>
Author: Ruben Berenguel <ruben@mostlymaths.net>

Closes #18164 from rberenguel/SPARK-19732-fillna-bools.
---
 python/pyspark/sql/dataframe.py               | 23 ++++++++++---
 python/pyspark/sql/tests.py                   | 34 ++++++++++++++-----
 .../spark/sql/DataFrameNaFunctions.scala      | 30 ++++++++++++++--
 .../spark/sql/DataFrameNaFunctionsSuite.scala | 21 ++++++++++++
 4 files changed, 94 insertions(+), 14 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 8d8b938478..99abfcc556 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1289,7 +1289,7 @@ class DataFrame(object):
         """Replace null values, alias for ``na.fill()``.
         :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other.
 
-        :param value: int, long, float, string, or dict.
+        :param value: int, long, float, string, bool or dict.
             Value to replace null values with.
             If the value is a dict, then `subset` is ignored and `value` must be a mapping
             from column name (string) to replacement value. The replacement value must be
@@ -1309,6 +1309,15 @@ class DataFrame(object):
         | 50|    50| null|
         +---+------+-----+
 
+        >>> df5.na.fill(False).show()
+        +----+-------+-----+
+        | age|   name|  spy|
+        +----+-------+-----+
+        |  10|  Alice|false|
+        |   5|    Bob|false|
+        |null|Mallory| true|
+        +----+-------+-----+
+
         >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
         +---+------+-------+
         |age|height|   name|
@@ -1319,10 +1328,13 @@ class DataFrame(object):
         | 50|  null|unknown|
         +---+------+-------+
         """
-        if not isinstance(value, (float, int, long, basestring, dict)):
-            raise ValueError("value should be a float, int, long, string, or dict")
+        if not isinstance(value, (float, int, long, basestring, bool, dict)):
+            raise ValueError("value should be a float, int, long, string, bool or dict")
+
+        # Note that bool validates isinstance(int), but we don't want to
+        # convert bools to floats
 
-        if isinstance(value, (int, long)):
+        if not isinstance(value, bool) and isinstance(value, (int, long)):
             value = float(value)
 
         if isinstance(value, dict):
@@ -1819,6 +1831,9 @@ def _test():
                                    Row(name='Bob', age=5, height=None),
                                    Row(name='Tom', age=None, height=None),
                                    Row(name=None, age=None, height=None)]).toDF()
+    globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10),
+                                   Row(name='Bob', spy=None, age=5),
+                                   Row(name='Mallory', spy=True, age=None)]).toDF()
     globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846),
                                    Row(name='Bob', time=1479442946)]).toDF()
 
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index acea9113ee..845e1c7619 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1697,40 +1697,58 @@ class SQLTests(ReusedPySparkTestCase):
         schema = StructType([
             StructField("name", StringType(), True),
             StructField("age", IntegerType(), True),
-            StructField("height", DoubleType(), True)])
+            StructField("height", DoubleType(), True),
+            StructField("spy", BooleanType(), True)])
 
         # fillna shouldn't change non-null values
-        row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
+        row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
         self.assertEqual(row.age, 10)
 
         # fillna with int
-        row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
+        row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, 50.0)
 
         # fillna with double
-        row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
+        row = self.spark.createDataFrame(
+            [(u'Alice', None, None, None)], schema).fillna(50.1).first()
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, 50.1)
 
+        # fillna with bool
+        row = self.spark.createDataFrame(
+            [(u'Alice', None, None, None)], schema).fillna(True).first()
+        self.assertEqual(row.age, None)
+        self.assertEqual(row.spy, True)
+
         # fillna with string
-        row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first()
+        row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
         self.assertEqual(row.name, u"hello")
         self.assertEqual(row.age, None)
 
         # fillna with subset specified for numeric cols
         row = self.spark.createDataFrame(
-            [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
+            [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
         self.assertEqual(row.name, None)
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, None)
+        self.assertEqual(row.spy, None)
 
-        # fillna with subset specified for numeric cols
+        # fillna with subset specified for string cols
         row = self.spark.createDataFrame(
-            [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
+            [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
         self.assertEqual(row.name, "haha")
         self.assertEqual(row.age, None)
         self.assertEqual(row.height, None)
+        self.assertEqual(row.spy, None)
+
+        # fillna with subset specified for bool cols
+        row = self.spark.createDataFrame(
+            [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
+        self.assertEqual(row.name, None)
+        self.assertEqual(row.age, None)
+        self.assertEqual(row.height, None)
+        self.assertEqual(row.spy, True)
 
         # fillna with dictionary for boolean types
         row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 052d85ad33..ee949e78fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -195,6 +195,30 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    */
   def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
 
+  /**
+   * Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
+   *
+   * @since 2.3.0
+   */
+  def fill(value: Boolean): DataFrame = fill(value, df.columns)
+
+  /**
+   * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
+   * boolean columns. If a specified column is not a boolean column, it is ignored.
+   *
+   * @since 2.3.0
+   */
+  def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)
+
+  /**
+   * Returns a new `DataFrame` that replaces null values in specified boolean columns.
+   * If a specified column is not a boolean column, it is ignored.
+   *
+   * @since 2.3.0
+   */
+  def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+
   /**
    * Returns a new `DataFrame` that replaces null values.
    *
@@ -440,8 +464,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
 
   /**
    * Returns a new `DataFrame` that replaces null or NaN values in specified
-   * numeric, string columns. If a specified column is not a numeric, string column,
-   * it is ignored.
+   * numeric, string columns. If a specified column is not a numeric, string
+   * or boolean column it is ignored.
    */
   private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
     // the fill[T] which T is  Long/Double,
@@ -452,6 +476,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
     val targetType = value match {
       case _: Double | _: Long => NumericType
       case _: String => StringType
+      case _: Boolean => BooleanType
       case _ => throw new IllegalArgumentException(
         s"Unsupported value type ${value.getClass.getName} ($value).")
     }
@@ -461,6 +486,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
       val typeMatches = (targetType, f.dataType) match {
         case (NumericType, dt) => dt.isInstanceOf[NumericType]
         case (StringType, dt) => dt == StringType
+        case (BooleanType, dt) => dt == BooleanType
       }
       // Only fill if the column is part of the cols list.
       if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index aa237d0619..e63c5cb194 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -104,6 +104,13 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
   test("fill") {
     val input = createDF()
 
+    val boolInput = Seq[(String, java.lang.Boolean)](
+      ("Bob", false),
+      ("Alice", null),
+      ("Mallory", true),
+      (null, null)
+    ).toDF("name", "spy")
+
     val fillNumeric = input.na.fill(50.6)
     checkAnswer(
       fillNumeric,
@@ -124,6 +131,12 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
         Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil)
     assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)
 
+    // boolean
+    checkAnswer(
+      boolInput.na.fill(true).select("spy"),
+      Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil)
+    assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq)
+
     // fill double with subset columns
     checkAnswer(
       input.na.fill(50.6, "age" :: Nil).select("name", "age"),
@@ -134,6 +147,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
         Row("Amy", 50) ::
         Row(null, 50) :: Nil)
 
+    // fill boolean with subset columns
+    checkAnswer(
+      boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"),
+      Row("Bob", false) ::
+        Row("Alice", true) ::
+        Row("Mallory", true) ::
+        Row(null, true) :: Nil)
+
     // fill string with subset columns
     checkAnswer(
       Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
-- 
GitLab