diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b6c8e1dc6c1b7a665c0facc66d606626b917e3f8..3b0c16be5a754e32a2e31320fd52af7d050957c2 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -791,11 +791,3 @@ test_that("sampleByKey() on pairwise RDDs", { expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) }) - -test_that("Test correct concurrency of RRDD.compute()", { - rdd <- parallelize(sc, 1:1000, 100) - jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") - zrdd <- callJMethod(jrdd, "zip", jrdd) - count <- callJMethod(zrdd, "count") - expect_equal(count, 1000) -}) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 606ba6ef867a15fddf6d801ff47e7bd16369cab8..588a57e65f5534dd952e46cad8128c7bdf2b9d89 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,16 +17,21 @@ package org.apache.spark.api.r -import java.util.{Map => JMap} +import java.io._ +import java.net.{InetAddress, ServerSocket} +import java.util.{Arrays, Map => JMap} import scala.collection.JavaConverters._ +import scala.io.Source import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -37,16 +42,188 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { + protected var dataStream: DataInputStream = _ + private var bootTime: Double = _ override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - val runner = new RRunner[U]( - func, deserializer, serializer, packageNames, broadcastVars, numPartitions) + + // Timing start + bootTime = System.currentTimeMillis / 1000.0 // The parent may be also an RRDD, so we should launch it first. val parentIterator = firstParent[T].iterator(partition, context) - runner.compute(parentIterator, partition.index, context) + // we expect two connections + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + dataStream = new DataInputStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val taskContext = TaskContext.get() + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + // scalastyle:off println + printOut.println(elem) + // scalastyle:on println + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def readData(length: Int): U + + protected def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length) + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } } } @@ -65,6 +242,19 @@ private class PairwiseRRDD[T: ClassTag]( parent, numPartitions, hashFunc, deserializer, SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } + } + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -81,6 +271,17 @@ private class RRDD[T: ClassTag]( extends BaseRRDD[T, Array[Byte]]( parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null + } + } + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -96,10 +297,55 @@ private class StringRRDD[T: ClassTag]( extends BaseRRDD[T, String]( parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null + } + } + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } +private object SpecialLengths { + val TIMING_DATA = -1 +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + def createSparkContext( master: String, appName: String, @@ -107,6 +353,7 @@ private[r] object RRDD { jars: Array[String], sparkEnvirMap: JMap[Object, Object], sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) @@ -133,6 +380,78 @@ private[r] object RRDD { jsc } + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + val sparkConf = SparkEnv.get.conf + var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") + rCommand = sparkConf.get("spark.r.command", rCommand) + + val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) + val rExecScript = rLibDir(0) + "/SparkR/worker/" + script + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(port, "worker.R") + } + } + /** * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is * called from R. @@ -140,4 +459,5 @@ private[r] object RRDD { def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala deleted file mode 100644 index e8fcada4532606687ea748219ccd498600f16985..0000000000000000000000000000000000000000 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ /dev/null @@ -1,367 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.r - -import java.io._ -import java.net.{InetAddress, ServerSocket} -import java.util.Arrays - -import scala.io.Source -import scala.util.Try - -import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.util.Utils - -/** - * A helper class to run R UDFs in Spark. - */ -private[spark] class RRunner[U]( - func: Array[Byte], - deserializer: String, - serializer: String, - packageNames: Array[Byte], - broadcastVars: Array[Broadcast[Object]], - numPartitions: Int = -1) - extends Logging { - private var bootTime: Double = _ - private var dataStream: DataInputStream = _ - val readData = numPartitions match { - case -1 => - serializer match { - case SerializationFormats.STRING => readStringData _ - case _ => readByteArrayData _ - } - case _ => readShuffledData _ - } - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext): Iterator[U] = { - // Timing start - bootTime = System.currentTimeMillis / 1000.0 - - // we expect two connections - val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) - val listenPort = serverSocket.getLocalPort() - - // The stdout/stderr is shared by multiple tasks, because we use one daemon - // to launch child process as worker. - val errThread = RRunner.createRWorker(listenPort) - - // We use two sockets to separate input and output, then it's easy to manage - // the lifecycle of them to avoid deadlock. - // TODO: optimize it to use one socket - - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() - - try { - return new Iterator[U] { - def next(): U = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - var _nextObj = read() - - def hasNext(): Boolean = { - val hasMore = (_nextObj != null) - if (!hasMore) { - dataStream.close() - } - hasMore - } - } - } catch { - case e: Exception => - throw new SparkException("R computation failed with\n " + errThread.getLines()) - } - } - - /** - * Start a thread to write RDD data to the R process. - */ - private def startStdinThread( - output: OutputStream, - iter: Iterator[_], - partitionIndex: Int): Unit = { - val env = SparkEnv.get - val taskContext = TaskContext.get() - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val stream = new BufferedOutputStream(output, bufferSize) - - new Thread("writer for R") { - override def run(): Unit = { - try { - SparkEnv.set(env) - TaskContext.setTaskContext(taskContext) - val dataOut = new DataOutputStream(stream) - dataOut.writeInt(partitionIndex) - - SerDe.writeString(dataOut, deserializer) - SerDe.writeString(dataOut, serializer) - - dataOut.writeInt(packageNames.length) - dataOut.write(packageNames) - - dataOut.writeInt(func.length) - dataOut.write(func) - - dataOut.writeInt(broadcastVars.length) - broadcastVars.foreach { broadcast => - // TODO(shivaram): Read a Long in R to avoid this cast - dataOut.writeInt(broadcast.id.toInt) - // TODO: Pass a byte array from R to avoid this cast ? - val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] - dataOut.writeInt(broadcastByteArr.length) - dataOut.write(broadcastByteArr) - } - - dataOut.writeInt(numPartitions) - - if (!iter.hasNext) { - dataOut.writeInt(0) - } else { - dataOut.writeInt(1) - } - - val printOut = new PrintStream(stream) - - def writeElem(elem: Any): Unit = { - if (deserializer == SerializationFormats.BYTE) { - val elemArr = elem.asInstanceOf[Array[Byte]] - dataOut.writeInt(elemArr.length) - dataOut.write(elemArr) - } else if (deserializer == SerializationFormats.ROW) { - dataOut.write(elem.asInstanceOf[Array[Byte]]) - } else if (deserializer == SerializationFormats.STRING) { - // write string(for StringRRDD) - // scalastyle:off println - printOut.println(elem) - // scalastyle:on println - } - } - - for (elem <- iter) { - elem match { - case (key, value) => - writeElem(key) - writeElem(value) - case _ => - writeElem(elem) - } - } - stream.flush() - } catch { - // TODO: We should propogate this error to the task thread - case e: Exception => - logError("R Writer thread got an exception", e) - } finally { - Try(output.close()) - } - } - }.start() - } - - private def read(): U = { - try { - val length = dataStream.readInt() - - length match { - case SpecialLengths.TIMING_DATA => - // Timing data from R worker - val boot = dataStream.readDouble - bootTime - val init = dataStream.readDouble - val broadcast = dataStream.readDouble - val input = dataStream.readDouble - val compute = dataStream.readDouble - val output = dataStream.readDouble - logInfo( - ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + - "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + - "total = %.3f s").format( - boot, - init, - broadcast, - input, - compute, - output, - boot + init + broadcast + input + compute + output)) - read() - case length if length >= 0 => - readData(length).asInstanceOf[U] - } - } catch { - case eof: EOFException => - throw new SparkException("R worker exited unexpectedly (cranshed)", eof) - } - } - - private def readShuffledData(length: Int): (Int, Array[Byte]) = { - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null - } - } - - private def readByteArrayData(length: Int): Array[Byte] = { - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj) - obj - case _ => null - } - } - - private def readStringData(length: Int): String = { - length match { - case length if length > 0 => - SerDe.readStringBytes(dataStream, length) - case _ => null - } - } -} - -private object SpecialLengths { - val TIMING_DATA = -1 -} - -private[r] class BufferedStreamThread( - in: InputStream, - name: String, - errBufferSize: Int) extends Thread(name) with Logging { - val lines = new Array[String](errBufferSize) - var lineIdx = 0 - override def run() { - for (line <- Source.fromInputStream(in).getLines) { - synchronized { - lines(lineIdx) = line - lineIdx = (lineIdx + 1) % errBufferSize - } - logInfo(line) - } - } - - def getLines(): String = synchronized { - (0 until errBufferSize).filter { x => - lines((x + lineIdx) % errBufferSize) != null - }.map { x => - lines((x + lineIdx) % errBufferSize) - }.mkString("\n") - } -} - -private[r] object RRunner { - // Because forking processes from Java is expensive, we prefer to launch - // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. - // This daemon currently only works on UNIX-based systems now, so we should - // also fall back to launching workers (worker.R) directly. - private[this] var errThread: BufferedStreamThread = _ - private[this] var daemonChannel: DataOutputStream = _ - - /** - * Start a thread to print the process's stderr to ours - */ - private def startStdoutThread(proc: Process): BufferedStreamThread = { - val BUFFER_SIZE = 100 - val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) - thread.setDaemon(true) - thread.start() - thread - } - - private def createRProcess(port: Int, script: String): BufferedStreamThread = { - // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", - // but kept here for backward compatibility. - val sparkConf = SparkEnv.get.conf - var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") - rCommand = sparkConf.get("spark.r.command", rCommand) - - val rOptions = "--vanilla" - val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir(0) + "/SparkR/worker/" + script - val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) - // Unset the R_TESTS environment variable for workers. - // This is set by R CMD check as startup.Rs - // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) - // and confuses worker script which tries to load a non-existent file - pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) - pb.environment().put("SPARKR_WORKER_PORT", port.toString) - pb.redirectErrorStream(true) // redirect stderr into stdout - val proc = pb.start() - val errThread = startStdoutThread(proc) - errThread - } - - /** - * ProcessBuilder used to launch worker R processes. - */ - def createRWorker(port: Int): BufferedStreamThread = { - val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) - if (!Utils.isWindows && useDaemon) { - synchronized { - if (daemonChannel == null) { - // we expect one connections - val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(daemonPort, "daemon.R") - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() - } - try { - daemonChannel.writeInt(port) - daemonChannel.flush() - } catch { - case e: IOException => - // daemon process died - daemonChannel.close() - daemonChannel = null - errThread = null - // fail the current task, retry by scheduler - throw e - } - errThread - } - } else { - createRProcess(port, "worker.R") - } - } -}