From 71a138cd0e0a14e8426f97877e3b52a562bbd02c Mon Sep 17 00:00:00 2001 From: Sun Rui <rui.sun@intel.com> Date: Tue, 25 Aug 2015 13:14:10 -0700 Subject: [PATCH] [SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde. This PR: 1. supports transferring arbitrary nested array from JVM to R side in SerDe; 2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types from a DataFrame. Author: Sun Rui <rui.sun@intel.com> Closes #8276 from sun-rui/SPARK-10048. --- R/pkg/R/DataFrame.R | 55 +++++++++--- R/pkg/R/deserialize.R | 72 +++++++--------- R/pkg/R/serialize.R | 10 +-- R/pkg/inst/tests/test_Serde.R | 77 +++++++++++++++++ R/pkg/inst/worker/worker.R | 4 +- .../apache/spark/api/r/RBackendHandler.scala | 7 ++ .../scala/org/apache/spark/api/r/SerDe.scala | 86 +++++++++++-------- .../org/apache/spark/sql/api/r/SQLUtils.scala | 32 +------ 8 files changed, 216 insertions(+), 127 deletions(-) create mode 100644 R/pkg/inst/tests/test_Serde.R diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 10f3c4ea59..ae1d912cf6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -652,18 +652,49 @@ setMethod("dim", setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - # listCols is a list of raw vectors, one per column - listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) - cols <- lapply(listCols, function(col) { - objRaw <- rawConnection(col) - numRows <- readInt(objRaw) - col <- readCol(objRaw, numRows) - close(objRaw) - col - }) - names(cols) <- columns(x) - do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) - }) + names <- columns(x) + ncol <- length(names) + if (ncol <= 0) { + # empty data.frame with 0 columns and 0 rows + data.frame() + } else { + # listCols is a list of columns + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + stopifnot(length(listCols) == ncol) + + # An empty data.frame with 0 columns and number of rows as collected + nrow <- length(listCols[[1]]) + if (nrow <= 0) { + df <- data.frame() + } else { + df <- data.frame(row.names = 1 : nrow) + } + + # Append columns one by one + for (colIndex in 1 : ncol) { + # Note: appending a column of list type into a data.frame so that + # data of complex type can be held. But getting a cell from a column + # of list type returns a list instead of a vector. So for columns of + # non-complex type, append them as vector. + col <- listCols[[colIndex]] + if (length(col) <= 0) { + df[[names[colIndex]]] <- col + } else { + # TODO: more robust check on column of primitive types + vec <- do.call(c, col) + if (class(vec) != "list") { + df[[names[colIndex]]] <- vec + } else { + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. + df[[names[colIndex]]] <- col + } + } + } + df + } + }) #' Limit #' diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 33bf13ec9e..6cf628e300 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -48,6 +48,7 @@ readTypedObject <- function(con, type) { "r" = readRaw(con), "D" = readDate(con), "t" = readTime(con), + "a" = readArray(con), "l" = readList(con), "n" = NULL, "j" = getJobj(readString(con)), @@ -85,8 +86,7 @@ readTime <- function(con) { as.POSIXct(t, origin = "1970-01-01") } -# We only support lists where all elements are of same type -readList <- function(con) { +readArray <- function(con) { type <- readType(con) len <- readInt(con) if (len > 0) { @@ -100,6 +100,25 @@ readList <- function(con) { } } +# Read a list. Types of each element may be different. +# Null objects are read as NA. +readList <- function(con) { + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + elem <- readObject(con) + if (is.null(elem)) { + elem <- NA + } + l[[i]] <- elem + } + l + } else { + list() + } +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") @@ -132,18 +151,19 @@ readDeserialize <- function(con) { } } -readDeserializeRows <- function(inputCon) { - # readDeserializeRows will deserialize a DataOutputStream composed of - # a list of lists. Since the DOS is one continuous stream and - # the number of rows varies, we put the readRow function in a while loop - # that termintates when the next row is empty. +readMultipleObjects <- function(inputCon) { + # readMultipleObjects will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. data <- list() while(TRUE) { - row <- readRow(inputCon) - if (length(row) == 0) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { break } - data[[length(data) + 1L]] <- row + data[[length(data) + 1L]] <- readTypedObject(inputCon, type) } data # this is a list of named lists now } @@ -155,35 +175,5 @@ readRowList <- function(obj) { # deserialize the row. rawObj <- rawConnection(obj, "r+") on.exit(close(rawObj)) - readRow(rawObj) -} - -readRow <- function(inputCon) { - numCols <- readInt(inputCon) - if (length(numCols) > 0 && numCols > 0) { - lapply(1:numCols, function(x) { - obj <- readObject(inputCon) - if (is.null(obj)) { - NA - } else { - obj - } - }) # each row is a list now - } else { - list() - } -} - -# Take a single column as Array[Byte] and deserialize it into an atomic vector -readCol <- function(inputCon, numRows) { - if (numRows > 0) { - # sapply can not work with POSIXlt - do.call(c, lapply(1:numRows, function(x) { - value <- readObject(inputCon) - # Replace NULL with NA so we can coerce to vectors - if (is.null(value)) NA else value - })) - } else { - vector() - } + readObject(rawObj) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 311021e5d8..e3676f57f9 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeRow(rawObj, row) + writeGenericList(rawObj, row) rawConnectionValue(rawObj) } -writeRow <- function(con, row) { - numCols <- length(row) - writeInt(con, numCols) - for (i in 1:numCols) { - writeObject(con, row[[i]]) - } -} - writeRaw <- function(con, batch) { writeInt(con, length(batch)) writeBin(batch, con, endian = "big") diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R new file mode 100644 index 0000000000..009db85da2 --- /dev/null +++ b/R/pkg/inst/tests/test_Serde.R @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("SerDe functionality") + +sc <- sparkR.init() + +test_that("SerDe of primitive types", { + x <- callJStatic("SparkRHandler", "echo", 1L) + expect_equal(x, 1L) + expect_equal(class(x), "integer") + + x <- callJStatic("SparkRHandler", "echo", 1) + expect_equal(x, 1) + expect_equal(class(x), "numeric") + + x <- callJStatic("SparkRHandler", "echo", TRUE) + expect_true(x) + expect_equal(class(x), "logical") + + x <- callJStatic("SparkRHandler", "echo", "abc") + expect_equal(x, "abc") + expect_equal(class(x), "character") +}) + +test_that("SerDe of list of primitive types", { + x <- list(1L, 2L, 3L) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "integer") + + x <- list(1, 2, 3) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "numeric") + + x <- list(TRUE, FALSE) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "logical") + + x <- list("a", "b", "c") + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "character") + + # Empty list + x <- list() + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) + +test_that("SerDe of list of lists", { + x <- list(list(1L, 2L, 3L), list(1, 2, 3), + list(TRUE, FALSE), list("a", "b", "c")) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + + # List of empty lists + x <- list(list(), list()) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 7e3b5fc403..0c3b0d1f4b 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -94,7 +94,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- as.list(readLines(inputCon)) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() @@ -120,7 +120,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- readLines(inputCon) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 6ce02e2ea3..bb82f3285f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend) if (objId == "SparkRHandler") { methodName match { + // This function is for test-purpose only + case "echo" => + val args = readArgs(numArgs, dis) + assert(numArgs == 1) + + writeInt(dos, 0) + writeObject(dos, args(0)) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index dbbbcf40c1..26ad4f1d46 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -149,6 +149,10 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'l' => { + val len = readInt(dis) + (0 until len).map(_ => readList(dis)).toArray + } case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -200,6 +204,9 @@ private[spark] object SerDe { case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') + // Array of primitive types + case "array" => dos.writeByte('a') + // Array of objects case "list" => dos.writeByte('l') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") @@ -211,26 +218,35 @@ private[spark] object SerDe { writeType(dos, "void") } else { value.getClass.getName match { + case "java.lang.Character" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[Character].toString) case "java.lang.String" => writeType(dos, "character") writeString(dos, value.asInstanceOf[String]) - case "long" | "java.lang.Long" => + case "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "float" | "java.lang.Float" => + case "java.lang.Float" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "decimal" | "java.math.BigDecimal" => + case "java.math.BigDecimal" => writeType(dos, "double") val javaDecimal = value.asInstanceOf[java.math.BigDecimal] writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "double" | "java.lang.Double" => + case "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) - case "int" | "java.lang.Integer" => + case "java.lang.Byte" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Byte].toInt) + case "java.lang.Short" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Short].toInt) + case "java.lang.Integer" => writeType(dos, "integer") writeInt(dos, value.asInstanceOf[Int]) - case "boolean" | "java.lang.Boolean" => + case "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) case "java.sql.Date" => @@ -242,43 +258,48 @@ private[spark] object SerDe { case "java.sql.Timestamp" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Timestamp]) + + // Handle arrays + + // Array of primitive types + + // Special handling for byte array case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) - // TODO: Types not handled right now include - // byte, char, short, float - // Handle arrays - case "[Ljava.lang.String;" => - writeType(dos, "list") - writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[C" => + writeType(dos, "array") + writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) + case "[S" => + writeType(dos, "array") + writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) case "[I" => - writeType(dos, "list") + writeType(dos, "array") writeIntArr(dos, value.asInstanceOf[Array[Int]]) case "[J" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[F" => + writeType(dos, "array") + writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) case "[D" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) case "[Z" => - writeType(dos, "list") + writeType(dos, "array") writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) - case "[[B" => + + // Array of objects, null objects use "void" type + case c if c.startsWith("[") => writeType(dos, "list") - writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) - case otherName => - // Handle array of objects - if (otherName.startsWith("[L")) { - val objArr = value.asInstanceOf[Array[Object]] - writeType(dos, "list") - writeType(dos, "jobj") - dos.writeInt(objArr.length) - objArr.foreach(o => writeJObj(dos, o)) - } else { - writeType(dos, "jobj") - writeJObj(dos, value) - } + val array = value.asInstanceOf[Array[Object]] + writeInt(dos, array.length) + array.foreach(elem => writeObject(dos, elem)) + + case _ => + writeType(dos, "jobj") + writeJObj(dos, value) } } } @@ -350,11 +371,6 @@ private[spark] object SerDe { value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { - writeType(out, "raw") - out.writeInt(value.length) - value.foreach(v => writeBytes(out, v)) - } } private[r] object SerializationFormats { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 92861ab038..7f3defec3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -98,27 +98,17 @@ private[r] object SQLUtils { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - SerDe.writeInt(dos, row.length) - (0 until row.length).map { idx => - val obj: Object = row(idx).asInstanceOf[Object] - SerDe.writeObject(dos, obj) - } + val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray + SerDe.writeObject(dos, cols) bos.toByteArray() } - def dfToCols(df: DataFrame): Array[Array[Byte]] = { + def dfToCols(df: DataFrame): Array[Array[Any]] = { // localDF is Array[Row] val localDF = df.collect() val numCols = df.columns.length - // dfCols is Array[Array[Any]] - val dfCols = convertRowsToColumns(localDF, numCols) - - dfCols.map { col => - colToRBytes(col) - } - } - def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + // result is Array[Array[Any]] (0 until numCols).map { colIdx => localDF.map { row => row(colIdx) @@ -126,20 +116,6 @@ private[r] object SQLUtils { }.toArray } - def colToRBytes(col: Array[Any]): Array[Byte] = { - val numRows = col.length - val bos = new ByteArrayOutputStream() - val dos = new DataOutputStream(bos) - - SerDe.writeInt(dos, numRows) - - col.map { item => - val obj: Object = item.asInstanceOf[Object] - SerDe.writeObject(dos, obj) - } - bos.toByteArray() - } - def saveMode(mode: String): SaveMode = { mode match { case "append" => SaveMode.Append -- GitLab