From 40984f67065eeaea731940008e6677c2323dda3e Mon Sep 17 00:00:00 2001 From: Sun Rui <rui.sun@intel.com> Date: Mon, 28 Mar 2016 10:14:28 -0700 Subject: [PATCH] [SPARK-12792] [SPARKR] Refactor RRDD to support R UDF. Refactor RRDD by separating the common logic interacting with the R worker to a new class RRunner, which can be used to evaluate R UDFs. Now RRDD relies on RRuner for RDD computation and RRDD could be reomved if we want to remove RDD API in SparkR later. Author: Sun Rui <rui.sun@intel.com> Closes #10947 from sun-rui/SPARK-12792. --- R/pkg/inst/tests/testthat/test_rdd.R | 8 + .../scala/org/apache/spark/api/r/RRDD.scala | 328 +--------------- .../org/apache/spark/api/r/RRunner.scala | 367 ++++++++++++++++++ 3 files changed, 379 insertions(+), 324 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/RRunner.scala diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 3b0c16be5a..b6c8e1dc6c 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -791,3 +791,11 @@ test_that("sampleByKey() on pairwise RDDs", { expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) }) + +test_that("Test correct concurrency of RRDD.compute()", { + rdd <- parallelize(sc, 1:1000, 100) + jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") + zrdd <- callJMethod(jrdd, "zip", jrdd) + count <- callJMethod(zrdd, "count") + expect_equal(count, 1000) +}) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 588a57e65f..606ba6ef86 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,21 +17,16 @@ package org.apache.spark.api.r -import java.io._ -import java.net.{InetAddress, ServerSocket} -import java.util.{Arrays, Map => JMap} +import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -42,188 +37,16 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { - protected var dataStream: DataInputStream = _ - private var bootTime: Double = _ override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - - // Timing start - bootTime = System.currentTimeMillis / 1000.0 + val runner = new RRunner[U]( + func, deserializer, serializer, packageNames, broadcastVars, numPartitions) // The parent may be also an RRDD, so we should launch it first. val parentIterator = firstParent[T].iterator(partition, context) - // we expect two connections - val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) - val listenPort = serverSocket.getLocalPort() - - // The stdout/stderr is shared by multiple tasks, because we use one daemon - // to launch child process as worker. - val errThread = RRDD.createRWorker(listenPort) - - // We use two sockets to separate input and output, then it's easy to manage - // the lifecycle of them to avoid deadlock. - // TODO: optimize it to use one socket - - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() - - try { - - return new Iterator[U] { - def next(): U = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - var _nextObj = read() - - def hasNext(): Boolean = { - val hasMore = (_nextObj != null) - if (!hasMore) { - dataStream.close() - } - hasMore - } - } - } catch { - case e: Exception => - throw new SparkException("R computation failed with\n " + errThread.getLines()) - } - } - - /** - * Start a thread to write RDD data to the R process. - */ - private def startStdinThread[T]( - output: OutputStream, - iter: Iterator[T], - partition: Int): Unit = { - - val env = SparkEnv.get - val taskContext = TaskContext.get() - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val stream = new BufferedOutputStream(output, bufferSize) - - new Thread("writer for R") { - override def run(): Unit = { - try { - SparkEnv.set(env) - TaskContext.setTaskContext(taskContext) - val dataOut = new DataOutputStream(stream) - dataOut.writeInt(partition) - - SerDe.writeString(dataOut, deserializer) - SerDe.writeString(dataOut, serializer) - - dataOut.writeInt(packageNames.length) - dataOut.write(packageNames) - - dataOut.writeInt(func.length) - dataOut.write(func) - - dataOut.writeInt(broadcastVars.length) - broadcastVars.foreach { broadcast => - // TODO(shivaram): Read a Long in R to avoid this cast - dataOut.writeInt(broadcast.id.toInt) - // TODO: Pass a byte array from R to avoid this cast ? - val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] - dataOut.writeInt(broadcastByteArr.length) - dataOut.write(broadcastByteArr) - } - - dataOut.writeInt(numPartitions) - - if (!iter.hasNext) { - dataOut.writeInt(0) - } else { - dataOut.writeInt(1) - } - - val printOut = new PrintStream(stream) - - def writeElem(elem: Any): Unit = { - if (deserializer == SerializationFormats.BYTE) { - val elemArr = elem.asInstanceOf[Array[Byte]] - dataOut.writeInt(elemArr.length) - dataOut.write(elemArr) - } else if (deserializer == SerializationFormats.ROW) { - dataOut.write(elem.asInstanceOf[Array[Byte]]) - } else if (deserializer == SerializationFormats.STRING) { - // write string(for StringRRDD) - // scalastyle:off println - printOut.println(elem) - // scalastyle:on println - } - } - - for (elem <- iter) { - elem match { - case (key, value) => - writeElem(key) - writeElem(value) - case _ => - writeElem(elem) - } - } - stream.flush() - } catch { - // TODO: We should propogate this error to the task thread - case e: Exception => - logError("R Writer thread got an exception", e) - } finally { - Try(output.close()) - } - } - }.start() - } - - protected def readData(length: Int): U - - protected def read(): U = { - try { - val length = dataStream.readInt() - - length match { - case SpecialLengths.TIMING_DATA => - // Timing data from R worker - val boot = dataStream.readDouble - bootTime - val init = dataStream.readDouble - val broadcast = dataStream.readDouble - val input = dataStream.readDouble - val compute = dataStream.readDouble - val output = dataStream.readDouble - logInfo( - ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + - "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + - "total = %.3f s").format( - boot, - init, - broadcast, - input, - compute, - output, - boot + init + broadcast + input + compute + output)) - read() - case length if length >= 0 => - readData(length) - } - } catch { - case eof: EOFException => - throw new SparkException("R worker exited unexpectedly (cranshed)", eof) - } + runner.compute(parentIterator, partition.index, context) } } @@ -242,19 +65,6 @@ private class PairwiseRRDD[T: ClassTag]( parent, numPartitions, hashFunc, deserializer, SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - - override protected def readData(length: Int): (Int, Array[Byte]) = { - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null - } - } - lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -271,17 +81,6 @@ private class RRDD[T: ClassTag]( extends BaseRRDD[T, Array[Byte]]( parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - - override protected def readData(length: Int): Array[Byte] = { - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj) - obj - case _ => null - } - } - lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -297,55 +96,10 @@ private class StringRRDD[T: ClassTag]( extends BaseRRDD[T, String]( parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - - override protected def readData(length: Int): String = { - length match { - case length if length > 0 => - SerDe.readStringBytes(dataStream, length) - case _ => null - } - } - lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } -private object SpecialLengths { - val TIMING_DATA = -1 -} - -private[r] class BufferedStreamThread( - in: InputStream, - name: String, - errBufferSize: Int) extends Thread(name) with Logging { - val lines = new Array[String](errBufferSize) - var lineIdx = 0 - override def run() { - for (line <- Source.fromInputStream(in).getLines) { - synchronized { - lines(lineIdx) = line - lineIdx = (lineIdx + 1) % errBufferSize - } - logInfo(line) - } - } - - def getLines(): String = synchronized { - (0 until errBufferSize).filter { x => - lines((x + lineIdx) % errBufferSize) != null - }.map { x => - lines((x + lineIdx) % errBufferSize) - }.mkString("\n") - } -} - private[r] object RRDD { - // Because forking processes from Java is expensive, we prefer to launch - // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. - // This daemon currently only works on UNIX-based systems now, so we should - // also fall back to launching workers (worker.R) directly. - private[this] var errThread: BufferedStreamThread = _ - private[this] var daemonChannel: DataOutputStream = _ - def createSparkContext( master: String, appName: String, @@ -353,7 +107,6 @@ private[r] object RRDD { jars: Array[String], sparkEnvirMap: JMap[Object, Object], sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { - val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) @@ -380,78 +133,6 @@ private[r] object RRDD { jsc } - /** - * Start a thread to print the process's stderr to ours - */ - private def startStdoutThread(proc: Process): BufferedStreamThread = { - val BUFFER_SIZE = 100 - val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) - thread.setDaemon(true) - thread.start() - thread - } - - private def createRProcess(port: Int, script: String): BufferedStreamThread = { - // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", - // but kept here for backward compatibility. - val sparkConf = SparkEnv.get.conf - var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") - rCommand = sparkConf.get("spark.r.command", rCommand) - - val rOptions = "--vanilla" - val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir(0) + "/SparkR/worker/" + script - val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) - // Unset the R_TESTS environment variable for workers. - // This is set by R CMD check as startup.Rs - // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) - // and confuses worker script which tries to load a non-existent file - pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) - pb.environment().put("SPARKR_WORKER_PORT", port.toString) - pb.redirectErrorStream(true) // redirect stderr into stdout - val proc = pb.start() - val errThread = startStdoutThread(proc) - errThread - } - - /** - * ProcessBuilder used to launch worker R processes. - */ - def createRWorker(port: Int): BufferedStreamThread = { - val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) - if (!Utils.isWindows && useDaemon) { - synchronized { - if (daemonChannel == null) { - // we expect one connections - val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(daemonPort, "daemon.R") - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() - } - try { - daemonChannel.writeInt(port) - daemonChannel.flush() - } catch { - case e: IOException => - // daemon process died - daemonChannel.close() - daemonChannel = null - errThread = null - // fail the current task, retry by scheduler - throw e - } - errThread - } - } else { - createRProcess(port, "worker.R") - } - } - /** * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is * called from R. @@ -459,5 +140,4 @@ private[r] object RRDD { def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) } - } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala new file mode 100644 index 0000000000..e8fcada453 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.{InetAddress, ServerSocket} +import java.util.Arrays + +import scala.io.Source +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.util.Utils + +/** + * A helper class to run R UDFs in Spark. + */ +private[spark] class RRunner[U]( + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + numPartitions: Int = -1) + extends Logging { + private var bootTime: Double = _ + private var dataStream: DataInputStream = _ + val readData = numPartitions match { + case -1 => + serializer match { + case SerializationFormats.STRING => readStringData _ + case _ => readByteArrayData _ + } + case _ => readShuffledData _ + } + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[U] = { + // Timing start + bootTime = System.currentTimeMillis / 1000.0 + + // we expect two connections + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRunner.createRWorker(listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + dataStream = new DataInputStream(inputStream) + serverSocket.close() + + try { + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread( + output: OutputStream, + iter: Iterator[_], + partitionIndex: Int): Unit = { + val env = SparkEnv.get + val taskContext = TaskContext.get() + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partitionIndex) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + // scalastyle:off println + printOut.println(elem) + // scalastyle:on println + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + private def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length).asInstanceOf[U] + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } + } + + private def readShuffledData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } + } + + private def readByteArrayData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null + } + } + + private def readStringData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null + } + } +} + +private object SpecialLengths { + val TIMING_DATA = -1 +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRunner { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + val sparkConf = SparkEnv.get.conf + var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") + rCommand = sparkConf.get("spark.r.command", rCommand) + + val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) + val rExecScript = rLibDir(0) + "/SparkR/worker/" + script + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(port, "worker.R") + } + } +} -- GitLab