From 6a5a7254dc37952505989e9e580a14543adb730c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh <viirya@gmail.com> Date: Thu, 8 Dec 2016 23:22:18 +0800 Subject: [PATCH] [SPARK-18667][PYSPARK][SQL] Change the way to group row in BatchEvalPythonExec so input_file_name function can work with UDF in pyspark ## What changes were proposed in this pull request? `input_file_name` doesn't return filename when working with UDF in PySpark. An example shows the problem: from pyspark.sql.functions import * from pyspark.sql.types import * def filename(path): return path sourceFile = udf(filename, StringType()) spark.read.json("tmp.json").select(sourceFile(input_file_name())).show() +---------------------------+ |filename(input_file_name())| +---------------------------+ | | +---------------------------+ The cause of this issue is, we group rows in `BatchEvalPythonExec` for batching processing of PythonUDF. Currently we group rows first and then evaluate expressions on the rows. If the data is less than the required number of rows for a group, the iterator will be consumed to the end before the evaluation. However, once the iterator reaches the end, we will unset input filename. So the input_file_name expression can't return correct filename. This patch fixes the approach to group the batch of rows. We evaluate the expression first and then group evaluated results to batch. ## How was this patch tested? Added unit test to PySpark. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #16115 from viirya/fix-py-udf-input-filename. --- python/pyspark/sql/tests.py | 8 +++++ .../python/BatchEvalPythonExec.scala | 35 +++++++++---------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 50df68b144..66320bd050 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -412,6 +412,14 @@ class SQLTests(ReusedPySparkTestCase): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_with_input_file_name(self): + from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.types import StringType + sourceFile = udf(lambda path: path, StringType()) + filePath = "python/test_support/sql/people1.json" + row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() + self.assertTrue(row[0].find("people1.json") != -1) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index dcaf2c76d4..7a5ac48f1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -119,26 +119,23 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val pickle = new Pickler(needConversion) // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - val row = projection(inputRow) - if (needConversion) { - EvaluatePython.toJava(row, schema) - } else { - // fast path for these types that does not need conversion in Python - val fields = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - val dt = dataTypes(i) - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) - i += 1 - } - fields + val inputIterator = iter.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 } - }.toArray - pickle.dumps(toBePickled) - } + fields + } + }.grouped(100).map(x => pickle.dumps(x.toArray)) val context = TaskContext.get() -- GitLab