diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 5691e24c320f678e3ce23d79cc70ba3ddb7aede8..5b55d4521230b5fabf4f3bf43d84d3c3043e08ea 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -44,6 +44,7 @@ class SparkEnv ( private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() def stop() { + pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() shuffleFetcher.stop() diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala index 8ee3c6884f98f0998bd44a5ad43042f79c00f073..74c8c6d37a811c7717519fd8fa9c6a2ab25509d1 100644 --- a/core/src/main/scala/spark/api/python/PythonWorker.scala +++ b/core/src/main/scala/spark/api/python/PythonWorker.scala @@ -33,6 +33,10 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin } } + def stop() { + stopDaemon + } + private def startDaemon() { synchronized { // Is it already running? diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 642f30b2b9644aff7f105230f57cb56a9639bc30..ab9c19df578f8a6053cc8764082e53ffa34c25de 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -12,7 +12,7 @@ try: except NotImplementedError: POOLSIZE = 4 -should_exit = False +should_exit = multiprocessing.Event() def worker(listen_sock): @@ -21,14 +21,13 @@ def worker(listen_sock): # Manager sends SIGHUP to request termination of workers in the pool def handle_sighup(signum, frame): - global should_exit - should_exit = True + assert should_exit.is_set() signal(SIGHUP, handle_sighup) - while not should_exit: + while not should_exit.is_set(): # Wait until a client arrives or we have to exit sock = None - while not should_exit and sock is None: + while not should_exit.is_set() and sock is None: try: sock, addr = listen_sock.accept() except EnvironmentError as err: @@ -36,8 +35,8 @@ def worker(listen_sock): raise if sock is not None: - # Fork a child to handle the client - if os.fork() == 0: + # Fork to handle the client + if os.fork() != 0: # Leave the worker pool signal(SIGHUP, SIG_DFL) listen_sock.close() @@ -50,7 +49,7 @@ def worker(listen_sock): else: sock.close() - assert should_exit + assert should_exit.is_set() os._exit(0) @@ -73,9 +72,7 @@ def manager(): listen_sock.close() def shutdown(): - global should_exit - os.kill(0, SIGHUP) - should_exit = True + should_exit.set() # Gracefully exit on SIGTERM, don't die on SIGHUP signal(SIGTERM, lambda signum, frame: shutdown()) @@ -85,8 +82,8 @@ def manager(): def handle_sigchld(signum, frame): try: pid, status = os.waitpid(0, os.WNOHANG) - if (pid, status) != (0, 0) and not should_exit: - raise RuntimeError("pool member crashed: %s, %s" % (pid, status)) + if status != 0 and not should_exit.is_set(): + raise RuntimeError("worker crashed: %s, %s" % (pid, status)) except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise @@ -94,15 +91,20 @@ def manager(): # Initialization complete sys.stdout.close() - while not should_exit: - try: - # Spark tells us to exit by closing stdin - if sys.stdin.read() == '': - shutdown() - except EnvironmentError as err: - if err.errno != EINTR: - shutdown() - raise + try: + while not should_exit.is_set(): + try: + # Spark tells us to exit by closing stdin + if os.read(0, 512) == '': + shutdown() + except EnvironmentError as err: + if err.errno != EINTR: + shutdown() + raise + finally: + should_exit.set() + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) if __name__ == '__main__': diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6a1962d26795dfe950049de5b3a07b3d7f12354a..1e34d473650bf03ba93fb5a509d49e831467ed4d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -12,6 +12,7 @@ import unittest from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME +from pyspark.serializers import read_int class PySparkTestCase(unittest.TestCase): @@ -117,5 +118,47 @@ class TestIO(PySparkTestCase): self.sc.parallelize([1]).foreach(func) +class TestDaemon(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send("\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") + daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + with self.assertRaises(EnvironmentError) as trap: + self.connect(port) + self.assertEqual(trap.exception.errno, ECONNREFUSED) + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 94d612ea6ed926c42f44f0c57f91078b3aa2afc3..f76ee3c236db5a23f45e1d1149ced239f1c8cac2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,6 +30,8 @@ def report_times(outfile, boot, init, finish): def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) + if split_index == -1: # for unit tests + return spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True