Skip to content
Snippets Groups Projects
Commit 8e7f098a authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Added accumulators to PySpark

parent 54c0f9f1
No related branches found
No related tags found
No related merge requests found
package spark.api.python package spark.api.python
import java.io._ import java.io._
import java.util.{List => JList} import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Collections}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.io.Source import scala.io.Source
...@@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} ...@@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast import spark.broadcast.Broadcast
import spark._ import spark._
import spark.rdd.PipedRDD import spark.rdd.PipedRDD
import java.util
private[spark] class PythonRDD[T: ClassManifest]( private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T], parent: RDD[T],
command: Seq[String], command: Seq[String],
envVars: java.util.Map[String, String], envVars: java.util.Map[String, String],
preservePartitoning: Boolean, preservePartitoning: Boolean,
pythonExec: String, pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]]) broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent.context) { extends RDD[Array[Byte]](parent.context) {
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
preservePartitoning: Boolean, pythonExec: String, preservePartitoning: Boolean, pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
broadcastVars) broadcastVars, accumulator)
override def splits = parent.splits override def splits = parent.splits
...@@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest](
// Return an iterator that read lines from the process's stdout // Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream) val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] { return new Iterator[Array[Byte]] {
def next() = { def next(): Array[Byte] = {
val obj = _nextObj val obj = _nextObj
_nextObj = read() _nextObj = read()
obj obj
} }
private def read() = { private def read(): Array[Byte] = {
try { try {
val length = stream.readInt() val length = stream.readInt()
val obj = new Array[Byte](length) if (length != -1) {
stream.readFully(obj) val obj = new Array[Byte](length)
obj stream.readFully(obj)
obj
} else {
// We've finished the data section of the output, but we can still read some
// accumulator updates; let's do that, breaking when we get EOFException
while (true) {
val len2 = stream.readInt()
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
new Array[Byte](0)
}
} catch { } catch {
case eof: EOFException => { case eof: EOFException => {
val exitStatus = proc.waitFor() val exitStatus = proc.waitFor()
...@@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte] ...@@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte]
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
} }
/**
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
: JList[Array[Byte]] = {
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
val out = new DataOutputStream(socket.getOutputStream)
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
out.write(array)
}
out.flush()
// Wait for a byte from the Python side as an acknowledgement
val byteRead = in.read()
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
socket.close()
null
}
}
}
...@@ -7,6 +7,10 @@ Public classes: ...@@ -7,6 +7,10 @@ Public classes:
Main entry point for Spark functionality. Main entry point for Spark functionality.
- L{RDD<pyspark.rdd.RDD>} - L{RDD<pyspark.rdd.RDD>}
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- L{Broadcast<pyspark.broadcast.Broadcast>}
A broadcast variable that gets reused across tasks.
- L{Accumulator<pyspark.accumulators.Accumulator>}
An "add-only" shared variable that tasks can only add values to.
""" """
import sys import sys
import os import os
......
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> a = sc.accumulator(1)
>>> a.value
1
>>> a.value = 2
>>> a.value
2
>>> a += 5
>>> a.value
7
>>> rdd = sc.parallelize([1,2,3])
>>> def f(x):
... global a
... a += x
>>> rdd.foreach(f)
>>> a.value
13
>>> class VectorAccumulatorParam(object):
... def zero(self, value):
... return [0.0] * len(value)
... def addInPlace(self, val1, val2):
... for i in xrange(len(val1)):
... val1[i] += val2[i]
... return val1
>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
>>> va.value
[1.0, 2.0, 3.0]
>>> def g(x):
... global va
... va += [x] * 3
>>> rdd.foreach(g)
>>> va.value
[7.0, 8.0, 9.0]
>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
Py4JJavaError:...
>>> def h(x):
... global a
... a.value = 7
>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
Py4JJavaError:...
>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
Exception:...
"""
import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import read_int, read_with_length, load_pickle
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
_accumulatorRegistry = {}
def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum
class Accumulator(object):
def __init__(self, aid, value, accum_param):
"""Create a new Accumulator with a given initial value and AccumulatorParam object"""
from pyspark.accumulators import _accumulatorRegistry
self.aid = aid
self.accum_param = accum_param
self._value = value
self._deserialized = False
_accumulatorRegistry[aid] = self
def __reduce__(self):
"""Custom serialization; saves the zero value from our AccumulatorParam"""
param = self.accum_param
return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
@property
def value(self):
"""Get the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
return self._value
@value.setter
def value(self, value):
"""Sets the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
return self
def __str__(self):
return str(self._value)
class AddingAccumulatorParam(object):
"""
An AccumulatorParam that uses the + operators to add values. Designed for simple types
such as integers, floats, and lists. Requires the zero value for the underlying type
as a parameter.
"""
def __init__(self, zero_value):
self.zero_value = zero_value
def zero(self, value):
return self.zero_value
def addInPlace(self, value1, value2):
value1 += value2
return value1
# Singleton accumulator params for some standard types
INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
def handle(self):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = load_pickle(read_with_length(self.rfile))
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
def _start_update_server():
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
return server
def _test():
import doctest
doctest.testmod()
if __name__ == "__main__":
_test()
...@@ -2,6 +2,8 @@ import os ...@@ -2,6 +2,8 @@ import os
import atexit import atexit
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.serializers import dump_pickle, write_with_length, batched
...@@ -22,6 +24,7 @@ class SparkContext(object): ...@@ -22,6 +24,7 @@ class SparkContext(object):
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition _takePartition = jvm.PythonRDD.takePartition
_next_accum_id = 0
def __init__(self, master, jobName, sparkHome=None, pyFiles=None, def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024): environment=None, batchSize=1024):
...@@ -52,6 +55,14 @@ class SparkContext(object): ...@@ -52,6 +55,14 @@ class SparkContext(object):
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array) empty_string_array)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
self.jvm.java.util.ArrayList(),
self.jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here. # Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have # This allows other code to determine which Broadcast instances have
...@@ -74,6 +85,8 @@ class SparkContext(object): ...@@ -74,6 +85,8 @@ class SparkContext(object):
def __del__(self): def __del__(self):
if self._jsc: if self._jsc:
self._jsc.stop() self._jsc.stop()
if self._accumulatorServer:
self._accumulatorServer.shutdown()
def stop(self): def stop(self):
""" """
...@@ -129,6 +142,31 @@ class SparkContext(object): ...@@ -129,6 +142,31 @@ class SparkContext(object):
return Broadcast(jbroadcast.id(), value, jbroadcast, return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars) self._pickled_broadcast_vars)
def accumulator(self, value, accum_param=None):
"""
Create an C{Accumulator} with the given initial value, using a given
AccumulatorParam helper object to define how to add values of the data
type if provided. Default AccumulatorParams are used for integers and
floating-point numbers if you do not provide one. For other types, the
AccumulatorParam must implement two methods:
- C{zero(value)}: provide a "zero value" for the type, compatible in
dimensions with the provided C{value} (e.g., a zero vector).
- C{addInPlace(val1, val2)}: add two values of the accumulator's data
type, returning a new value; for efficiency, can also update C{val1}
in place and return it.
"""
if accum_param == None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):
accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
elif isinstance(value, complex):
accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
else:
raise Exception("No default accumulator param for type %s" % type(value))
SparkContext._next_accum_id += 1
return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
def addFile(self, path): def addFile(self, path):
""" """
Add a file to be downloaded into the working directory of this Spark Add a file to be downloaded into the working directory of this Spark
......
...@@ -703,7 +703,7 @@ class PipelinedRDD(RDD): ...@@ -703,7 +703,7 @@ class PipelinedRDD(RDD):
env = MapConverter().convert(env, self.ctx.gateway._gateway_client) env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, class_manifest) broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD() self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val return self._jrdd_val
......
...@@ -52,8 +52,13 @@ def read_int(stream): ...@@ -52,8 +52,13 @@ def read_int(stream):
raise EOFError raise EOFError
return struct.unpack("!i", length)[0] return struct.unpack("!i", length)[0]
def write_int(value, stream):
stream.write(struct.pack("!i", value))
def write_with_length(obj, stream): def write_with_length(obj, stream):
stream.write(struct.pack("!i", len(obj))) write_int(len(obj), stream)
stream.write(obj) stream.write(obj)
......
""" """
An interactive shell. An interactive shell.
This fle is designed to be launched as a PYTHONSTARTUP script. This file is designed to be launched as a PYTHONSTARTUP script.
""" """
import os import os
from pyspark.context import SparkContext from pyspark.context import SparkContext
...@@ -14,4 +14,4 @@ print "Spark context avaiable as sc." ...@@ -14,4 +14,4 @@ print "Spark context avaiable as sc."
# which allows us to execute the user's PYTHONSTARTUP file: # which allows us to execute the user's PYTHONSTARTUP file:
_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
if _pythonstartup and os.path.isfile(_pythonstartup): if _pythonstartup and os.path.isfile(_pythonstartup):
execfile(_pythonstartup) execfile(_pythonstartup)
...@@ -5,9 +5,10 @@ import sys ...@@ -5,9 +5,10 @@ import sys
from base64 import standard_b64decode from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the # CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module. # copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import write_with_length, read_with_length, \ from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
...@@ -36,6 +37,10 @@ def main(): ...@@ -36,6 +37,10 @@ def main():
iterator = read_from_pickle_file(sys.stdin) iterator = read_from_pickle_file(sys.stdin)
for obj in func(split_index, iterator): for obj in func(split_index, iterator):
write_with_length(dumps(obj), old_stdout) write_with_length(dumps(obj), old_stdout)
# Mark the beginning of the accumulators section of the output
write_int(-1, old_stdout)
for aid, accum in _accumulatorRegistry.items():
write_with_length(dump_pickle((aid, accum._value)), old_stdout)
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment