Skip to content
Snippets Groups Projects
Commit a2447799 authored by Andrew Ray's avatar Andrew Ray Committed by Yin Huai
Browse files

[SPARK-11690][PYSPARK] Add pivot to python api

This PR adds pivot to the python api of GroupedData with the same syntax as Scala/Java.

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #9653 from aray/sql-pivot-python.
parent 99693fef
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from pyspark import since from pyspark import since
from pyspark.rdd import ignore_unicode_prefix from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import Column, _to_seq from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
from pyspark.sql.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import * from pyspark.sql.types import *
...@@ -167,6 +167,23 @@ class GroupedData(object): ...@@ -167,6 +167,23 @@ class GroupedData(object):
[Row(sum(age)=7, sum(height)=165)] [Row(sum(age)=7, sum(height)=165)]
""" """
@since(1.6)
def pivot(self, pivot_col, *values):
"""Pivots a column of the current DataFrame and preform the specified aggregation.
:param pivot_col: Column to pivot
:param values: Optional list of values of pivotColumn that will be translated to columns in
the output data frame. If values are not provided the method with do an immediate call
to .distinct() on the pivot column.
>>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect()
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
jgd = self._jdf.pivot(_to_java_column(pivot_col),
_to_seq(self.sql_ctx._sc, values, _create_column_from_literal))
return GroupedData(jgd, self.sql_ctx)
def _test(): def _test():
import doctest import doctest
...@@ -182,6 +199,11 @@ def _test(): ...@@ -182,6 +199,11 @@ def _test():
StructField('name', StringType())])) StructField('name', StringType())]))
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF() Row(name='Bob', age=5, height=85)]).toDF()
globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000),
Row(course="Java", year=2012, earnings=20000),
Row(course="dotNET", year=2012, earnings=5000),
Row(course="dotNET", year=2013, earnings=48000),
Row(course="Java", year=2013, earnings=30000)]).toDF()
(failure_count, test_count) = doctest.testmod( (failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs, pyspark.sql.group, globs=globs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment