Skip to content
Snippets Groups Projects
Commit ec31e68d authored by root's avatar root
Browse files

Fixed PySpark perf regression by not using socket.makefile(), and improved

debuggability by letting "print" statements show up in the executor's stderr

Conflicts:
	core/src/main/scala/spark/api/python/PythonRDD.scala
parent 3296d132
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest](
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
......@@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest](
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream)
// Partition index
......@@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest](
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream))
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
......@@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
......@@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream))
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
......
......@@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
daemon = pb.start()
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
......@@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
}.start()
val in = new DataInputStream(daemon.getInputStream)
daemonPort = in.readInt()
// Redirect further stdout output to our stderr
new Thread("stdout reader for " + pythonExec) {
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
var buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
} catch {
case e => {
stopDaemon()
......
import os
import signal
import socket
import sys
import traceback
import multiprocessing
from ctypes import c_bool
from errno import EINTR, ECHILD
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int
......@@ -33,11 +36,12 @@ def compute_real_exit_code(exit_code):
def worker(listen_sock):
# Redirect stdout to stderr
os.dup2(2, 1)
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
# Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(*args):
assert should_exit()
signal(SIGHUP, handle_sighup)
signal.signal(SIGHUP, handle_sighup)
# Cleanup zombie children
def handle_sigchld(*args):
......@@ -51,7 +55,7 @@ def worker(listen_sock):
handle_sigchld()
elif err.errno != ECHILD:
raise
signal(SIGCHLD, handle_sigchld)
signal.signal(SIGCHLD, handle_sigchld)
# Handle clients
while not should_exit():
......@@ -70,19 +74,22 @@ def worker(listen_sock):
# never receives SIGCHLD unless a worker crashes.
if os.fork() == 0:
# Leave the worker pool
signal(SIGHUP, SIG_DFL)
signal.signal(SIGHUP, SIG_DFL)
listen_sock.close()
# Handle the client then exit
sockfile = sock.makefile()
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
worker_main(sockfile, sockfile)
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
exit_code = exc.code
finally:
sockfile.close()
sock.close()
os._exit(compute_real_exit_code(exit_code))
outfile.flush()
sock.close()
os._exit(compute_real_exit_code(exit_code))
else:
sock.close()
......@@ -92,7 +99,6 @@ def launch_worker(listen_sock):
try:
worker(listen_sock)
except Exception as err:
import traceback
traceback.print_exc()
os._exit(1)
else:
......@@ -105,7 +111,7 @@ def manager():
os.setpgid(0, 0)
# Create a listening socket on the AF_INET loopback interface
listen_sock = socket(AF_INET, SOCK_STREAM)
listen_sock = socket.socket(AF_INET, SOCK_STREAM)
listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
......@@ -121,8 +127,8 @@ def manager():
exit_flag.value = True
# Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown())
signal(SIGHUP, SIG_IGN)
signal.signal(SIGTERM, lambda signum, frame: shutdown())
signal.signal(SIGHUP, SIG_IGN)
# Cleanup zombie children
def handle_sigchld(*args):
......@@ -133,7 +139,7 @@ def manager():
except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR):
raise
signal(SIGCHLD, handle_sigchld)
signal.signal(SIGCHLD, handle_sigchld)
# Initialization complete
sys.stdout.close()
......@@ -148,7 +154,7 @@ def manager():
shutdown()
raise
finally:
signal(SIGTERM, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
exit_flag.value = True
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment