Skip to content
Snippets Groups Projects
Commit 7c5ff733 authored by Jey Kottalam's avatar Jey Kottalam
Browse files

PySpark daemon: fix deadlock, improve error handling

parent edb18ca9
No related branches found
No related tags found
No related merge requests found
import os import os
import sys import sys
import multiprocessing import multiprocessing
from ctypes import c_bool
from errno import EINTR, ECHILD from errno import EINTR, ECHILD
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
...@@ -12,7 +13,12 @@ try: ...@@ -12,7 +13,12 @@ try:
except NotImplementedError: except NotImplementedError:
POOLSIZE = 4 POOLSIZE = 4
should_exit = multiprocessing.Event() exit_flag = multiprocessing.Value(c_bool, False)
def should_exit():
global exit_flag
return exit_flag.value
def worker(listen_sock): def worker(listen_sock):
...@@ -20,14 +26,29 @@ def worker(listen_sock): ...@@ -20,14 +26,29 @@ def worker(listen_sock):
os.dup2(2, 1) os.dup2(2, 1)
# Manager sends SIGHUP to request termination of workers in the pool # Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(signum, frame): def handle_sighup(*args):
assert should_exit.is_set() assert should_exit()
signal(SIGHUP, handle_sighup) signal(SIGHUP, handle_sighup)
while not should_exit.is_set(): # Cleanup zombie children
def handle_sigchld(*args):
pid = status = None
try:
while (pid, status) != (0, 0):
pid, status = os.waitpid(0, os.WNOHANG)
except EnvironmentError as err:
if err.errno == EINTR:
# retry
handle_sigchld()
elif err.errno != ECHILD:
raise
signal(SIGCHLD, handle_sigchld)
# Handle clients
while not should_exit():
# Wait until a client arrives or we have to exit # Wait until a client arrives or we have to exit
sock = None sock = None
while not should_exit.is_set() and sock is None: while not should_exit() and sock is None:
try: try:
sock, addr = listen_sock.accept() sock, addr = listen_sock.accept()
except EnvironmentError as err: except EnvironmentError as err:
...@@ -35,8 +56,10 @@ def worker(listen_sock): ...@@ -35,8 +56,10 @@ def worker(listen_sock):
raise raise
if sock is not None: if sock is not None:
# Fork to handle the client # Fork a child to handle the client.
if os.fork() != 0: # The client is handled in the child so that the manager
# never receives SIGCHLD unless a worker crashes.
if os.fork() == 0:
# Leave the worker pool # Leave the worker pool
signal(SIGHUP, SIG_DFL) signal(SIGHUP, SIG_DFL)
listen_sock.close() listen_sock.close()
...@@ -49,8 +72,18 @@ def worker(listen_sock): ...@@ -49,8 +72,18 @@ def worker(listen_sock):
else: else:
sock.close() sock.close()
assert should_exit.is_set()
os._exit(0) def launch_worker(listen_sock):
if os.fork() == 0:
try:
worker(listen_sock)
except Exception as err:
import traceback
traceback.print_exc()
os._exit(1)
else:
assert should_exit()
os._exit(0)
def manager(): def manager():
...@@ -66,23 +99,22 @@ def manager(): ...@@ -66,23 +99,22 @@ def manager():
# Launch initial worker pool # Launch initial worker pool
for idx in range(POOLSIZE): for idx in range(POOLSIZE):
if os.fork() == 0: launch_worker(listen_sock)
worker(listen_sock)
raise RuntimeError("worker() unexpectedly returned")
listen_sock.close() listen_sock.close()
def shutdown(): def shutdown():
should_exit.set() global exit_flag
exit_flag.value = True
# Gracefully exit on SIGTERM, don't die on SIGHUP # Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown()) signal(SIGTERM, lambda signum, frame: shutdown())
signal(SIGHUP, SIG_IGN) signal(SIGHUP, SIG_IGN)
# Cleanup zombie children # Cleanup zombie children
def handle_sigchld(signum, frame): def handle_sigchld(*args):
try: try:
pid, status = os.waitpid(0, os.WNOHANG) pid, status = os.waitpid(0, os.WNOHANG)
if status != 0 and not should_exit.is_set(): if status != 0 and not should_exit():
raise RuntimeError("worker crashed: %s, %s" % (pid, status)) raise RuntimeError("worker crashed: %s, %s" % (pid, status))
except EnvironmentError as err: except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR): if err.errno not in (ECHILD, EINTR):
...@@ -92,7 +124,7 @@ def manager(): ...@@ -92,7 +124,7 @@ def manager():
# Initialization complete # Initialization complete
sys.stdout.close() sys.stdout.close()
try: try:
while not should_exit.is_set(): while not should_exit():
try: try:
# Spark tells us to exit by closing stdin # Spark tells us to exit by closing stdin
if os.read(0, 512) == '': if os.read(0, 512) == '':
...@@ -102,7 +134,8 @@ def manager(): ...@@ -102,7 +134,8 @@ def manager():
shutdown() shutdown()
raise raise
finally: finally:
should_exit.set() signal(SIGTERM, SIG_DFL)
exit_flag.value = True
# Send SIGHUP to notify workers of shutdown # Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP) os.kill(0, SIGHUP)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment