diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 9989f68f8508c247dab4e47e653a69328eaefdf7..f524de68fbce027b28f814d5d9f27de0c09132ab 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -34,9 +34,11 @@ import org.apache.spark.util._ */ private[spark] object PythonEvalType { val NON_UDF = 0 - val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 - val SQL_PANDAS_GROUPED_UDF = 3 + + val SQL_BATCHED_UDF = 100 + + val SQL_PANDAS_SCALAR_UDF = 200 + val SQL_PANDAS_GROUP_MAP_UDF = 201 } /** diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ea993c572fafd2a5c07e064a527143816fb7d2be..340bc3a6b74704727ba5c334dd6353c7d699bb98 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -56,6 +56,22 @@ from pyspark.traceback_utils import SCCallSiteSync __all__ = ["RDD"] +class PythonEvalType(object): + """ + Evaluation type of python rdd. + + These values are internal to PySpark. + + These values should match values in org.apache.spark.api.python.PythonEvalType. + """ + NON_UDF = 0 + + SQL_BATCHED_UDF = 100 + + SQL_PANDAS_SCALAR_UDF = 200 + SQL_PANDAS_GROUP_MAP_UDF = 201 + + def portable_hash(x): """ This function returns consistent hash code for builtin types, especially diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e0afdafbfcd625ca07c2f6b1e351c08c551dd6a0..b95de2c804394e6f6046f399dced475a61017a1a 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -82,13 +82,6 @@ class SpecialLengths(object): START_ARROW_STREAM = -6 -class PythonEvalType(object): - NON_UDF = 0 - SQL_BATCHED_UDF = 1 - SQL_PANDAS_UDF = 2 - SQL_PANDAS_GROUPED_UDF = 3 - - class Serializer(object): def dump_stream(self, iterator, stream): diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 5f25dce161963f971e95f9c25ea50f193b9042ed..659bc65701a0cbcc9b2aaa27c19b2883e24bab6f 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -19,9 +19,9 @@ import warnings from collections import namedtuple from pyspark import since -from pyspark.rdd import ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import IntegerType, StringType, StructType @@ -256,7 +256,8 @@ class Catalog(object): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - udf = UserDefinedFunction(f, returnType, name) + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 087ce7caa89c809748b9060f45808d470ce19c25..b631e2041706f94fc48165331ed8b4a047cf21dd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,11 +27,12 @@ if sys.version < "3": from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import StringType, DataType +from pyspark.sql.udf import UserDefinedFunction, _create_udf def _create_function(name, doc=""): @@ -2062,132 +2063,12 @@ def map_values(col): # ---------------------------- User Defined Function ---------------------------------- -def _wrap_function(sc, func, returnType): - command = (func, returnType) - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) - return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, - sc.pythonVer, broadcast_vars, sc._javaAccumulator) - - -class PythonUdfType(object): - # row-at-a-time UDFs - NORMAL_UDF = 0 - # scalar vectorized UDFs - PANDAS_UDF = 1 - # grouped vectorized UDFs - PANDAS_GROUPED_UDF = 2 - - -class UserDefinedFunction(object): - """ - User defined function in Python - - .. versionadded:: 1.3 - """ - def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): - if not callable(func): - raise TypeError( - "Not a function or callable (__call__ is not defined): " - "{0}".format(type(func))) - - self.func = func - self._returnType = returnType - # Stores UserDefinedPythonFunctions jobj, once initialized - self._returnType_placeholder = None - self._judf_placeholder = None - self._name = name or ( - func.__name__ if hasattr(func, '__name__') - else func.__class__.__name__) - self.pythonUdfType = pythonUdfType - - @property - def returnType(self): - # This makes sure this is called after SparkContext is initialized. - # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. - if self._returnType_placeholder is None: - if isinstance(self._returnType, DataType): - self._returnType_placeholder = self._returnType - else: - self._returnType_placeholder = _parse_datatype_string(self._returnType) - return self._returnType_placeholder - - @property - def _judf(self): - # It is possible that concurrent access, to newly created UDF, - # will initialize multiple UserDefinedPythonFunctions. - # This is unlikely, doesn't affect correctness, - # and should have a minimal performance impact. - if self._judf_placeholder is None: - self._judf_placeholder = self._create_judf() - return self._judf_placeholder - - def _create_judf(self): - from pyspark.sql import SparkSession - - spark = SparkSession.builder.getOrCreate() - sc = spark.sparkContext - - wrapped_func = _wrap_function(sc, self.func, self.returnType) - jdt = spark._jsparkSession.parseDataType(self.returnType.json()) - judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.pythonUdfType) - return judf - - def __call__(self, *cols): - judf = self._judf - sc = SparkContext._active_spark_context - return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) - - def _wrapped(self): - """ - Wrap this udf with a function and attach docstring from func - """ - - # It is possible for a callable instance without __name__ attribute or/and - # __module__ attribute to be wrapped here. For example, functools.partial. In this case, - # we should avoid wrapping the attributes from the wrapped function to the wrapper - # function. So, we take out these attribute names from the default names to set and - # then manually assign it after being wrapped. - assignments = tuple( - a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') - - @functools.wraps(self.func, assigned=assignments) - def wrapper(*args): - return self(*args) - - wrapper.__name__ = self._name - wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') - else self.func.__class__.__module__) - - wrapper.func = self.func - wrapper.returnType = self.returnType - wrapper.pythonUdfType = self.pythonUdfType - - return wrapper - - -def _create_udf(f, returnType, pythonUdfType): - - def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): - if pythonUdfType == PythonUdfType.PANDAS_UDF: - import inspect - argspec = inspect.getargspec(f) - if len(argspec.args) == 0 and argspec.varargs is None: - raise ValueError( - "0-arg pandas_udfs are not supported. " - "Instead, create a 1-arg pandas_udf and ignore the arg in your function." - ) - udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) - return udf_obj._wrapped() - - # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf - if f is None or isinstance(f, (str, DataType)): - # If DataType has been passed as a positional argument - # for decorator use it as a returnType - return_type = f or returnType - return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) - else: - return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) +class PandasUDFType(object): + """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. + """ + SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF + + GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF @since(1.3) @@ -2228,33 +2109,47 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) + # decorator @udf, @udf(), @udf(dataType()) + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + return functools.partial(_create_udf, returnType=return_type, + evalType=PythonEvalType.SQL_BATCHED_UDF) + else: + return _create_udf(f=f, returnType=returnType, + evalType=PythonEvalType.SQL_BATCHED_UDF) @since(2.3) -def pandas_udf(f=None, returnType=StringType()): +def pandas_udf(f=None, returnType=None, functionType=None): """ Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object + :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. + Default: SCALAR. - The user-defined function can define one of the following transformations: + The function type of the UDF can be one of the following: - 1. One or more `pandas.Series` -> A `pandas.Series` + 1. SCALAR - This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. + A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. The returnType should be a primitive data type, e.g., `DoubleType()`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql.types import IntegerType, StringType >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) + >>> @pandas_udf(StringType()) ... def to_upper(s): ... return s.str.upper() ... - >>> @pandas_udf(returnType="integer") + >>> @pandas_udf("integer", PandasUDFType.SCALAR) ... def add_one(x): ... return x + 1 ... @@ -2267,20 +2162,24 @@ def pandas_udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ - 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + 2. GROUP_MAP - This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. + The length of the returned `pandas.DataFrame` can be arbitrary. + + Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -2291,10 +2190,6 @@ def pandas_udf(f=None, returnType=StringType()): | 2| 1.1094003924504583| +---+-------------------+ - .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` - because it defines a `DataFrame` transformation rather than a `Column` - transformation. - .. seealso:: :meth:`pyspark.sql.GroupedData.apply` .. note:: The user-defined function must be deterministic. @@ -2306,7 +2201,44 @@ def pandas_udf(f=None, returnType=StringType()): rows that do not satisfy the conditions, the suggested workaround is to incorporate the condition logic into the functions. """ - return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) + # decorator @pandas_udf(returnType, functionType) + is_decorator = f is None or isinstance(f, (str, DataType)) + + if is_decorator: + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + + if functionType is not None: + # @pandas_udf(dataType, functionType=functionType) + # @pandas_udf(returnType=dataType, functionType=functionType) + eval_type = functionType + elif returnType is not None and isinstance(returnType, int): + # @pandas_udf(dataType, functionType) + eval_type = returnType + else: + # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) + eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + else: + return_type = returnType + + if functionType is not None: + eval_type = functionType + else: + eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + + if return_type is None: + raise ValueError("Invalid returnType: returnType can not be None") + + if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + raise ValueError("Invalid functionType: " + "functionType must be one the values from PandasUDFType") + + if is_decorator: + return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) + else: + return _create_udf(f=f, returnType=return_type, evalType=eval_type) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index e11388d6043123bca19c53a817e893dd4a9c83a9..4d47dd6a3e878e1cba5113f84c05b47ba93e95b9 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -16,10 +16,10 @@ # from pyspark import since -from pyspark.rdd import ignore_unicode_prefix +from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import PythonUdfType, UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -214,15 +214,15 @@ class GroupedData(object): :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` - >>> from pyspark.sql.functions import pandas_udf + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -236,44 +236,13 @@ class GroupedData(object): .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ - import inspect - # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ - or len(inspect.getargspec(udf.func).args) != 1: - raise ValueError("The argument to apply must be a 1-arg pandas_udf") - if not isinstance(udf.returnType, StructType): - raise ValueError("The returnType of the pandas_udf must be a StructType") - + or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " + "GROUP_MAP.") df = self._df - func = udf.func - returnType = udf.returnType - - # The python executors expects the function to use pd.Series as input and output - # So we to create a wrapper function that turns that to a pd.DataFrame before passing - # down to the user function, then turn the result pd.DataFrame back into pd.Series - columns = df.columns - - def wrapped(*cols): - from pyspark.sql.types import to_arrow_type - import pandas as pd - result = func(pd.concat(cols, axis=1, keys=columns)) - if not isinstance(result, pd.DataFrame): - raise TypeError("Return type of the user-defined function should be " - "Pandas.DataFrame, but is {}".format(type(result))) - if not len(result.columns) == len(returnType): - raise RuntimeError( - "Number of columns of the returned Pandas.DataFrame " - "doesn't match specified schema. " - "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] - - udf_obj = UserDefinedFunction( - wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) - udf_column = udf_obj(*[df[col] for col in df.columns]) + udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ef592c2356a8c3421a85f24288e954dbcc10b184..762afe0d730f313f5e3d6caebd274ea9f14000b7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3259,6 +3259,129 @@ class ArrowTests(ReusedSQLTestCase): self.assertEquals(self.schema, schema_rt) +class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + udf = pandas_udf(lambda x: x, DoubleType()) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), + PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, 'v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, returnType='v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_pandas_udf_decorator(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql.types import StructType, StructField, DoubleType + + @pandas_udf(DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + @pandas_udf(returnType=DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + schema = StructType([StructField("v", DoubleType())]) + + @pandas_udf(schema, PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf('v double', PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_udf_wrong_arg(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaises(ParseException): + @pandas_udf('blah') + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'): + @pandas_udf(functionType=PandasUDFType.SCALAR) + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid functionType'): + @pandas_udf('double', 100) + def foo(x): + return x + + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR) + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + @pandas_udf(LongType(), PandasUDFType.SCALAR) + def zero_with_type(): + return 1 + + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): + @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid function'): + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP) + def foo(k, v): + return k + + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): @@ -3355,23 +3478,6 @@ class VectorizedUDFTests(ReusedSQLTestCase): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) - def test_vectorized_udf_zero_parameter(self): - from pyspark.sql.functions import pandas_udf - error_str = '0-arg pandas_udfs.*not.*supported' - with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, error_str): - pandas_udf(lambda: 1, LongType()) - - with self.assertRaisesRegexp(ValueError, error_str): - @pandas_udf - def zero_no_type(): - return 1 - - with self.assertRaisesRegexp(ValueError, error_str): - @pandas_udf(LongType()) - def zero_with_type(): - return 1 - def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3570,7 +3676,7 @@ class VectorizedUDFTests(ReusedSQLTestCase): return x result = df.select(check_records_per_batch(col("id"))) - self.assertEquals(df.collect(), result.collect()) + self.assertEqual(df.collect(), result.collect()) finally: if orig_value is None: self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") @@ -3595,7 +3701,7 @@ class GroupbyApplyTests(ReusedSQLTestCase): .withColumn("v", explode(col('vs'))).drop('vs') def test_simple(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( @@ -3604,21 +3710,22 @@ class GroupbyApplyTests(ReusedSQLTestCase): [StructField('id', LongType()), StructField('v', IntegerType()), StructField('v1', DoubleType()), - StructField('v2', LongType())])) + StructField('v2', LongType())]), + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) def test_decorator(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('v1', DoubleType()), - StructField('v2', LongType())])) + @pandas_udf( + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUP_MAP + ) def foo(pdf): return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) @@ -3627,12 +3734,14 @@ class GroupbyApplyTests(ReusedSQLTestCase): self.assertFramesEqual(expected, result) def test_coerce(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( lambda pdf: pdf, - StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + 'id long, v double', + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) @@ -3640,13 +3749,13 @@ class GroupbyApplyTests(ReusedSQLTestCase): self.assertFramesEqual(expected, result) def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('norm', DoubleType())])) + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUP_MAP + ) def normalize(pdf): v = pdf.v return pdf.assign(norm=(v - v.mean()) / v.std()) @@ -3659,13 +3768,13 @@ class GroupbyApplyTests(ReusedSQLTestCase): self.assertFramesEqual(expected, result) def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data - @pandas_udf(StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('norm', DoubleType())])) + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUP_MAP + ) def normalize(pdf): v = pdf.v return pdf.assign(norm=(v - v.mean()) / v.std()) @@ -3678,57 +3787,63 @@ class GroupbyApplyTests(ReusedSQLTestCase): self.assertFramesEqual(expected, result) def test_datatype_string(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - "id long, v int, v1 double, v2 long") + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUP_MAP + ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( lambda pdf: pdf, - StructType([StructField('id', LongType()), StructField('v', StringType())])) + 'id long, v string', + PandasUDFType.GROUP_MAP + ) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.groupby('id').apply(foo).sort('id').toPandas() def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum + from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType df = self.data with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(lambda x: x) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(udf(lambda x: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(sum(df.v)) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply(df.v + 1) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid function'): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'returnType'): - df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), + PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema) + f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py new file mode 100644 index 0000000000000000000000000000000000000000..c3301a41ccd5aa49acca35d157314dad663c1210 --- /dev/null +++ b/python/pyspark/sql/udf.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +User-defined function related classes and functions +""" +import functools + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string + + +def _wrap_function(sc, func, returnType): + command = (func, returnType) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + +def _create_udf(f, returnType, evalType): + if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + import inspect + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: + raise ValueError( + "Invalid function: 0-arg pandas_udfs are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) + + elif evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + import inspect + argspec = inspect.getargspec(f) + if len(argspec.args) != 1: + raise ValueError( + "Invalid function: pandas_udfs with function type GROUP_MAP " + "must take a single arg that is a pandas DataFrame." + ) + + # Set the name of the UserDefinedFunction object to be the name of function f + udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + return udf_obj._wrapped() + + +class UserDefinedFunction(object): + """ + User defined function in Python + + .. versionadded:: 1.3 + """ + def __init__(self, func, + returnType=StringType(), name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF): + if not callable(func): + raise TypeError( + "Invalid function: not a function or callable (__call__ is not defined): " + "{0}".format(type(func))) + + if not isinstance(returnType, (DataType, str)): + raise TypeError( + "Invalid returnType: returnType should be DataType or str " + "but is {}".format(returnType)) + + if not isinstance(evalType, int): + raise TypeError( + "Invalid evalType: evalType should be an int but is {}".format(evalType)) + + self.func = func + self._returnType = returnType + # Stores UserDefinedPythonFunctions jobj, once initialized + self._returnType_placeholder = None + self._judf_placeholder = None + self._name = name or ( + func.__name__ if hasattr(func, '__name__') + else func.__class__.__name__) + self.evalType = evalType + + @property + def returnType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._returnType_placeholder is None: + if isinstance(self._returnType, DataType): + self._returnType_placeholder = self._returnType + else: + self._returnType_placeholder = _parse_datatype_string(self._returnType) + + if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + and not isinstance(self._returnType_placeholder, StructType): + raise ValueError("Invalid returnType: returnType must be a StructType for " + "pandas_udf with function type GROUP_MAP") + + return self._returnType_placeholder + + @property + def _judf(self): + # It is possible that concurrent access, to newly created UDF, + # will initialize multiple UserDefinedPythonFunctions. + # This is unlikely, doesn't affect correctness, + # and should have a minimal performance impact. + if self._judf_placeholder is None: + self._judf_placeholder = self._create_judf() + return self._judf_placeholder + + def _create_judf(self): + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + wrapped_func = _wrap_function(sc, self.func, self.returnType) + jdt = spark._jsparkSession.parseDataType(self.returnType.json()) + judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( + self._name, wrapped_func, jdt, self.evalType) + return judf + + def __call__(self, *cols): + judf = self._judf + sc = SparkContext._active_spark_context + return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') + + @functools.wraps(self.func, assigned=assignments) + def wrapper(*args): + return self(*args) + + wrapper.__name__ = self._name + wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') + else self.func.__class__.__module__) + + wrapper.func = self.func + wrapper.returnType = self.returnType + wrapper.evalType = self.evalType + + return wrapper diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5e100e0a9a95d827ce4aeb59edeecba5bbc41908..939643071943ab91d470c66e8c6a9c5d520fac1d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -29,8 +29,9 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles +from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type from pyspark import shuffle @@ -73,7 +74,7 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def wrap_pandas_udf(f, return_type): +def wrap_pandas_scalar_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): @@ -89,6 +90,26 @@ def wrap_pandas_udf(f, return_type): return lambda *a: (verify_result_length(*a), arrow_return_type) +def wrap_pandas_group_map_udf(f, return_type): + def wrapped(*series): + import pandas as pd + + result = f(pd.concat(series, axis=1)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] + + return wrapped + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -99,12 +120,12 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = f else: row_func = chain(row_func, f) + # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_UDF: - return arg_offsets, wrap_pandas_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: - # a groupby apply udf has already been wrapped under apply() - return arg_offsets, row_func + if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF: + return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) @@ -127,8 +148,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 3e4edd4ea8cf39bb2f87d42ca527e603048cba66..a009c00b0abc5998af7cf4912d0f8ecde4d75d1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} +import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -449,10 +450,10 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, - "Must pass a grouped vectorized python udf") + require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + "Must pass a group map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the vectorized python udf must be a StructType") + "The returnType of the udf must be a StructType") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index bcda2dae92e53617934f94f192f2a8b162e29b0b..e27210117a1e7baf920ed70b22349ae2738b6024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e15e760136e81d95d0061df2aefa8ad203773098..f5a4cbc4793e30e664cda01681b09cd9600c8f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -148,15 +149,18 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { - throw new IllegalArgumentException("Can not use grouped vectorized UDFs") - } + require(validUdfs.forall(udf => + udf.evalType == PythonEvalType.SQL_BATCHED_UDF || + udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + ), "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { + val evaluation = validUdfs.partition( + _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e1e04e34e0c7152d8bff401ca4d76013f649d967..ee495814b82555b8ba769b105d422e060dd0d5cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -95,7 +95,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9c07c7638de5756994bf9344a1bc9605525168a0..ef27fbc2db7d92cd4470dc645661cdd280350da3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - pythonUdfType: Int) + evalType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index b2fe6c300846a175fae7a3dcbf6f5ce230f1f53e..348e49e473ed3a03486b45b894f1ea64cbd00d5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,15 +22,6 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType -private[spark] object PythonUdfType { - // row-at-a-time UDFs - val NORMAL_UDF = 0 - // scalar vectorized UDFs - val PANDAS_UDF = 1 - // grouped vectorized UDFs - val PANDAS_GROUPED_UDF = 2 -} - /** * A user-defined Python function. This is used by the Python API. */ @@ -38,10 +29,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - pythonUdfType: Int) { + pythonEvalType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, pythonUdfType) + PythonUDF(name, func, dataType, e, pythonEvalType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 95b21fc9f16ae43947590c8d8690519b337b5759..53d3f34567518337eab001e033285ac01c2a7f01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.api.python.PythonFunction +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - pythonUdfType = PythonUdfType.NORMAL_UDF) + pythonEvalType = PythonEvalType.SQL_BATCHED_UDF)