Skip to content
Snippets Groups Projects
Commit 1a52a623 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Nick Pentreath
Browse files

[SPARK-20076][ML][PYSPARK] Add Python interface for ml.stats.Correlation

## What changes were proposed in this pull request?

The Dataframes-based support for the correlation statistics is added in #17108. This patch adds the Python interface for it.

## How was this patch tested?

Python unit test.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #17494 from viirya/correlation-python-api.
parent ad3cc131
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ object Correlation {
/**
* :: Experimental ::
* Compute the correlation matrix for the input RDD of Vectors using the specified method.
* Compute the correlation matrix for the input Dataset of Vectors using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
*
* @param dataset A dataset or a dataframe
......@@ -56,14 +56,14 @@ object Correlation {
* Here is how to access the correlation coefficient:
* {{{
* val data: Dataset[Vector] = ...
* val Row(coeff: Matrix) = Statistics.corr(data, "value").head
* val Row(coeff: Matrix) = Correlation.corr(data, "value").head
* // coeff now contains the Pearson correlation matrix.
* }}}
*
* @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
* avoid recomputing the common lineage.
* which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"`
* to avoid recomputing the common lineage.
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
......
......@@ -71,6 +71,67 @@ class ChiSquareTest(object):
return _java2py(sc, javaTestObj.test(*args))
class Correlation(object):
"""
.. note:: Experimental
Compute the correlation matrix for the input dataset of Vectors using the specified method.
Methods currently supported: `pearson` (default), `spearman`.
.. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column
and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'`
to avoid recomputing the common lineage.
:param dataset:
A dataset or a dataframe.
:param column:
The name of the column of vectors for which the correlation coefficient needs
to be computed. This must be a column of the dataset, and it must contain
Vector objects.
:param method:
String specifying the method to use for computing correlation.
Supported: `pearson` (default), `spearman`.
:return:
A dataframe that contains the correlation matrix of the column of vectors. This
dataframe contains a single row and a single column of name
'$METHODNAME($COLUMN)'.
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.stat import Correlation
>>> dataset = [[Vectors.dense([1, 0, 0, -2])],
... [Vectors.dense([4, 5, 0, 3])],
... [Vectors.dense([6, 7, 0, 8])],
... [Vectors.dense([9, 0, 0, 1])]]
>>> dataset = spark.createDataFrame(dataset, ['features'])
>>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0]
>>> print(str(pearsonCorr).replace('nan', 'NaN'))
DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],
[ 0.0556..., 1. , NaN, 0.9135...],
[ NaN, NaN, 1. , NaN],
[ 0.4004..., 0.9135..., NaN, 1. ]])
>>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0]
>>> print(str(spearmanCorr).replace('nan', 'NaN'))
DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],
[ 0.1054..., 1. , NaN, 0.9486... ],
[ NaN, NaN, 1. , NaN],
[ 0.4 , 0.9486... , NaN, 1. ]])
.. versionadded:: 2.2.0
"""
@staticmethod
@since("2.2.0")
def corr(dataset, column, method="pearson"):
"""
Compute the correlation matrix with specified method using dataset.
"""
sc = SparkContext._active_spark_context
javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation
args = [_py2java(sc, arg) for arg in (dataset, column, method)]
return _java2py(sc, javaCorrObj.corr(*args))
if __name__ == "__main__":
import doctest
import pyspark.ml.stat
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment