From 6cbc61d1070584ffbc34b1f53df352c9162f414a Mon Sep 17 00:00:00 2001 From: Ruben Berenguel Montoro <ruben@mostlymaths.net> Date: Sat, 3 Jun 2017 14:56:42 +0900 Subject: [PATCH] [SPARK-19732][SQL][PYSPARK] Add fill functions for nulls in bool fields of datasets ## What changes were proposed in this pull request? Allow fill/replace of NAs with booleans, both in Python and Scala ## How was this patch tested? Unit tests, doctests This PR is original work from me and I license this work to the Spark project Author: Ruben Berenguel Montoro <ruben@mostlymaths.net> Author: Ruben Berenguel <ruben@mostlymaths.net> Closes #18164 from rberenguel/SPARK-19732-fillna-bools. --- python/pyspark/sql/dataframe.py | 23 ++++++++++--- python/pyspark/sql/tests.py | 34 ++++++++++++++----- .../spark/sql/DataFrameNaFunctions.scala | 30 ++++++++++++++-- .../spark/sql/DataFrameNaFunctionsSuite.scala | 21 ++++++++++++ 4 files changed, 94 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8d8b938478..99abfcc556 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1289,7 +1289,7 @@ class DataFrame(object): """Replace null values, alias for ``na.fill()``. :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. - :param value: int, long, float, string, or dict. + :param value: int, long, float, string, bool or dict. Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be @@ -1309,6 +1309,15 @@ class DataFrame(object): | 50| 50| null| +---+------+-----+ + >>> df5.na.fill(False).show() + +----+-------+-----+ + | age| name| spy| + +----+-------+-----+ + | 10| Alice|false| + | 5| Bob|false| + |null|Mallory| true| + +----+-------+-----+ + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() +---+------+-------+ |age|height| name| @@ -1319,10 +1328,13 @@ class DataFrame(object): | 50| null|unknown| +---+------+-------+ """ - if not isinstance(value, (float, int, long, basestring, dict)): - raise ValueError("value should be a float, int, long, string, or dict") + if not isinstance(value, (float, int, long, basestring, bool, dict)): + raise ValueError("value should be a float, int, long, string, bool or dict") + + # Note that bool validates isinstance(int), but we don't want to + # convert bools to floats - if isinstance(value, (int, long)): + if not isinstance(value, bool) and isinstance(value, (int, long)): value = float(value) if isinstance(value, dict): @@ -1819,6 +1831,9 @@ def _test(): Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), Row(name=None, age=None, height=None)]).toDF() + globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10), + Row(name='Bob', spy=None, age=5), + Row(name='Mallory', spy=True, age=None)]).toDF() globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846), Row(name='Bob', time=1479442946)]).toDF() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index acea9113ee..845e1c7619 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1697,40 +1697,58 @@ class SQLTests(ReusedPySparkTestCase): schema = StructType([ StructField("name", StringType(), True), StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) + StructField("height", DoubleType(), True), + StructField("spy", BooleanType(), True)]) # fillna shouldn't change non-null values - row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first() self.assertEqual(row.age, 10) # fillna with int - row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.0) # fillna with double - row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + row = self.spark.createDataFrame( + [(u'Alice', None, None, None)], schema).fillna(50.1).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.1) + # fillna with bool + row = self.spark.createDataFrame( + [(u'Alice', None, None, None)], schema).fillna(True).first() + self.assertEqual(row.age, None) + self.assertEqual(row.spy, True) + # fillna with string - row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first() + row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first() self.assertEqual(row.name, u"hello") self.assertEqual(row.age, None) # fillna with subset specified for numeric cols row = self.spark.createDataFrame( - [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first() self.assertEqual(row.name, None) self.assertEqual(row.age, 50) self.assertEqual(row.height, None) + self.assertEqual(row.spy, None) - # fillna with subset specified for numeric cols + # fillna with subset specified for string cols row = self.spark.createDataFrame( - [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() self.assertEqual(row.name, "haha") self.assertEqual(row.age, None) self.assertEqual(row.height, None) + self.assertEqual(row.spy, None) + + # fillna with subset specified for bool cols + row = self.spark.createDataFrame( + [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + self.assertEqual(row.spy, True) # fillna with dictionary for boolean types row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 052d85ad33..ee949e78fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -195,6 +195,30 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) + /** + * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. + * + * @since 2.3.0 + */ + def fill(value: Boolean): DataFrame = fill(value, df.columns) + + /** + * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified + * boolean columns. If a specified column is not a boolean column, it is ignored. + * + * @since 2.3.0 + */ + def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols) + + /** + * Returns a new `DataFrame` that replaces null values in specified boolean columns. + * If a specified column is not a boolean column, it is ignored. + * + * @since 2.3.0 + */ + def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** * Returns a new `DataFrame` that replaces null values. * @@ -440,8 +464,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric, string columns. If a specified column is not a numeric, string column, - * it is ignored. + * numeric, string columns. If a specified column is not a numeric, string + * or boolean column it is ignored. */ private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { // the fill[T] which T is Long/Double, @@ -452,6 +476,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val targetType = value match { case _: Double | _: Long => NumericType case _: String => StringType + case _: Boolean => BooleanType case _ => throw new IllegalArgumentException( s"Unsupported value type ${value.getClass.getName} ($value).") } @@ -461,6 +486,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val typeMatches = (targetType, f.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType + case (BooleanType, dt) => dt == BooleanType } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index aa237d0619..e63c5cb194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -104,6 +104,13 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { test("fill") { val input = createDF() + val boolInput = Seq[(String, java.lang.Boolean)]( + ("Bob", false), + ("Alice", null), + ("Mallory", true), + (null, null) + ).toDF("name", "spy") + val fillNumeric = input.na.fill(50.6) checkAnswer( fillNumeric, @@ -124,6 +131,12 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + // boolean + checkAnswer( + boolInput.na.fill(true).select("spy"), + Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil) + assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq) + // fill double with subset columns checkAnswer( input.na.fill(50.6, "age" :: Nil).select("name", "age"), @@ -134,6 +147,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row("Amy", 50) :: Row(null, 50) :: Nil) + // fill boolean with subset columns + checkAnswer( + boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"), + Row("Bob", false) :: + Row("Alice", true) :: + Row("Mallory", true) :: + Row(null, true) :: Nil) + // fill string with subset columns checkAnswer( Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), -- GitLab