diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 9989f68f8508c247dab4e47e653a69328eaefdf7..f524de68fbce027b28f814d5d9f27de0c09132ab 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -34,9 +34,11 @@ import org.apache.spark.util._
  */
 private[spark] object PythonEvalType {
   val NON_UDF = 0
-  val SQL_BATCHED_UDF = 1
-  val SQL_PANDAS_UDF = 2
-  val SQL_PANDAS_GROUPED_UDF = 3
+
+  val SQL_BATCHED_UDF = 100
+
+  val SQL_PANDAS_SCALAR_UDF = 200
+  val SQL_PANDAS_GROUP_MAP_UDF = 201
 }
 
 /**
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index ea993c572fafd2a5c07e064a527143816fb7d2be..340bc3a6b74704727ba5c334dd6353c7d699bb98 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -56,6 +56,22 @@ from pyspark.traceback_utils import SCCallSiteSync
 __all__ = ["RDD"]
 
 
+class PythonEvalType(object):
+    """
+    Evaluation type of python rdd.
+
+    These values are internal to PySpark.
+
+    These values should match values in org.apache.spark.api.python.PythonEvalType.
+    """
+    NON_UDF = 0
+
+    SQL_BATCHED_UDF = 100
+
+    SQL_PANDAS_SCALAR_UDF = 200
+    SQL_PANDAS_GROUP_MAP_UDF = 201
+
+
 def portable_hash(x):
     """
     This function returns consistent hash code for builtin types, especially
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index e0afdafbfcd625ca07c2f6b1e351c08c551dd6a0..b95de2c804394e6f6046f399dced475a61017a1a 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -82,13 +82,6 @@ class SpecialLengths(object):
     START_ARROW_STREAM = -6
 
 
-class PythonEvalType(object):
-    NON_UDF = 0
-    SQL_BATCHED_UDF = 1
-    SQL_PANDAS_UDF = 2
-    SQL_PANDAS_GROUPED_UDF = 3
-
-
 class Serializer(object):
 
     def dump_stream(self, iterator, stream):
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 5f25dce161963f971e95f9c25ea50f193b9042ed..659bc65701a0cbcc9b2aaa27c19b2883e24bab6f 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -19,9 +19,9 @@ import warnings
 from collections import namedtuple
 
 from pyspark import since
-from pyspark.rdd import ignore_unicode_prefix
+from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
 from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.types import IntegerType, StringType, StructType
 
 
@@ -256,7 +256,8 @@ class Catalog(object):
         >>> spark.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
         """
-        udf = UserDefinedFunction(f, returnType, name)
+        udf = UserDefinedFunction(f, returnType=returnType, name=name,
+                                  evalType=PythonEvalType.SQL_BATCHED_UDF)
         self._jsparkSession.udf().registerPython(name, udf._judf)
         return udf._wrapped()
 
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 087ce7caa89c809748b9060f45808d470ce19c25..b631e2041706f94fc48165331ed8b4a047cf21dd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -27,11 +27,12 @@ if sys.version < "3":
     from itertools import imap as map
 
 from pyspark import since, SparkContext
-from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
+from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.sql.types import StringType, DataType, _parse_datatype_string
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.types import StringType, DataType
+from pyspark.sql.udf import UserDefinedFunction, _create_udf
 
 
 def _create_function(name, doc=""):
@@ -2062,132 +2063,12 @@ def map_values(col):
 
 # ---------------------------- User Defined Function ----------------------------------
 
-def _wrap_function(sc, func, returnType):
-    command = (func, returnType)
-    pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
-    return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
-                                  sc.pythonVer, broadcast_vars, sc._javaAccumulator)
-
-
-class PythonUdfType(object):
-    # row-at-a-time UDFs
-    NORMAL_UDF = 0
-    # scalar vectorized UDFs
-    PANDAS_UDF = 1
-    # grouped vectorized UDFs
-    PANDAS_GROUPED_UDF = 2
-
-
-class UserDefinedFunction(object):
-    """
-    User defined function in Python
-
-    .. versionadded:: 1.3
-    """
-    def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF):
-        if not callable(func):
-            raise TypeError(
-                "Not a function or callable (__call__ is not defined): "
-                "{0}".format(type(func)))
-
-        self.func = func
-        self._returnType = returnType
-        # Stores UserDefinedPythonFunctions jobj, once initialized
-        self._returnType_placeholder = None
-        self._judf_placeholder = None
-        self._name = name or (
-            func.__name__ if hasattr(func, '__name__')
-            else func.__class__.__name__)
-        self.pythonUdfType = pythonUdfType
-
-    @property
-    def returnType(self):
-        # This makes sure this is called after SparkContext is initialized.
-        # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
-        if self._returnType_placeholder is None:
-            if isinstance(self._returnType, DataType):
-                self._returnType_placeholder = self._returnType
-            else:
-                self._returnType_placeholder = _parse_datatype_string(self._returnType)
-        return self._returnType_placeholder
-
-    @property
-    def _judf(self):
-        # It is possible that concurrent access, to newly created UDF,
-        # will initialize multiple UserDefinedPythonFunctions.
-        # This is unlikely, doesn't affect correctness,
-        # and should have a minimal performance impact.
-        if self._judf_placeholder is None:
-            self._judf_placeholder = self._create_judf()
-        return self._judf_placeholder
-
-    def _create_judf(self):
-        from pyspark.sql import SparkSession
-
-        spark = SparkSession.builder.getOrCreate()
-        sc = spark.sparkContext
-
-        wrapped_func = _wrap_function(sc, self.func, self.returnType)
-        jdt = spark._jsparkSession.parseDataType(self.returnType.json())
-        judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
-            self._name, wrapped_func, jdt, self.pythonUdfType)
-        return judf
-
-    def __call__(self, *cols):
-        judf = self._judf
-        sc = SparkContext._active_spark_context
-        return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
-
-    def _wrapped(self):
-        """
-        Wrap this udf with a function and attach docstring from func
-        """
-
-        # It is possible for a callable instance without __name__ attribute or/and
-        # __module__ attribute to be wrapped here. For example, functools.partial. In this case,
-        # we should avoid wrapping the attributes from the wrapped function to the wrapper
-        # function. So, we take out these attribute names from the default names to set and
-        # then manually assign it after being wrapped.
-        assignments = tuple(
-            a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
-
-        @functools.wraps(self.func, assigned=assignments)
-        def wrapper(*args):
-            return self(*args)
-
-        wrapper.__name__ = self._name
-        wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
-                              else self.func.__class__.__module__)
-
-        wrapper.func = self.func
-        wrapper.returnType = self.returnType
-        wrapper.pythonUdfType = self.pythonUdfType
-
-        return wrapper
-
-
-def _create_udf(f, returnType, pythonUdfType):
-
-    def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType):
-        if pythonUdfType == PythonUdfType.PANDAS_UDF:
-            import inspect
-            argspec = inspect.getargspec(f)
-            if len(argspec.args) == 0 and argspec.varargs is None:
-                raise ValueError(
-                    "0-arg pandas_udfs are not supported. "
-                    "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
-                )
-        udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType)
-        return udf_obj._wrapped()
-
-    # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
-    if f is None or isinstance(f, (str, DataType)):
-        # If DataType has been passed as a positional argument
-        # for decorator use it as a returnType
-        return_type = f or returnType
-        return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType)
-    else:
-        return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType)
+class PandasUDFType(object):
+    """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`.
+    """
+    SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF
+
+    GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF
 
 
 @since(1.3)
@@ -2228,33 +2109,47 @@ def udf(f=None, returnType=StringType()):
     |         8|      JOHN DOE|          22|
     +----------+--------------+------------+
     """
-    return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF)
+    # decorator @udf, @udf(), @udf(dataType())
+    if f is None or isinstance(f, (str, DataType)):
+        # If DataType has been passed as a positional argument
+        # for decorator use it as a returnType
+        return_type = f or returnType
+        return functools.partial(_create_udf, returnType=return_type,
+                                 evalType=PythonEvalType.SQL_BATCHED_UDF)
+    else:
+        return _create_udf(f=f, returnType=returnType,
+                           evalType=PythonEvalType.SQL_BATCHED_UDF)
 
 
 @since(2.3)
-def pandas_udf(f=None, returnType=StringType()):
+def pandas_udf(f=None, returnType=None, functionType=None):
     """
     Creates a vectorized user defined function (UDF).
 
     :param f: user-defined function. A python function if used as a standalone function
     :param returnType: a :class:`pyspark.sql.types.DataType` object
+    :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
+                         Default: SCALAR.
 
-    The user-defined function can define one of the following transformations:
+    The function type of the UDF can be one of the following:
 
-    1. One or more `pandas.Series` -> A `pandas.Series`
+    1. SCALAR
 
-       This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
-       :meth:`pyspark.sql.DataFrame.select`.
+       A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
        The returnType should be a primitive data type, e.g., `DoubleType()`.
        The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
 
+       Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
+       :meth:`pyspark.sql.DataFrame.select`.
+
+       >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
        >>> from pyspark.sql.types import IntegerType, StringType
        >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
-       >>> @pandas_udf(returnType=StringType())
+       >>> @pandas_udf(StringType())
        ... def to_upper(s):
        ...     return s.str.upper()
        ...
-       >>> @pandas_udf(returnType="integer")
+       >>> @pandas_udf("integer", PandasUDFType.SCALAR)
        ... def add_one(x):
        ...     return x + 1
        ...
@@ -2267,20 +2162,24 @@ def pandas_udf(f=None, returnType=StringType()):
        |         8|      JOHN DOE|          22|
        +----------+--------------+------------+
 
-    2. A `pandas.DataFrame` -> A `pandas.DataFrame`
+    2. GROUP_MAP
 
-       This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
+       A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
        The returnType should be a :class:`StructType` describing the schema of the returned
        `pandas.DataFrame`.
+       The length of the returned `pandas.DataFrame` can be arbitrary.
+
+       Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
 
+       >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
        >>> df = spark.createDataFrame(
        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
        ...     ("id", "v"))
-       >>> @pandas_udf(returnType=df.schema)
+       >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP)
        ... def normalize(pdf):
        ...     v = pdf.v
        ...     return pdf.assign(v=(v - v.mean()) / v.std())
-       >>> df.groupby('id').apply(normalize).show()  # doctest: +SKIP
+       >>> df.groupby("id").apply(normalize).show()  # doctest: +SKIP
        +---+-------------------+
        | id|                  v|
        +---+-------------------+
@@ -2291,10 +2190,6 @@ def pandas_udf(f=None, returnType=StringType()):
        |  2| 1.1094003924504583|
        +---+-------------------+
 
-       .. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
-                 because it defines a `DataFrame` transformation rather than a `Column`
-                 transformation.
-
        .. seealso:: :meth:`pyspark.sql.GroupedData.apply`
 
     .. note:: The user-defined function must be deterministic.
@@ -2306,7 +2201,44 @@ def pandas_udf(f=None, returnType=StringType()):
         rows that do not satisfy the conditions, the suggested workaround is to incorporate the
         condition logic into the functions.
     """
-    return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF)
+    # decorator @pandas_udf(returnType, functionType)
+    is_decorator = f is None or isinstance(f, (str, DataType))
+
+    if is_decorator:
+        # If DataType has been passed as a positional argument
+        # for decorator use it as a returnType
+        return_type = f or returnType
+
+        if functionType is not None:
+            # @pandas_udf(dataType, functionType=functionType)
+            # @pandas_udf(returnType=dataType, functionType=functionType)
+            eval_type = functionType
+        elif returnType is not None and isinstance(returnType, int):
+            # @pandas_udf(dataType, functionType)
+            eval_type = returnType
+        else:
+            # @pandas_udf(dataType) or @pandas_udf(returnType=dataType)
+            eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF
+    else:
+        return_type = returnType
+
+        if functionType is not None:
+            eval_type = functionType
+        else:
+            eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF
+
+    if return_type is None:
+        raise ValueError("Invalid returnType: returnType can not be None")
+
+    if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF,
+                         PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]:
+        raise ValueError("Invalid functionType: "
+                         "functionType must be one the values from PandasUDFType")
+
+    if is_decorator:
+        return functools.partial(_create_udf, returnType=return_type, evalType=eval_type)
+    else:
+        return _create_udf(f=f, returnType=return_type, evalType=eval_type)
 
 
 blacklist = ['map', 'since', 'ignore_unicode_prefix']
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index e11388d6043123bca19c53a817e893dd4a9c83a9..4d47dd6a3e878e1cba5113f84c05b47ba93e95b9 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -16,10 +16,10 @@
 #
 
 from pyspark import since
-from pyspark.rdd import ignore_unicode_prefix
+from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
 from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
 from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.functions import PythonUdfType, UserDefinedFunction
+from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.types import *
 
 __all__ = ["GroupedData"]
@@ -214,15 +214,15 @@ class GroupedData(object):
 
         :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
 
-        >>> from pyspark.sql.functions import pandas_udf
+        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
         ...     ("id", "v"))
-        >>> @pandas_udf(returnType=df.schema)
+        >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP)
         ... def normalize(pdf):
         ...     v = pdf.v
         ...     return pdf.assign(v=(v - v.mean()) / v.std())
-        >>> df.groupby('id').apply(normalize).show()  # doctest: +SKIP
+        >>> df.groupby("id").apply(normalize).show()  # doctest: +SKIP
         +---+-------------------+
         | id|                  v|
         +---+-------------------+
@@ -236,44 +236,13 @@ class GroupedData(object):
         .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
 
         """
-        import inspect
-
         # Columns are special because hasattr always return True
         if isinstance(udf, Column) or not hasattr(udf, 'func') \
-           or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \
-           or len(inspect.getargspec(udf.func).args) != 1:
-            raise ValueError("The argument to apply must be a 1-arg pandas_udf")
-        if not isinstance(udf.returnType, StructType):
-            raise ValueError("The returnType of the pandas_udf must be a StructType")
-
+           or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
+            raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
+                             "GROUP_MAP.")
         df = self._df
-        func = udf.func
-        returnType = udf.returnType
-
-        # The python executors expects the function to use pd.Series as input and output
-        # So we to create a wrapper function that turns that to a pd.DataFrame before passing
-        # down to the user function, then turn the result pd.DataFrame back into pd.Series
-        columns = df.columns
-
-        def wrapped(*cols):
-            from pyspark.sql.types import to_arrow_type
-            import pandas as pd
-            result = func(pd.concat(cols, axis=1, keys=columns))
-            if not isinstance(result, pd.DataFrame):
-                raise TypeError("Return type of the user-defined function should be "
-                                "Pandas.DataFrame, but is {}".format(type(result)))
-            if not len(result.columns) == len(returnType):
-                raise RuntimeError(
-                    "Number of columns of the returned Pandas.DataFrame "
-                    "doesn't match specified schema. "
-                    "Expected: {} Actual: {}".format(len(returnType), len(result.columns)))
-            arrow_return_types = (to_arrow_type(field.dataType) for field in returnType)
-            return [(result[result.columns[i]], arrow_type)
-                    for i, arrow_type in enumerate(arrow_return_types)]
-
-        udf_obj = UserDefinedFunction(
-            wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF)
-        udf_column = udf_obj(*[df[col] for col in df.columns])
+        udf_column = udf(*[df[col] for col in df.columns])
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.sql_ctx)
 
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ef592c2356a8c3421a85f24288e954dbcc10b184..762afe0d730f313f5e3d6caebd274ea9f14000b7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3259,6 +3259,129 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertEquals(self.schema, schema_rt)
 
 
+class PandasUDFTests(ReusedSQLTestCase):
+    def test_pandas_udf_basic(self):
+        from pyspark.rdd import PythonEvalType
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        udf = pandas_udf(lambda x: x, DoubleType())
+        self.assertEqual(udf.returnType, DoubleType())
+        self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+
+        udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
+        self.assertEqual(udf.returnType, DoubleType())
+        self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+
+        udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR)
+        self.assertEqual(udf.returnType, DoubleType())
+        self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+
+        udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]),
+                         PandasUDFType.GROUP_MAP)
+        self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+        self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+        udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP)
+        self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+        self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+        udf = pandas_udf(lambda x: x, 'v double',
+                         functionType=PandasUDFType.GROUP_MAP)
+        self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+        self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+        udf = pandas_udf(lambda x: x, returnType='v double',
+                         functionType=PandasUDFType.GROUP_MAP)
+        self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+        self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+    def test_pandas_udf_decorator(self):
+        from pyspark.rdd import PythonEvalType
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+        from pyspark.sql.types import StructType, StructField, DoubleType
+
+        @pandas_udf(DoubleType())
+        def foo(x):
+            return x
+        self.assertEqual(foo.returnType, DoubleType())
+        self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+
+        @pandas_udf(returnType=DoubleType())
+        def foo(x):
+            return x
+        self.assertEqual(foo.returnType, DoubleType())
+        self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+
+        schema = StructType([StructField("v", DoubleType())])
+
+        @pandas_udf(schema, PandasUDFType.GROUP_MAP)
+        def foo(x):
+            return x
+        self.assertEqual(foo.returnType, schema)
+        self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+        @pandas_udf('v double', PandasUDFType.GROUP_MAP)
+        def foo(x):
+            return x
+        self.assertEqual(foo.returnType, schema)
+        self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+        @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP)
+        def foo(x):
+            return x
+        self.assertEqual(foo.returnType, schema)
+        self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+        @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR)
+        def foo(x):
+            return x
+        self.assertEqual(foo.returnType, schema)
+        self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+
+        @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP)
+        def foo(x):
+            return x
+        self.assertEqual(foo.returnType, schema)
+        self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
+
+    def test_udf_wrong_arg(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        with QuietTest(self.sc):
+            with self.assertRaises(ParseException):
+                @pandas_udf('blah')
+                def foo(x):
+                    return x
+            with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'):
+                @pandas_udf(functionType=PandasUDFType.SCALAR)
+                def foo(x):
+                    return x
+            with self.assertRaisesRegexp(ValueError, 'Invalid functionType'):
+                @pandas_udf('double', 100)
+                def foo(x):
+                    return x
+
+            with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
+                pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR)
+            with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
+                @pandas_udf(LongType(), PandasUDFType.SCALAR)
+                def zero_with_type():
+                    return 1
+
+            with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
+                @pandas_udf(returnType=PandasUDFType.GROUP_MAP)
+                def foo(df):
+                    return df
+            with self.assertRaisesRegexp(ValueError, 'Invalid returnType'):
+                @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP)
+                def foo(df):
+                    return df
+            with self.assertRaisesRegexp(ValueError, 'Invalid function'):
+                @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP)
+                def foo(k, v):
+                    return k
+
+
 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
 class VectorizedUDFTests(ReusedSQLTestCase):
 
@@ -3355,23 +3478,6 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         res = df.select(str_f(col('str')))
         self.assertEquals(df.collect(), res.collect())
 
-    def test_vectorized_udf_zero_parameter(self):
-        from pyspark.sql.functions import pandas_udf
-        error_str = '0-arg pandas_udfs.*not.*supported'
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(ValueError, error_str):
-                pandas_udf(lambda: 1, LongType())
-
-            with self.assertRaisesRegexp(ValueError, error_str):
-                @pandas_udf
-                def zero_no_type():
-                    return 1
-
-            with self.assertRaisesRegexp(ValueError, error_str):
-                @pandas_udf(LongType())
-                def zero_with_type():
-                    return 1
-
     def test_vectorized_udf_datatype_string(self):
         from pyspark.sql.functions import pandas_udf, col
         df = self.spark.range(10).select(
@@ -3570,7 +3676,7 @@ class VectorizedUDFTests(ReusedSQLTestCase):
                 return x
 
             result = df.select(check_records_per_batch(col("id")))
-            self.assertEquals(df.collect(), result.collect())
+            self.assertEqual(df.collect(), result.collect())
         finally:
             if orig_value is None:
                 self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
@@ -3595,7 +3701,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
             .withColumn("v", explode(col('vs'))).drop('vs')
 
     def test_simple(self):
-        from pyspark.sql.functions import pandas_udf
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
         df = self.data
 
         foo_udf = pandas_udf(
@@ -3604,21 +3710,22 @@ class GroupbyApplyTests(ReusedSQLTestCase):
                 [StructField('id', LongType()),
                  StructField('v', IntegerType()),
                  StructField('v1', DoubleType()),
-                 StructField('v2', LongType())]))
+                 StructField('v2', LongType())]),
+            PandasUDFType.GROUP_MAP
+        )
 
         result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
         self.assertFramesEqual(expected, result)
 
     def test_decorator(self):
-        from pyspark.sql.functions import pandas_udf
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
         df = self.data
 
-        @pandas_udf(StructType(
-            [StructField('id', LongType()),
-             StructField('v', IntegerType()),
-             StructField('v1', DoubleType()),
-             StructField('v2', LongType())]))
+        @pandas_udf(
+            'id long, v int, v1 double, v2 long',
+            PandasUDFType.GROUP_MAP
+        )
         def foo(pdf):
             return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
 
@@ -3627,12 +3734,14 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         self.assertFramesEqual(expected, result)
 
     def test_coerce(self):
-        from pyspark.sql.functions import pandas_udf
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
         df = self.data
 
         foo = pandas_udf(
             lambda pdf: pdf,
-            StructType([StructField('id', LongType()), StructField('v', DoubleType())]))
+            'id long, v double',
+            PandasUDFType.GROUP_MAP
+        )
 
         result = df.groupby('id').apply(foo).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
@@ -3640,13 +3749,13 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         self.assertFramesEqual(expected, result)
 
     def test_complex_groupby(self):
-        from pyspark.sql.functions import pandas_udf, col
+        from pyspark.sql.functions import pandas_udf, col, PandasUDFType
         df = self.data
 
-        @pandas_udf(StructType(
-            [StructField('id', LongType()),
-             StructField('v', IntegerType()),
-             StructField('norm', DoubleType())]))
+        @pandas_udf(
+            'id long, v int, norm double',
+            PandasUDFType.GROUP_MAP
+        )
         def normalize(pdf):
             v = pdf.v
             return pdf.assign(norm=(v - v.mean()) / v.std())
@@ -3659,13 +3768,13 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         self.assertFramesEqual(expected, result)
 
     def test_empty_groupby(self):
-        from pyspark.sql.functions import pandas_udf, col
+        from pyspark.sql.functions import pandas_udf, col, PandasUDFType
         df = self.data
 
-        @pandas_udf(StructType(
-            [StructField('id', LongType()),
-             StructField('v', IntegerType()),
-             StructField('norm', DoubleType())]))
+        @pandas_udf(
+            'id long, v int, norm double',
+            PandasUDFType.GROUP_MAP
+        )
         def normalize(pdf):
             v = pdf.v
             return pdf.assign(norm=(v - v.mean()) / v.std())
@@ -3678,57 +3787,63 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         self.assertFramesEqual(expected, result)
 
     def test_datatype_string(self):
-        from pyspark.sql.functions import pandas_udf
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
         df = self.data
 
         foo_udf = pandas_udf(
             lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
-            "id long, v int, v1 double, v2 long")
+            'id long, v int, v1 double, v2 long',
+            PandasUDFType.GROUP_MAP
+        )
 
         result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
         self.assertFramesEqual(expected, result)
 
     def test_wrong_return_type(self):
-        from pyspark.sql.functions import pandas_udf
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
         df = self.data
 
         foo = pandas_udf(
             lambda pdf: pdf,
-            StructType([StructField('id', LongType()), StructField('v', StringType())]))
+            'id long, v string',
+            PandasUDFType.GROUP_MAP
+        )
 
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(Exception, 'Invalid.*type'):
                 df.groupby('id').apply(foo).sort('id').toPandas()
 
     def test_wrong_args(self):
-        from pyspark.sql.functions import udf, pandas_udf, sum
+        from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
         df = self.data
 
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+            with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
                 df.groupby('id').apply(lambda x: x)
-            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+            with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
                 df.groupby('id').apply(udf(lambda x: x, DoubleType()))
-            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+            with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
                 df.groupby('id').apply(sum(df.v))
-            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+            with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
                 df.groupby('id').apply(df.v + 1)
-            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+            with self.assertRaisesRegexp(ValueError, 'Invalid function'):
                 df.groupby('id').apply(
                     pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
-            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+            with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
                 df.groupby('id').apply(
                     pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
-            with self.assertRaisesRegexp(ValueError, 'returnType'):
-                df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType()))
+            with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'):
+                df.groupby('id').apply(
+                    pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]),
+                               PandasUDFType.SCALAR))
 
     def test_unsupported_types(self):
-        from pyspark.sql.functions import pandas_udf, col
+        from pyspark.sql.functions import pandas_udf, col, PandasUDFType
         schema = StructType(
             [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
         df = self.spark.createDataFrame([(1, None,)], schema=schema)
-        f = pandas_udf(lambda x: x, df.schema)
+        f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
                 df.groupby('id').apply(f).collect()
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3301a41ccd5aa49acca35d157314dad663c1210
--- /dev/null
+++ b/python/pyspark/sql/udf.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined function related classes and functions
+"""
+import functools
+
+from pyspark import SparkContext
+from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
+from pyspark.sql.column import Column, _to_java_column, _to_seq
+from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
+
+
+def _wrap_function(sc, func, returnType):
+    command = (func, returnType)
+    pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
+    return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
+                                  sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+
+
+def _create_udf(f, returnType, evalType):
+    if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF:
+        import inspect
+        argspec = inspect.getargspec(f)
+        if len(argspec.args) == 0 and argspec.varargs is None:
+            raise ValueError(
+                "Invalid function: 0-arg pandas_udfs are not supported. "
+                "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
+            )
+
+    elif evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
+        import inspect
+        argspec = inspect.getargspec(f)
+        if len(argspec.args) != 1:
+            raise ValueError(
+                "Invalid function: pandas_udfs with function type GROUP_MAP "
+                "must take a single arg that is a pandas DataFrame."
+            )
+
+    # Set the name of the UserDefinedFunction object to be the name of function f
+    udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
+    return udf_obj._wrapped()
+
+
+class UserDefinedFunction(object):
+    """
+    User defined function in Python
+
+    .. versionadded:: 1.3
+    """
+    def __init__(self, func,
+                 returnType=StringType(), name=None,
+                 evalType=PythonEvalType.SQL_BATCHED_UDF):
+        if not callable(func):
+            raise TypeError(
+                "Invalid function: not a function or callable (__call__ is not defined): "
+                "{0}".format(type(func)))
+
+        if not isinstance(returnType, (DataType, str)):
+            raise TypeError(
+                "Invalid returnType: returnType should be DataType or str "
+                "but is {}".format(returnType))
+
+        if not isinstance(evalType, int):
+            raise TypeError(
+                "Invalid evalType: evalType should be an int but is {}".format(evalType))
+
+        self.func = func
+        self._returnType = returnType
+        # Stores UserDefinedPythonFunctions jobj, once initialized
+        self._returnType_placeholder = None
+        self._judf_placeholder = None
+        self._name = name or (
+            func.__name__ if hasattr(func, '__name__')
+            else func.__class__.__name__)
+        self.evalType = evalType
+
+    @property
+    def returnType(self):
+        # This makes sure this is called after SparkContext is initialized.
+        # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
+        if self._returnType_placeholder is None:
+            if isinstance(self._returnType, DataType):
+                self._returnType_placeholder = self._returnType
+            else:
+                self._returnType_placeholder = _parse_datatype_string(self._returnType)
+
+        if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \
+                and not isinstance(self._returnType_placeholder, StructType):
+            raise ValueError("Invalid returnType: returnType must be a StructType for "
+                             "pandas_udf with function type GROUP_MAP")
+
+        return self._returnType_placeholder
+
+    @property
+    def _judf(self):
+        # It is possible that concurrent access, to newly created UDF,
+        # will initialize multiple UserDefinedPythonFunctions.
+        # This is unlikely, doesn't affect correctness,
+        # and should have a minimal performance impact.
+        if self._judf_placeholder is None:
+            self._judf_placeholder = self._create_judf()
+        return self._judf_placeholder
+
+    def _create_judf(self):
+        from pyspark.sql import SparkSession
+
+        spark = SparkSession.builder.getOrCreate()
+        sc = spark.sparkContext
+
+        wrapped_func = _wrap_function(sc, self.func, self.returnType)
+        jdt = spark._jsparkSession.parseDataType(self.returnType.json())
+        judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
+            self._name, wrapped_func, jdt, self.evalType)
+        return judf
+
+    def __call__(self, *cols):
+        judf = self._judf
+        sc = SparkContext._active_spark_context
+        return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
+
+    def _wrapped(self):
+        """
+        Wrap this udf with a function and attach docstring from func
+        """
+
+        # It is possible for a callable instance without __name__ attribute or/and
+        # __module__ attribute to be wrapped here. For example, functools.partial. In this case,
+        # we should avoid wrapping the attributes from the wrapped function to the wrapper
+        # function. So, we take out these attribute names from the default names to set and
+        # then manually assign it after being wrapped.
+        assignments = tuple(
+            a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
+
+        @functools.wraps(self.func, assigned=assignments)
+        def wrapper(*args):
+            return self(*args)
+
+        wrapper.__name__ = self._name
+        wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
+                              else self.func.__class__.__module__)
+
+        wrapper.func = self.func
+        wrapper.returnType = self.returnType
+        wrapper.evalType = self.evalType
+
+        return wrapper
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 5e100e0a9a95d827ce4aeb59edeecba5bbc41908..939643071943ab91d470c66e8c6a9c5d520fac1d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -29,8 +29,9 @@ from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
+from pyspark.rdd import PythonEvalType
 from pyspark.serializers import write_with_length, write_int, read_long, \
-    write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
+    write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
 from pyspark.sql.types import to_arrow_type
 from pyspark import shuffle
@@ -73,7 +74,7 @@ def wrap_udf(f, return_type):
         return lambda *a: f(*a)
 
 
-def wrap_pandas_udf(f, return_type):
+def wrap_pandas_scalar_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
 
     def verify_result_length(*a):
@@ -89,6 +90,26 @@ def wrap_pandas_udf(f, return_type):
     return lambda *a: (verify_result_length(*a), arrow_return_type)
 
 
+def wrap_pandas_group_map_udf(f, return_type):
+    def wrapped(*series):
+        import pandas as pd
+
+        result = f(pd.concat(series, axis=1))
+        if not isinstance(result, pd.DataFrame):
+            raise TypeError("Return type of the user-defined function should be "
+                            "pandas.DataFrame, but is {}".format(type(result)))
+        if not len(result.columns) == len(return_type):
+            raise RuntimeError(
+                "Number of columns of the returned pandas.DataFrame "
+                "doesn't match specified schema. "
+                "Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
+        arrow_return_types = (to_arrow_type(field.dataType) for field in return_type)
+        return [(result[result.columns[i]], arrow_type)
+                for i, arrow_type in enumerate(arrow_return_types)]
+
+    return wrapped
+
+
 def read_single_udf(pickleSer, infile, eval_type):
     num_arg = read_int(infile)
     arg_offsets = [read_int(infile) for i in range(num_arg)]
@@ -99,12 +120,12 @@ def read_single_udf(pickleSer, infile, eval_type):
             row_func = f
         else:
             row_func = chain(row_func, f)
+
     # the last returnType will be the return type of UDF
-    if eval_type == PythonEvalType.SQL_PANDAS_UDF:
-        return arg_offsets, wrap_pandas_udf(row_func, return_type)
-    elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
-        # a groupby apply udf has already been wrapped under apply()
-        return arg_offsets, row_func
+    if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF:
+        return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type)
+    elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
+        return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type)
     else:
         return arg_offsets, wrap_udf(row_func, return_type)
 
@@ -127,8 +148,8 @@ def read_udfs(pickleSer, infile, eval_type):
 
     func = lambda _, it: map(mapper, it)
 
-    if eval_type == PythonEvalType.SQL_PANDAS_UDF \
-       or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
+    if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \
+       or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
         ser = ArrowStreamPandasSerializer()
     else:
         ser = BatchedSerializer(PickleSerializer(), 100)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 3e4edd4ea8cf39bb2f87d42ca527e603048cba66..a009c00b0abc5998af7cf4912d0f8ecde4d75d1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
 import scala.language.implicitConversions
 
 import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
 import org.apache.spark.sql.catalyst.expressions._
@@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
-import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType}
+import org.apache.spark.sql.execution.python.PythonUDF
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{NumericType, StructType}
 
@@ -449,10 +450,10 @@ class RelationalGroupedDataset protected[sql](
    * workers.
    */
   private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
-    require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF,
-      "Must pass a grouped vectorized python udf")
+    require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
+      "Must pass a group map udf")
     require(expr.dataType.isInstanceOf[StructType],
-      "The returnType of the vectorized python udf must be a StructType")
+      "The returnType of the udf must be a StructType")
 
     val groupingNamedExpressions = groupingExprs.map {
       case ne: NamedExpression => ne
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index bcda2dae92e53617934f94f192f2a8b162e29b0b..e27210117a1e7baf920ed70b22349ae2738b6024 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
 
     val columnarBatchIter = new ArrowPythonRunner(
         funcs, bufferSize, reuseWorker,
-        PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone)
+        PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone)
       .compute(batchIter, context.partitionId(), context)
 
     new Iterator[InternalRow] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index e15e760136e81d95d0061df2aefa8ad203773098..f5a4cbc4793e30e664cda01681b09cd9600c8f82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.python
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
@@ -148,15 +149,18 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
           udf.references.subsetOf(child.outputSet)
         }
         if (validUdfs.nonEmpty) {
-          if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) {
-            throw new IllegalArgumentException("Can not use grouped vectorized UDFs")
-          }
+          require(validUdfs.forall(udf =>
+            udf.evalType == PythonEvalType.SQL_BATCHED_UDF ||
+            udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF
+          ), "Can only extract scalar vectorized udf or sql batch udf")
 
           val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
             AttributeReference(s"pythonUDF$i", u.dataType)()
           }
 
-          val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match {
+          val evaluation = validUdfs.partition(
+            _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF
+          ) match {
             case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
               ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
             case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index e1e04e34e0c7152d8bff401ca4d76013f649d967..ee495814b82555b8ba769b105d422e060dd0d5cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -95,7 +95,7 @@ case class FlatMapGroupsInPandasExec(
 
       val columnarBatchIter = new ArrowPythonRunner(
         chainedFunc, bufferSize, reuseWorker,
-        PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone)
+        PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, sessionLocalTimeZone)
           .compute(grouped, context.partitionId(), context)
 
       columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 9c07c7638de5756994bf9344a1bc9605525168a0..ef27fbc2db7d92cd4470dc645661cdd280350da3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -29,7 +29,7 @@ case class PythonUDF(
     func: PythonFunction,
     dataType: DataType,
     children: Seq[Expression],
-    pythonUdfType: Int)
+    evalType: Int)
   extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
 
   override def toString: String = s"$name(${children.mkString(", ")})"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index b2fe6c300846a175fae7a3dcbf6f5ce230f1f53e..348e49e473ed3a03486b45b894f1ea64cbd00d5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -22,15 +22,6 @@ import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.types.DataType
 
-private[spark] object PythonUdfType {
-  // row-at-a-time UDFs
-  val NORMAL_UDF = 0
-  // scalar vectorized UDFs
-  val PANDAS_UDF = 1
-  // grouped vectorized UDFs
-  val PANDAS_GROUPED_UDF = 2
-}
-
 /**
  * A user-defined Python function. This is used by the Python API.
  */
@@ -38,10 +29,10 @@ case class UserDefinedPythonFunction(
     name: String,
     func: PythonFunction,
     dataType: DataType,
-    pythonUdfType: Int) {
+    pythonEvalType: Int) {
 
   def builder(e: Seq[Expression]): PythonUDF = {
-    PythonUDF(name, func, dataType, e, pythonUdfType)
+    PythonUDF(name, func, dataType, e, pythonEvalType)
   }
 
   /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 95b21fc9f16ae43947590c8d8690519b337b5759..53d3f34567518337eab001e033285ac01c2a7f01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.api.python.PythonFunction
+import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
 import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
@@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
   name = "dummyUDF",
   func = new DummyUDF,
   dataType = BooleanType,
-  pythonUdfType = PythonUdfType.NORMAL_UDF)
+  pythonEvalType = PythonEvalType.SQL_BATCHED_UDF)