diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 887c702ad41962b9e4eb4e410be07b0c345c32dd..7c1fbadcb82beff5527c328356262134c1d1b880 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -216,9 +216,6 @@ class ArrowPandasSerializer(ArrowSerializer): Serializes Pandas.Series as Arrow data. """ - def __init__(self): - super(ArrowPandasSerializer, self).__init__() - def dumps(self, series): """ Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or @@ -245,16 +242,10 @@ class ArrowPandasSerializer(ArrowSerializer): def loads(self, obj): """ - Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series - followed by a dictionary containing length of the loaded batches. + Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series. """ - import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)] - # NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set - num_rows = sum((batch.num_rows for batch in batches)) - table = pa.Table.from_batches(batches) - return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}] + table = super(ArrowPandasSerializer, self).loads(obj) + return [c.to_pandas() for c in table.itercolumns()] def __repr__(self): return "ArrowPandasSerializer" diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 46e3a85e60d7b7a5da3a5c303da106211c04e569..63e9a830bbc9e5c5515fd27c65c85db4417ba8be 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2127,6 +2127,10 @@ class UserDefinedFunction(object): def _create_udf(f, returnType, vectorized): def _udf(f, returnType=StringType(), vectorized=vectorized): + if vectorized: + import inspect + if len(inspect.getargspec(f).args) == 0: + raise NotImplementedError("0-parameter pandas_udfs are not currently supported") udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) return udf_obj._wrapped() @@ -2183,14 +2187,28 @@ def pandas_udf(f=None, returnType=StringType()): :param f: python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object - # TODO: doctest + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) + >>> @pandas_udf(returnType=StringType()) + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf(returnType="integer") + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ """ - import inspect - # If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder - if inspect.getargspec(f).keywords is None: - return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True) - else: - return _create_udf(f, returnType=returnType, vectorized=True) + wrapped_udf = _create_udf(f, returnType=returnType, vectorized=True) + + return wrapped_udf blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3db8bee203469241e8e8c1faf0a6736ca7bc6bd0..1b3af42c47ad2aa5b68399233e3ed15ede678da2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3256,11 +3256,20 @@ class VectorizedUDFTests(ReusedPySparkTestCase): def test_vectorized_udf_zero_parameter(self): from pyspark.sql.functions import pandas_udf - import pandas as pd - df = self.spark.range(10) - f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType()) - res = df.select(f0()) - self.assertEquals(df.select(lit(1)).collect(), res.collect()) + error_str = '0-parameter pandas_udfs.*not.*supported' + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, error_str): + pandas_udf(lambda: 1, LongType()) + + with self.assertRaisesRegexp(NotImplementedError, error_str): + @pandas_udf + def zero_no_type(): + return 1 + + with self.assertRaisesRegexp(NotImplementedError, error_str): + @pandas_udf(LongType()) + def zero_with_type(): + return 1 def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col @@ -3308,12 +3317,12 @@ class VectorizedUDFTests(ReusedPySparkTestCase): from pyspark.sql.functions import pandas_udf, col import pandas as pd df = self.spark.range(10) - raise_exception = pandas_udf(lambda: pd.Series(1), LongType()) + raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) with QuietTest(self.sc): with self.assertRaisesRegexp( Exception, 'Result vector from pandas_udf was not the required length'): - df.select(raise_exception()).collect() + df.select(raise_exception(col('id'))).collect() def test_vectorized_udf_mix_udf(self): from pyspark.sql.functions import pandas_udf, udf, col @@ -3328,22 +3337,44 @@ class VectorizedUDFTests(ReusedPySparkTestCase): def test_vectorized_udf_chained(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10).toDF('x') + df = self.spark.range(10) f = pandas_udf(lambda x: x + 1, LongType()) g = pandas_udf(lambda x: x - 1, LongType()) - res = df.select(g(f(col('x')))) + res = df.select(g(f(col('id')))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10).toDF('x') + df = self.spark.range(10) f = pandas_udf(lambda x: x * 1.0, StringType()) with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Invalid.*type.*string'): - df.select(f(col('x'))).collect() + with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'): + df.select(f(col('id'))).collect() + + def test_vectorized_udf_return_scalar(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10) + f = pandas_udf(lambda x: 1.0, DoubleType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'): + df.select(f(col('id'))).collect() + + def test_vectorized_udf_decorator(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10) + @pandas_udf(returnType=LongType()) + def identity(x): + return x + res = df.select(identity(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_empty_partition(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda x: x, LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0e35cf7be6240d5dacd26651454e27aef2792b31..fd917c400c8720643702c8dc0972ff22a1c7237d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -60,12 +60,9 @@ def read_command(serializer, file): return command -def chain(f, g, eval_type): +def chain(f, g): """chain two functions together """ - if eval_type == PythonEvalType.SQL_PANDAS_UDF: - return lambda *a, **kwargs: g(f(*a, **kwargs), **kwargs) - else: - return lambda *a: g(f(*a)) + return lambda *a: g(f(*a)) def wrap_udf(f, return_type): @@ -80,14 +77,14 @@ def wrap_pandas_udf(f, return_type): arrow_return_type = toArrowType(return_type) def verify_result_length(*a): - kwargs = a[-1] - result = f(*a[:-1], **kwargs) - if len(result) != kwargs["length"]: + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of pandas_udf should be a Pandas.Series") + if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d\nUse input vector length or kwargs['length']" - % (kwargs["length"], len(result))) - return result, arrow_return_type - return lambda *a: verify_result_length(*a) + "expected %d, got %d" % (len(a[0]), len(result))) + return result + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): @@ -99,11 +96,9 @@ def read_single_udf(pickleSer, infile, eval_type): if row_func is None: row_func = f else: - row_func = chain(row_func, f, eval_type) + row_func = chain(row_func, f) # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: - # A pandas_udf will take kwargs as the last argument - arg_offsets = arg_offsets + [-1] return arg_offsets, wrap_pandas_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index f8bdc1e14eebc6d465ac321e199a9b028dd430b1..5e72cd255873aaff782c796f20262d6dcdd38f44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -51,10 +51,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi outputIterator.map(new ArrowPayload(_)), context) // Verify that the output schema is correct - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) - assert(schemaOut.equals(outputRowIterator.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}") + if (outputRowIterator.hasNext) { + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex + .map { case (attr, i) => attr.withName(s"_$i") }) + assert(schemaOut.equals(outputRowIterator.schema), + s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}") + } outputRowIterator }