Skip to content
Snippets Groups Projects
Commit 6550e5e6 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Allow PySpark to launch worker.py directly on Windows

parent 3c520fea
No related branches found
No related tags found
No related merge requests found
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
package org.apache.spark.api.python package org.apache.spark.api.python
import java.io.{File, DataInputStream, IOException} import java.io.{OutputStreamWriter, File, DataInputStream, IOException}
import java.net.{Socket, SocketException, InetAddress} import java.net.{ServerSocket, Socket, SocketException, InetAddress}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
...@@ -26,11 +26,30 @@ import org.apache.spark._ ...@@ -26,11 +26,30 @@ import org.apache.spark._
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
extends Logging { extends Logging {
// Because forking processes from Java is expensive, we prefer to launch a single Python daemon
// (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently
// only works on UNIX-based systems now because it uses signals for child management, so we can
// also fall back to launching workers (pyspark/worker.py) directly.
val useDaemon = !System.getProperty("os.name").startsWith("Windows")
var daemon: Process = null var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0 var daemonPort: Int = 0
def create(): Socket = { def create(): Socket = {
if (useDaemon) {
createThroughDaemon()
} else {
createSimpleWorker()
}
}
/**
* Connect to a worker launched through pyspark/daemon.py, which forks python processes itself
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
synchronized { synchronized {
// Start the daemon if it hasn't been started // Start the daemon if it hasn't been started
startDaemon() startDaemon()
...@@ -50,6 +69,78 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String ...@@ -50,6 +69,78 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
} }
} }
/**
* Launch a worker by executing worker.py directly and telling it to connect to us.
*/
private def createSimpleWorker(): Socket = {
var serverSocket: ServerSocket = null
try {
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Create and start the worker
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/worker.py"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
workerEnv.put("PYTHONPATH", pythonPath)
val worker = pb.start()
// Redirect the worker's stderr to ours
new Thread("stderr reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = worker.getErrorStream
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
// Redirect worker's stdout to our stderr
new Thread("stdout reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = worker.getInputStream
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
// Tell the worker our port
val out = new OutputStreamWriter(worker.getOutputStream)
out.write(serverSocket.getLocalPort + "\n")
out.flush()
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
return serverSocket.accept()
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
}
} finally {
if (serverSocket != null) {
serverSocket.close()
}
}
null
}
def stop() { def stop() {
stopDaemon() stopDaemon()
} }
...@@ -73,12 +164,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String ...@@ -73,12 +164,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect the stderr to ours // Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) { new Thread("stderr reader for " + pythonExec) {
setDaemon(true)
override def run() { override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) { scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
// attempt to dodge encoding problems.
val in = daemon.getErrorStream val in = daemon.getErrorStream
var buf = new Array[Byte](1024) val buf = new Array[Byte](1024)
var len = in.read(buf) var len = in.read(buf)
while (len != -1) { while (len != -1) {
System.err.write(buf, 0, len) System.err.write(buf, 0, len)
...@@ -93,11 +184,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String ...@@ -93,11 +184,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect further stdout output to our stderr // Redirect further stdout output to our stderr
new Thread("stdout reader for " + pythonExec) { new Thread("stdout reader for " + pythonExec) {
setDaemon(true)
override def run() { override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) { scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
// attempt to dodge encoding problems. val buf = new Array[Byte](1024)
var buf = new Array[Byte](1024)
var len = in.read(buf) var len = in.read(buf)
while (len != -1) { while (len != -1) {
System.err.write(buf, 0, len) System.err.write(buf, 0, len)
......
...@@ -21,6 +21,7 @@ Worker that receives input from Piped RDD. ...@@ -21,6 +21,7 @@ Worker that receives input from Piped RDD.
import os import os
import sys import sys
import time import time
import socket
import traceback import traceback
from base64 import standard_b64decode from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the # CloudPickler needs to be imported so that depicklers are registered using the
...@@ -94,7 +95,9 @@ def main(infile, outfile): ...@@ -94,7 +95,9 @@ def main(infile, outfile):
if __name__ == '__main__': if __name__ == '__main__':
# Redirect stdout to stderr so that users must return values from functions. # Read a local port to connect to from stdin
old_stdout = os.fdopen(os.dup(1), 'w') java_port = int(sys.stdin.readline())
os.dup2(2, 1) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
main(sys.stdin, old_stdout) sock.connect(("127.0.0.1", java_port))
sock_file = sock.makefile("a+", 65536)
main(sock_file, sock_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment