Skip to content
Snippets Groups Projects
Commit 55349f9f authored by Davies Liu's avatar Davies Liu Committed by Josh Rosen
Browse files

[SPARK-1740] [PySpark] kill the python worker

Kill only the python worker related to cancelled tasks.

The daemon will start a background thread to monitor all the opened sockets for all workers. If the socket is closed by JVM, this thread will kill the worker.

When an task is cancelled, the socket to worker will be closed, then the worker will be killed by deamon.

Author: Davies Liu <davies.liu@gmail.com>

Closes #1643 from davies/kill and squashes the following commits:

8ffe9f3 [Davies Liu] kill worker by deamon, because runtime.exec() is too heavy
46ca150 [Davies Liu] address comment
acd751c [Davies Liu] kill the worker when task is canceled
parent e139e2be
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.File
import java.net.Socket
import scala.collection.JavaConversions._
import scala.collection.mutable
......@@ -102,10 +103,10 @@ class SparkEnv (
}
private[spark]
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers(key).stop()
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
}
......
......@@ -62,8 +62,8 @@ private[spark] class PythonRDD(
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
val worker: Socket = env.createPythonWorker(pythonExec,
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
......@@ -241,7 +241,7 @@ private[spark] class PythonRDD(
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
......@@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {
/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* This function is outdated, PySpark does not use it anymore
*/
@deprecated
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
......
......@@ -17,9 +17,11 @@
package org.apache.spark.api.python
import java.io.{DataInputStream, InputStream, OutputStreamWriter}
import java.lang.Runtime
import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}
import scala.collection.mutable
import scala.collection.JavaConversions._
import org.apache.spark._
......@@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
......@@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
def createSocket(): Socket = {
val socket = new Socket(daemonHost, daemonPort)
val pid = new DataInputStream(socket.getInputStream).readInt()
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
daemonWorkers.put(socket, pid)
socket
}
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
// Attempt to connect, restart and retry once if it fails
try {
val socket = new Socket(daemonHost, daemonPort)
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
if (launchStatus != 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
socket
createSocket()
} catch {
case exc: SocketException =>
logWarning("Failed to open socket to Python daemon:", exc)
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
createSocket()
}
}
}
......@@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
return serverSocket.accept()
val socket = serverSocket.accept()
simpleWorkers.put(socket, worker)
return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
......@@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
if (useDaemon) {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
daemon = null
daemonPort = 0
daemon = null
daemonPort = 0
} else {
simpleWorkers.mapValues(_.destroy())
}
}
}
def stop() {
stopDaemon()
}
def stopWorker(worker: Socket) {
if (useDaemon) {
if (daemon != null) {
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.flush()
daemon.getOutputStream.flush()
}
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}
}
private object PythonWorkerFactory {
......
......@@ -26,7 +26,7 @@ from errno import EINTR, ECHILD
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int
from pyspark.serializers import read_int, write_int
def compute_real_exit_code(exit_code):
......@@ -67,7 +67,8 @@ def worker(sock):
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
write_int(0, outfile) # Acknowledge that the fork was successful
# Acknowledge that the fork was successful
write_int(os.getpid(), outfile)
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
......@@ -125,14 +126,23 @@ def manager():
else:
raise
if 0 in ready_fds:
# Spark told us to exit by closing stdin
shutdown(0)
try:
worker_pid = read_int(sys.stdin)
except EOFError:
# Spark told us to exit by closing stdin
shutdown(0)
try:
os.kill(worker_pid, signal.SIGKILL)
except OSError:
pass # process already died
if listen_sock in ready_fds:
sock, addr = listen_sock.accept()
# Launch a worker process
try:
fork_return_code = os.fork()
if fork_return_code == 0:
pid = os.fork()
if pid == 0:
listen_sock.close()
try:
worker(sock)
......@@ -143,11 +153,13 @@ def manager():
os._exit(0)
else:
sock.close()
except OSError as e:
print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
write_int(-1, outfile) # Signal that the fork failed
outfile.flush()
outfile.close()
sock.close()
finally:
shutdown(1)
......
......@@ -790,6 +790,57 @@ class TestDaemon(unittest.TestCase):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
class TestWorker(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
temp.close()
path = temp.name
def sleep(x):
import os, time
with open(path, 'w') as f:
f.write("%d %d" % (os.getppid(), os.getpid()))
time.sleep(100)
# start job in background thread
def run():
self.sc.parallelize(range(1)).foreach(sleep)
import threading
t = threading.Thread(target=run)
t.daemon = True
t.start()
daemon_pid, worker_pid = 0, 0
while True:
if os.path.exists(path):
data = open(path).read().split(' ')
daemon_pid, worker_pid = map(int, data)
break
time.sleep(0.1)
# cancel jobs
self.sc.cancelAllJobs()
t.join()
for i in range(50):
try:
os.kill(worker_pid, 0)
time.sleep(0.1)
except OSError:
break # worker was killed
else:
self.fail("worker has not been killed after 5 seconds")
try:
os.kill(daemon_pid, 0)
except OSError:
self.fail("daemon had been killed")
def test_fd_leak(self):
N = 1100 # fd limit is 1024 by default
rdd = self.sc.parallelize(range(N), N)
self.assertEquals(N, rdd.count())
class TestSparkSubmit(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment