diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 99f5967a8edc932824ca6413ba111a8db03895a0..1e9b3bb5c03480a2e6e10c8df85fa8087f44c76e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,6 +31,7 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup +from pyspark.statcounter import StatCounter from py4j.java_collections import ListConverter, MapConverter @@ -357,6 +358,63 @@ class RDD(object): 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() + + def stats(self): + """ + Return a L{StatCounter} object that captures the mean, variance + and count of the RDD's elements in one operation. + """ + def redFunc(left_counter, right_counter): + return left_counter.mergeStats(right_counter) + + return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc) + + def mean(self): + """ + Compute the mean of this RDD's elements. + + >>> sc.parallelize([1, 2, 3]).mean() + 2.0 + """ + return self.stats().mean() + + def variance(self): + """ + Compute the variance of this RDD's elements. + + >>> sc.parallelize([1, 2, 3]).variance() + 0.666... + """ + return self.stats().variance() + + def stdev(self): + """ + Compute the standard deviation of this RDD's elements. + + >>> sc.parallelize([1, 2, 3]).stdev() + 0.816... + """ + return self.stats().stdev() + + def sampleStdev(self): + """ + Compute the sample standard deviation of this RDD's elements (which corrects for bias in + estimating the standard deviation by dividing by N-1 instead of N). + + >>> sc.parallelize([1, 2, 3]).sampleStdev() + 1.0 + """ + return self.stats().sampleStdev() + + def sampleVariance(self): + """ + Compute the sample variance of this RDD's elements (which corrects for bias in + estimating the variance by dividing by N-1 instead of N). + + >>> sc.parallelize([1, 2, 3]).sampleVariance() + 1.0 + """ + return self.stats().sampleVariance() def countByValue(self): """ @@ -777,7 +835,7 @@ def _test(): # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1cbd4ad9856bacbfc529aa55f7c34373eae758 --- /dev/null +++ b/python/pyspark/statcounter.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This file is ported from spark/util/StatCounter.scala + +import copy +import math + +class StatCounter(object): + + def __init__(self, values=[]): + self.n = 0L # Running count of our values + self.mu = 0.0 # Running mean of our values + self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) + + for v in values: + self.merge(v) + + # Add a value into this StatCounter, updating the internal statistics. + def merge(self, value): + delta = value - self.mu + self.n += 1 + self.mu += delta / self.n + self.m2 += delta * (value - self.mu) + return self + + # Merge another StatCounter into this one, adding up the internal statistics. + def mergeStats(self, other): + if not isinstance(other, StatCounter): + raise Exception("Can only merge Statcounters!") + + if other is self: # reference equality holds + self.merge(copy.deepcopy(other)) # Avoid overwriting fields in a weird order + else: + if self.n == 0: + self.mu = other.mu + self.m2 = other.m2 + self.n = other.n + elif other.n != 0: + delta = other.mu - self.mu + if other.n * 10 < self.n: + self.mu = self.mu + (delta * other.n) / (self.n + other.n) + elif self.n * 10 < other.n: + self.mu = other.mu - (delta * self.n) / (self.n + other.n) + else: + self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) + + self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) + self.n += other.n + return self + + # Clone this StatCounter + def copy(self): + return copy.deepcopy(self) + + def count(self): + return self.n + + def mean(self): + return self.mu + + def sum(self): + return self.n * self.mu + + # Return the variance of the values. + def variance(self): + if self.n == 0: + return float('nan') + else: + return self.m2 / self.n + + # + # Return the sample variance, which corrects for bias in estimating the variance by dividing + # by N-1 instead of N. + # + def sampleVariance(self): + if self.n <= 1: + return float('nan') + else: + return self.m2 / (self.n - 1) + + # Return the standard deviation of the values. + def stdev(self): + return math.sqrt(self.variance()) + + # + # Return the sample standard deviation of the values, which corrects for bias in estimating the + # variance by dividing by N-1 instead of N. + # + def sampleStdev(self): + return math.sqrt(self.sampleVariance()) + + def __repr__(self): + return "(count: %s, mean: %s, stdev: %s)" % (self.count(), self.mean(), self.stdev()) +