From 59e206deb7346148412bbf5ba4ab626718fadf18 Mon Sep 17 00:00:00 2001 From: cafreeman <cfreeman@alteryx.com> Date: Fri, 17 Apr 2015 13:42:19 -0700 Subject: [PATCH] [SPARK-6807] [SparkR] Merge recent SparkR-pkg changes This PR pulls in recent changes in SparkR-pkg, including cartesian, intersection, sampleByKey, subtract, subtractByKey, except, and some API for StructType and StructField. Author: cafreeman <cfreeman@alteryx.com> Author: Davies Liu <davies@databricks.com> Author: Zongheng Yang <zongheng.y@gmail.com> Author: Shivaram Venkataraman <shivaram.venkataraman@gmail.com> Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu> Author: Sun Rui <rui.sun@intel.com> Closes #5436 from davies/R3 and squashes the following commits: c2b09be [Davies Liu] SQLTypes -> schema a5a02f2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R3 168b7fe [Davies Liu] sort generics b1fe460 [Davies Liu] fix conflict in README.md e74c04e [Davies Liu] fix schema.R 4f5ac09 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R5 41f8184 [Davies Liu] rm man ae78312 [Davies Liu] Merge pull request #237 from sun-rui/SPARKR-154_3 1bdcb63 [Zongheng Yang] Updates to README.md. 5a553e7 [cafreeman] Use object attribute instead of argument 71372d9 [cafreeman] Update docs and examples 8526d2e71 [cafreeman] Remove `tojson` functions 6ef5f2d [cafreeman] Fix spacing 7741d66 [cafreeman] Rename the SQL DataType function 141efd8 [Shivaram Venkataraman] Merge pull request #245 from hqzizania/upstream 9387402 [Davies Liu] fix style 40199eb [Shivaram Venkataraman] Move except into sorted position 07d0dbc [Sun Rui] [SPARKR-244] Fix test failure after integration of subtract() and subtractByKey() for RDD. 7e8caa3 [Shivaram Venkataraman] Merge pull request #246 from hlin09/fixCombineByKey ed66c81 [cafreeman] Update `subtract` to work with `generics.R` f3ba785 [cafreeman] Fixed duplicate export 275deb4 [cafreeman] Update `NAMESPACE` and tests 1a3b63d [cafreeman] new version of `CreateDF` 836c4bf [cafreeman] Update `createDataFrame` and `toDF` be5d5c1 [cafreeman] refactor schema functions 40338a4 [Zongheng Yang] Merge pull request #244 from sun-rui/SPARKR-154_5 20b97a6 [Zongheng Yang] Merge pull request #234 from hqzizania/assist ba54e34 [Shivaram Venkataraman] Merge pull request #238 from sun-rui/SPARKR-154_4 c9497a3 [Shivaram Venkataraman] Merge pull request #208 from lythesia/master b317aa7 [Zongheng Yang] Merge pull request #243 from hqzizania/master 136a07e [Zongheng Yang] Merge pull request #242 from hqzizania/stats cd66603 [cafreeman] new line at EOF 8b76e81 [Shivaram Venkataraman] Merge pull request #233 from redbaron/fail-early-on-missing-dep 7dd81b7 [cafreeman] Documentation 0e2a94f [cafreeman] Define functions for schema and fields --- R/pkg/DESCRIPTION | 2 +- R/pkg/NAMESPACE | 20 +- R/pkg/R/DataFrame.R | 18 +- R/pkg/R/RDD.R | 205 ++++++++++++------ R/pkg/R/SQLContext.R | 44 +--- R/pkg/R/SQLTypes.R | 64 ------ R/pkg/R/column.R | 2 +- R/pkg/R/generics.R | 46 +++- R/pkg/R/group.R | 2 +- R/pkg/R/pairRDD.R | 192 +++++++++++++--- R/pkg/R/schema.R | 162 ++++++++++++++ R/pkg/R/serialize.R | 9 +- R/pkg/R/utils.R | 80 +++++++ R/pkg/inst/tests/test_rdd.R | 193 ++++++++++++++--- R/pkg/inst/tests/test_shuffle.R | 12 + R/pkg/inst/tests/test_sparkSQL.R | 35 +-- R/pkg/inst/worker/worker.R | 59 ++++- .../scala/org/apache/spark/api/r/RRDD.scala | 131 +++++------ .../scala/org/apache/spark/api/r/SerDe.scala | 14 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 32 ++- 20 files changed, 971 insertions(+), 351 deletions(-) delete mode 100644 R/pkg/R/SQLTypes.R create mode 100644 R/pkg/R/schema.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 052f68c6c2..1c1779a763 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -19,7 +19,7 @@ Collate: 'jobj.R' 'RDD.R' 'pairRDD.R' - 'SQLTypes.R' + 'schema.R' 'column.R' 'group.R' 'DataFrame.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index a354cdce74..8028364386 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -5,6 +5,7 @@ exportMethods( "aggregateByKey", "aggregateRDD", "cache", + "cartesian", "checkpoint", "coalesce", "cogroup", @@ -28,6 +29,7 @@ exportMethods( "fullOuterJoin", "glom", "groupByKey", + "intersection", "join", "keyBy", "keys", @@ -52,11 +54,14 @@ exportMethods( "reduceByKeyLocally", "repartition", "rightOuterJoin", + "sampleByKey", "sampleRDD", "saveAsTextFile", "saveAsObjectFile", "sortBy", "sortByKey", + "subtract", + "subtractByKey", "sumRDD", "take", "takeOrdered", @@ -95,6 +100,7 @@ exportClasses("DataFrame") exportMethods("columns", "distinct", "dtypes", + "except", "explain", "filter", "groupBy", @@ -118,7 +124,6 @@ exportMethods("columns", "show", "showDF", "sortDF", - "subtract", "toJSON", "toRDD", "unionAll", @@ -178,5 +183,14 @@ export("cacheTable", "toDF", "uncacheTable") -export("print.structType", - "print.structField") +export("sparkRSQL.init", + "sparkRHive.init") + +export("structField", + "structField.jobj", + "structField.character", + "print.structField", + "structType", + "structType.jobj", + "structType.structField", + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 044fdb4d01..861fe1c78b 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -17,7 +17,7 @@ # DataFrame.R - DataFrame class and methods implemented in S4 OO classes -#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +#' @include generics.R jobj.R schema.R RDD.R pairRDD.R column.R group.R NULL setOldClass("jobj") @@ -1141,15 +1141,15 @@ setMethod("intersect", dataFrame(intersected) }) -#' Subtract +#' except #' #' Return a new DataFrame containing rows in this DataFrame #' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame -#' @return A DataFrame containing the result of the subtract operation. -#' @rdname subtract +#' @return A DataFrame containing the result of the except operation. +#' @rdname except #' @export #' @examples #'\dontrun{ @@ -1157,13 +1157,15 @@ setMethod("intersect", #' sqlCtx <- sparkRSQL.init(sc) #' df1 <- jsonFile(sqlCtx, path) #' df2 <- jsonFile(sqlCtx, path2) -#' subtractDF <- subtract(df, df2) +#' exceptDF <- except(df, df2) #' } -setMethod("subtract", +#' @rdname except +#' @export +setMethod("except", signature(x = "DataFrame", y = "DataFrame"), function(x, y) { - subtracted <- callJMethod(x@sdf, "except", y@sdf) - dataFrame(subtracted) + excepted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(excepted) }) #' Save the contents of the DataFrame to a data source diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 820027ef67..128431334c 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -730,6 +730,7 @@ setMethod("take", index <- -1 jrdd <- getJRDD(x) numPartitions <- numPartitions(x) + serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size # estimates similar to the scala version of `take`. @@ -748,13 +749,14 @@ setMethod("take", elems <- convertJListToRList(partition, flatten = TRUE, logicalUpperBound = size, - serializedMode = getSerializedMode(x)) - # TODO: Check if this append is O(n^2)? + serializedMode = serializedModeRDD) + resList <- append(resList, elems) } resList }) + #' First #' #' Return the first element of an RDD @@ -1092,21 +1094,42 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { if (num < length(part)) { # R limitation: order works only on primitive types! ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) - list(part[ord[1:num]]) + part[ord[1:num]] } else { - list(part) + part } } - reduceFunc <- function(elems, part) { - newElems <- append(elems, part) - # R limitation: order works only on primitive types! - ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) - newElems[ord[1:num]] - } - newRdd <- mapPartitions(x, partitionFunc) - reduce(newRdd, reduceFunc) + + resList <- list() + index <- -1 + jrdd <- getJRDD(newRdd) + numPartitions <- numPartitions(newRdd) + serializedModeRDD <- getSerializedMode(newRdd) + + while (TRUE) { + index <- index + 1 + + if (index >= numPartitions) { + ord <- order(unlist(resList, recursive = FALSE), decreasing = !ascending) + resList <- resList[ord[1:num]] + break + } + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + # elems is capped to have at most `num` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = num, + serializedMode = serializedModeRDD) + + resList <- append(resList, elems) + } + resList } #' Returns the first N elements from an RDD in ascending order. @@ -1465,67 +1488,105 @@ setMethod("zipRDD", stop("Can only zip RDDs which have the same number of partitions.") } - if (getSerializedMode(x) != getSerializedMode(other) || - getSerializedMode(x) == "byte") { - # Append the number of elements in each partition to that partition so that we can later - # check if corresponding partitions of both RDDs have the same number of elements. - # - # Note that this appending also serves the purpose of reserialization, because even if - # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded - # as a single byte array. For example, partitions of an RDD generated from partitionBy() - # may be encoded as multiple byte arrays. - appendLength <- function(part) { - part[[length(part) + 1]] <- length(part) + 1 - part - } - x <- lapplyPartition(x, appendLength) - other <- lapplyPartition(other, appendLength) - } + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "zip", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) - # The zippedRDD's elements are of scala Tuple2 type. The serialized - # flag Here is used for the elements inside the tuples. - serializerMode <- getSerializedMode(x) - zippedRDD <- RDD(zippedJRDD, serializerMode) + mergePartitions(rdd, TRUE) + }) + +#' Cartesian product of this RDD and another one. +#' +#' Return the Cartesian product of this RDD and another one, +#' that is, the RDD of all pairs of elements (a, b) where a +#' is in this and b is in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @return A new RDD which is the Cartesian product of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2) +#' sortByKey(cartesian(rdd, rdd)) +#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#'} +#' @rdname cartesian +#' @aliases cartesian,RDD,RDD-method +setMethod("cartesian", + signature(x = "RDD", other = "RDD"), + function(x, other) { + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "cartesian", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - partitionFunc <- function(split, part) { - len <- length(part) - if (len > 0) { - if (serializerMode == "byte") { - lengthOfValues <- part[[len]] - lengthOfKeys <- part[[len - lengthOfValues]] - stopifnot(len == lengthOfKeys + lengthOfValues) - - # check if corresponding partitions of both RDDs have the same number of elements. - if (lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") - } - - if (lengthOfKeys > 1) { - keys <- part[1 : (lengthOfKeys - 1)] - values <- part[(lengthOfKeys + 1) : (len - 1)] - } else { - keys <- list() - values <- list() - } - } else { - # Keys, values must have same length here, because this has - # been validated inside the JavaRDD.zip() function. - keys <- part[c(TRUE, FALSE)] - values <- part[c(FALSE, TRUE)] - } - mapply( - function(k, v) { - list(k, v) - }, - keys, - values, - SIMPLIFY = FALSE, - USE.NAMES = FALSE) - } else { - part - } + mergePartitions(rdd, FALSE) + }) + +#' Subtract an RDD with another RDD. +#' +#' Return an RDD with the elements from this that are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the elements from this that are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +#' rdd2 <- parallelize(sc, list(2, 4)) +#' collect(subtract(rdd1, rdd2)) +#' # list(1, 1, 3) +#'} +#' @rdname subtract +#' @aliases subtract,RDD +setMethod("subtract", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + mapFunction <- function(e) { list(e, NA) } + rdd1 <- map(x, mapFunction) + rdd2 <- map(other, mapFunction) + keys(subtractByKey(rdd1, rdd2, numPartitions)) + }) + +#' Intersection of this RDD and another one. +#' +#' Return the intersection of this RDD and another one. +#' The output will not contain any duplicate elements, +#' even if the input RDDs did. Performs a hash partition +#' across the cluster. +#' Note that this method performs a shuffle internally. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions The number of partitions in the result RDD. +#' @return An RDD which is the intersection of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' # list(1, 2, 3) +#'} +#' @rdname intersection +#' @aliases intersection,RDD +setMethod("intersection", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + rdd1 <- map(x, function(v) { list(v, NA) }) + rdd2 <- map(other, function(v) { list(v, NA) }) + + filterFunction <- function(elem) { + iters <- elem[[2]] + all(as.vector( + lapply(iters, function(iter) { length(iter) > 0 }), mode = "logical")) } - - PipelinedRDD(zippedRDD, partitionFunc) + + keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 930ada22f4..4f05ba524a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -54,9 +54,9 @@ infer_type <- function(x) { # StructType types <- lapply(x, infer_type) fields <- lapply(1:length(x), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + structField(names[[i]], types[[i]], TRUE) }) - list(type = "struct", fields = fields) + do.call(structType, fields) } } else if (length(x) > 1) { list(type = "array", elementType = type, containsNull = TRUE) @@ -65,30 +65,6 @@ infer_type <- function(x) { } } -#' dump the schema into JSON string -tojson <- function(x) { - if (is.list(x)) { - names <- names(x) - if (!is.null(names)) { - items <- lapply(names, function(n) { - safe_n <- gsub('"', '\\"', n) - paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') - }) - d <- paste(items, collapse = ', ') - paste('{', d, '}', sep = '') - } else { - l <- paste(lapply(x, tojson), collapse = ', ') - paste('[', l, ']', sep = '') - } - } else if (is.character(x)) { - paste('"', x, '"', sep = '') - } else if (is.logical(x)) { - if (x) "true" else "false" - } else { - stop(paste("unexpected type:", class(x))) - } -} - #' Create a DataFrame from an RDD #' #' Converts an RDD to a DataFrame by infer the types. @@ -134,7 +110,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { stop(paste("unexpected type:", class(data))) } - if (is.null(schema) || is.null(names(schema))) { + if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) { row <- first(rdd) names <- if (is.null(schema)) { names(row) @@ -143,7 +119,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { } if (is.null(names)) { names <- lapply(1:length(row), function(x) { - paste("_", as.character(x), sep = "") + paste("_", as.character(x), sep = "") }) } @@ -159,20 +135,18 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { types <- lapply(row, infer_type) fields <- lapply(1:length(row), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + structField(names[[i]], types[[i]], TRUE) }) - schema <- list(type = "struct", fields = fields) + schema <- do.call(structType, fields) } - stopifnot(class(schema) == "list") - stopifnot(schema$type == "struct") - stopifnot(class(schema$fields) == "list") - schemaString <- tojson(schema) + stopifnot(class(schema) == "structType") + # schemaString <- tojson(schema) jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schemaString, sqlCtx) + srdd, schema$jobj, sqlCtx) dataFrame(sdf) } diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R deleted file mode 100644 index 962fba5b3c..0000000000 --- a/R/pkg/R/SQLTypes.R +++ /dev/null @@ -1,64 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Utility functions for handling SparkSQL DataTypes. - -# Handler for StructType -structType <- function(st) { - obj <- structure(new.env(parent = emptyenv()), class = "structType") - obj$jobj <- st - obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } - obj -} - -#' Print a Spark StructType. -#' -#' This function prints the contents of a StructType returned from the -#' SparkR JVM backend. -#' -#' @param x A StructType object -#' @param ... further arguments passed to or from other methods -print.structType <- function(x, ...) { - fieldsList <- lapply(x$fields(), function(i) { i$print() }) - print(fieldsList) -} - -# Handler for StructField -structField <- function(sf) { - obj <- structure(new.env(parent = emptyenv()), class = "structField") - obj$jobj <- sf - obj$name <- function() { callJMethod(sf, "name") } - obj$dataType <- function() { callJMethod(sf, "dataType") } - obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } - obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } - obj$nullable <- function() { callJMethod(sf, "nullable") } - obj$print <- function() { paste("StructField(", - paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), - ")", sep = "") } - obj -} - -#' Print a Spark StructField. -#' -#' This function prints the contents of a StructField returned from the -#' SparkR JVM backend. -#' -#' @param x A StructField object -#' @param ... further arguments passed to or from other methods -print.structField <- function(x, ...) { - cat(x$print()) -} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index b282001d8b..95fb9ff088 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -17,7 +17,7 @@ # Column Class -#' @include generics.R jobj.R SQLTypes.R +#' @include generics.R jobj.R schema.R NULL setOldClass("jobj") diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5fb1ccaa84..6c62333901 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -230,6 +230,10 @@ setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") } ############ Binary Functions ############# +#' @rdname cartesian +#' @export +setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) + #' @rdname countByKey #' @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) @@ -238,6 +242,11 @@ setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) #' @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) +#' @rdname intersection +#' @export +setGeneric("intersection", function(x, other, numPartitions = 1L) { + standardGeneric("intersection") }) + #' @rdname keys #' @export setGeneric("keys", function(x) { standardGeneric("keys") }) @@ -250,12 +259,18 @@ setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) #' @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) +#' @rdname sampleByKey +#' @export +setGeneric("sampleByKey", + function(x, withReplacement, fractions, seed) { + standardGeneric("sampleByKey") + }) + #' @rdname values #' @export setGeneric("values", function(x) { standardGeneric("values") }) - ############ Shuffle Functions ############ #' @rdname aggregateByKey @@ -330,9 +345,24 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri #' @rdname sortByKey #' @export -setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { - standardGeneric("sortByKey") -}) +setGeneric("sortByKey", + function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") + }) + +#' @rdname subtract +#' @export +setGeneric("subtract", + function(x, other, numPartitions = 1L) { + standardGeneric("subtract") + }) + +#' @rdname subtractByKey +#' @export +setGeneric("subtractByKey", + function(x, other, numPartitions = 1L) { + standardGeneric("subtractByKey") + }) ################### Broadcast Variable Methods ################# @@ -357,6 +387,10 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @export setGeneric("explain", function(x, ...) { standardGeneric("explain") }) +#' @rdname except +#' @export +setGeneric("except", function(x, y) { standardGeneric("except") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -434,10 +468,6 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) -#' @rdname subtract -#' @export -setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) - #' @rdname tojson #' @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 855fbdfc7c..02237b3672 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -17,7 +17,7 @@ # group.R - GroupedData class and methods implemented in S4 OO classes -#' @include generics.R jobj.R SQLTypes.R column.R +#' @include generics.R jobj.R schema.R column.R NULL setOldClass("jobj") diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 5d64822859..13efebc11c 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -430,7 +430,7 @@ setMethod("combineByKey", pred <- function(item) exists(item$hash, keys) lapply(part, function(item) { - item$hash <- as.character(item[[1]]) + item$hash <- as.character(hashCode(item[[1]])) updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) }) convertEnvsToList(keys, combiners) @@ -443,7 +443,7 @@ setMethod("combineByKey", pred <- function(item) exists(item$hash, keys) lapply(part, function(item) { - item$hash <- as.character(item[[1]]) + item$hash <- as.character(hashCode(item[[1]])) updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) }) convertEnvsToList(keys, combiners) @@ -452,19 +452,19 @@ setMethod("combineByKey", }) #' Aggregate a pair RDD by each key. -#' +#' #' Aggregate the values of each key in an RDD, using given combine functions #' and a neutral "zero value". This function can return a different result type, #' U, than the type of the values in this RDD, V. Thus, we need one operation -#' for merging a V into a U and one operation for merging two U's, The former -#' operation is used for merging values within a partition, and the latter is -#' used for merging values between partitions. To avoid memory allocation, both -#' of these functions are allowed to modify and return their first argument +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument #' instead of creating a new U. -#' +#' #' @param x An RDD. #' @param zeroValue A neutral "zero value". -#' @param seqOp A function to aggregate the values of each key. It may return +#' @param seqOp A function to aggregate the values of each key. It may return #' a different result type from the type of the values. #' @param combOp A function to aggregate results of seqOp. #' @return An RDD containing the aggregation result. @@ -476,7 +476,7 @@ setMethod("combineByKey", #' zeroValue <- list(0, 0) #' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } #' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) #' # list(list(1, list(3, 2)), list(2, list(7, 2))) #'} #' @rdname aggregateByKey @@ -493,12 +493,12 @@ setMethod("aggregateByKey", }) #' Fold a pair RDD by each key. -#' +#' #' Aggregate the values of each key in an RDD, using an associative function "func" -#' and a neutral "zero value" which may be added to the result an arbitrary -#' number of times, and must not change the result (e.g., 0 for addition, or +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or #' 1 for multiplication.). -#' +#' #' @param x An RDD. #' @param zeroValue A neutral "zero value". #' @param func An associative function for folding values of each key. @@ -548,11 +548,11 @@ setMethod("join", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(FALSE, FALSE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), doJoin) }) @@ -568,8 +568,8 @@ setMethod("join", #' @param y An RDD to be joined. Should be an RDD where each element is #' list(K, V). #' @param numPartitions Number of partitions to create. -#' @return For each element (k, v) in x, the resulting RDD will either contain -#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) #' if no elements in rdd2 have key k. #' @examples #'\dontrun{ @@ -586,11 +586,11 @@ setMethod("leftOuterJoin", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(FALSE, TRUE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) @@ -623,18 +623,18 @@ setMethod("rightOuterJoin", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(TRUE, FALSE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) #' Full outer join two RDDs #' #' @description -#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). #' The key types of the two RDDs should be the same. #' #' @param x An RDD to be joined. Should be an RDD where each element is @@ -644,7 +644,7 @@ setMethod("rightOuterJoin", #' @param numPartitions Number of partitions to create. #' @return For each element (k, v) in x and (k, w) in y, the resulting RDD #' will contain all pairs (k, (v, w)) for both (k, v) in x and -#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements #' in x/y have key k. #' @examples #'\dontrun{ @@ -683,7 +683,7 @@ setMethod("fullOuterJoin", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' cogroup(rdd1, rdd2, numPartitions = 2L) #' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) #'} #' @rdname cogroup @@ -694,7 +694,7 @@ setMethod("cogroup", rdds <- list(...) rddsLen <- length(rdds) for (i in 1:rddsLen) { - rdds[[i]] <- lapply(rdds[[i]], + rdds[[i]] <- lapply(rdds[[i]], function(x) { list(x[[1]], list(i, x[[2]])) }) } union.rdd <- Reduce(unionRDD, rdds) @@ -719,7 +719,7 @@ setMethod("cogroup", } }) } - cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), group.func) }) @@ -741,18 +741,18 @@ setMethod("sortByKey", signature(x = "RDD"), function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { rangeBounds <- list() - + if (numPartitions > 1) { rddSize <- count(x) # constant from Spark's RangePartitioner maxSampleSize <- numPartitions * 20 fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) - + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) - + # Note: the built-in R sort() function only works on atomic vectors samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) - + if (length(samples) > 0) { rangeBounds <- lapply(seq_len(numPartitions - 1), function(i) { @@ -764,24 +764,146 @@ setMethod("sortByKey", rangePartitionFunc <- function(key) { partition <- 0 - + # TODO: Use binary search instead of linear search, similar with Spark while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { partition <- partition + 1 } - + if (ascending) { partition } else { numPartitions - partition - 1 } } - + partitionFunc <- function(part) { sortKeyValueList(part, decreasing = !ascending) } - + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) +#' Subtract a pair RDD with another pair RDD. +#' +#' Return an RDD with the pairs from x whose keys are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the pairs from x whose keys are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +#' list("b", 5), list("a", 2))) +#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +#' collect(subtractByKey(rdd1, rdd2)) +#' # list(list("b", 4), list("b", 5)) +#'} +#' @rdname subtractByKey +#' @aliases subtractByKey,RDD +setMethod("subtractByKey", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + filterFunction <- function(elem) { + iters <- elem[[2]] + (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) + } + + flatMapValues(filterRDD(cogroup(x, + other, + numPartitions = numPartitions), + filterFunction), + function (v) { v[[1]] }) + }) + +#' Return a subset of this RDD sampled by key. +#' +#' @description +#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates +#' for different keys as specified by fractions, a key to sampling rate map. +#' +#' @param x The RDD to sample elements by key, where each element is +#' list(K, V) or c(K, V). +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3000) +#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +#' fractions <- list(a = 0.2, b = 0.1) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#'} +#' @rdname sampleByKey +#' @aliases sampleByKey,RDD-method +setMethod("sampleByKey", + signature(x = "RDD", withReplacement = "logical", + fractions = "vector", seed = "integer"), + function(x, withReplacement, fractions, seed) { + + for (elem in fractions) { + if (elem < 0.0) { + stop(paste("Negative fraction value ", fractions[which(fractions == elem)])) + } + } + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(bitwXor(seed, split)) + res <- vector("list", length(part)) + len <- 0 + + # mixing because the initial seeds are close to each other + runif(10) + + for (elem in part) { + if (elem[[1]] %in% names(fractions)) { + frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))]) + if (withReplacement) { + count <- rpois(1, frac) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < frac) { + len <- len + 1 + res[[len]] <- elem + } + } + } else { + stop("KeyError: \"", elem[[1]], "\"") + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. (duplicated from sampleRDD) + if (len > 0) { + res[1:len] + } else { + list() + } + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R new file mode 100644 index 0000000000..e442119086 --- /dev/null +++ b/R/pkg/R/schema.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A set of S3 classes and methods that support the SparkSQL `StructType` and `StructField +# datatypes. These are used to create and interact with DataFrame schemas. + +#' structType +#' +#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' use with createDataFrame and toDF. +#' +#' @param x a structField object (created with the field() function) +#' @param ... additional structField objects +#' @return a structType object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' schema <- structType(structField("a", "integer"), structField("b", "string")) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } +structType <- function(x, ...) { + UseMethod("structType", x) +} + +structType.jobj <- function(x) { + obj <- structure(list(), class = "structType") + obj$jobj <- x + obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) } + obj +} + +structType.structField <- function(x, ...) { + fields <- list(x, ...) + if (!all(sapply(fields, inherits, "structField"))) { + stop("All arguments must be structField objects.") + } + sfObjList <- lapply(fields, function(field) { + field$jobj + }) + stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructType", + listToSeq(sfObjList)) + structType(stObj) +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + cat("StructType\n", + sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") }) + , sep = "") +} + +#' structField +#' +#' Create a structField object that contains the metadata for a single field in a schema. +#' +#' @param x The name of the field +#' @param type The data type of the field +#' @param nullable A logical vector indicating whether or not the field is nullable +#' @return a structField object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' field1 <- structField("a", "integer", TRUE) +#' field2 <- structField("b", "string", TRUE) +#' schema <- structType(field1, field2) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } + +structField <- function(x, ...) { + UseMethod("structField", x) +} + +structField.jobj <- function(x) { + obj <- structure(list(), class = "structField") + obj$jobj <- x + obj$name <- function() { callJMethod(x, "name") } + obj$dataType <- function() { callJMethod(x, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(x, "nullable") } + obj +} + +structField.character <- function(x, type, nullable = TRUE) { + if (class(x) != "character") { + stop("Field name must be a string.") + } + if (class(type) != "character") { + stop("Field type must be a string.") + } + if (class(nullable) != "logical") { + stop("nullable must be either TRUE or FALSE") + } + options <- c("byte", + "integer", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + dataType <- if (type %in% options) { + type + } else { + stop(paste("Unsupported type for Dataframe:", type)) + } + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructField", + x, + dataType, + nullable) + structField(sfObj) +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat("StructField(name = \"", x$name(), + "\", type = \"", x$dataType.toString(), + "\", nullable = ", x$nullable(), + ")", + sep = "") +} diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 8a9c0c652c..c53d0a9610 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -69,8 +69,9 @@ writeJobj <- function(con, value) { } writeString <- function(con, value) { - writeInt(con, as.integer(nchar(value) + 1)) - writeBin(value, con, endian = "big") + utfVal <- enc2utf8(value) + writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) + writeBin(utfVal, con, endian = "big") } writeInt <- function(con, value) { @@ -189,7 +190,3 @@ writeArgs <- function(con, args) { } } } - -writeStrings <- function(con, stringList) { - writeLines(unlist(stringList), con) -} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c337fb0751..23305d3c67 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -465,3 +465,83 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { } func } + +# Append partition lengths to each partition in two input RDDs if needed. +# param +# x An RDD. +# Other An RDD. +# return value +# A list of two result RDDs. +appendPartitionLengths <- function(x, other) { + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # know the boundary of elements from x and other. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + len <- length(part) + part[[len + 1]] <- len + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + list (x, other) +} + +# Perform zip or cartesian between elements from two RDDs in each partition +# param +# rdd An RDD. +# zip A boolean flag indicating this call is for zip operation or not. +# return value +# A result RDD. +mergePartitions <- function(rdd, zip) { + serializerMode <- getSerializedMode(rdd) + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + if (zip && lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + } else { + keys <- list() + } + if (lengthOfValues > 1) { + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + values <- list() + } + + if (!zip) { + return(mergeCompactLists(keys, values)) + } + } else { + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { list(k, v) }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(rdd, partitionFunc) +} diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index b76e4db03e..3ba7d17163 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -35,7 +35,7 @@ test_that("get number of partitions in RDD", { test_that("first on RDD", { expect_true(first(rdd) == 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_true(first(newrdd) == 2) }) test_that("count and length on RDD", { @@ -48,7 +48,7 @@ test_that("count by values and keys", { actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - + actual <- countByKey(intRdd) expected <- list(list(2L, 2L), list(1L, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -82,11 +82,11 @@ test_that("filterRDD on RDD", { filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collect(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) - + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) actual <- collect(filtered.rdd) expect_equal(actual, list(list(1L, -1))) - + # Filter out all elements. filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) actual <- collect(filtered.rdd) @@ -96,7 +96,7 @@ test_that("filterRDD on RDD", { test_that("lookup on RDD", { vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) - + vals <- lookup(intRdd, 3L) expect_equal(vals, list()) }) @@ -110,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) rdd2 <- lapply(rdd2, function(x) x + x) actual <- collect(rdd2) - expected <- list(24, 24, 24, 24, 24, + expected <- list(24, 24, 24, 24, 24, 168, 170, 172, 174, 176) expect_equal(actual, expected) }) @@ -248,10 +248,10 @@ test_that("flatMapValues() on pairwise RDDs", { l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) actual <- collect(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) - + # Generate x to x+1 for every value actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) - expect_equal(actual, + expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) }) @@ -348,7 +348,7 @@ test_that("top() on RDDs", { rdd <- parallelize(sc, l) actual <- top(rdd, 6L) expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) - + l <- list("e", "d", "c", "d", "a") rdd <- parallelize(sc, l) actual <- top(rdd, 3L) @@ -358,7 +358,7 @@ test_that("top() on RDDs", { test_that("fold() on RDDs", { actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) - + rdd <- parallelize(sc, list()) actual <- fold(rdd, 0, "+") expect_equal(actual, 0) @@ -371,7 +371,7 @@ test_that("aggregateRDD() on RDDs", { combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) expect_equal(actual, list(10, 4)) - + rdd <- parallelize(sc, list()) actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) expect_equal(actual, list(0, 0)) @@ -380,13 +380,13 @@ test_that("aggregateRDD() on RDDs", { test_that("zipWithUniqueId() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collect(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 3), list("c", 1), + expected <- list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) expect_equal(actual, expected) - + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) actual <- collect(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) }) @@ -394,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", { test_that("zipWithIndex() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collect(zipWithIndex(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) - + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) actual <- collect(zipWithIndex(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) }) @@ -427,12 +427,12 @@ test_that("pipeRDD() on RDDs", { actual <- collect(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) - + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) actual <- collect(pipeRDD(trailed.rdd, "sort")) expected <- list("", "1", "2", "3") expect_equal(actual, expected) - + rev.nums <- 9:0 rev.rdd <- parallelize(sc, rev.nums, 2L) actual <- collect(pipeRDD(rev.rdd, "sort")) @@ -446,11 +446,11 @@ test_that("zipRDD() on RDDs", { actual <- collect(zipRDD(rdd1, rdd2)) expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - + mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName, 1) actual <- collect(zipRDD(rdd, rdd)) expected <- lapply(mockFile, function(x) { list(x ,x) }) @@ -465,10 +465,125 @@ test_that("zipRDD() on RDDs", { actual <- collect(zipRDD(rdd, rdd1)) expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) - + + unlink(fileName) +}) + +test_that("cartesian() on RDDs", { + rdd <- parallelize(sc, 1:3) + actual <- collect(cartesian(rdd, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(1, 1), list(1, 2), list(1, 3), + list(2, 1), list(2, 2), list(2, 3), + list(3, 1), list(3, 2), list(3, 3))) + + # test case where one RDD is empty + emptyRdd <- parallelize(sc, list()) + actual <- collect(cartesian(rdd, emptyRdd)) + expect_equal(actual, list()) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + actual <- collect(cartesian(rdd, rdd)) + expected <- list( + list("Spark is awesome.", "Spark is pretty."), + list("Spark is awesome.", "Spark is awesome."), + list("Spark is pretty.", "Spark is pretty."), + list("Spark is pretty.", "Spark is awesome.")) + expect_equal(sortKeyValueList(actual), expected) + + rdd1 <- parallelize(sc, 0:1) + actual <- collect(cartesian(rdd1, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(0, "Spark is pretty."), + list(0, "Spark is awesome."), + list(1, "Spark is pretty."), + list(1, "Spark is awesome."))) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(cartesian(rdd, rdd1)) + expect_equal(sortKeyValueList(actual), expected) + unlink(fileName) }) +test_that("subtract() on RDDs", { + l <- list(1, 1, 2, 2, 3, 4) + rdd1 <- parallelize(sc, l) + + # subtract by itself + actual <- collect(subtract(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtract by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + l) + + rdd2 <- parallelize(sc, list(2, 4)) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + list(1, 1, 3)) + + l <- list("a", "a", "b", "b", "c", "d") + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list("b", "d")) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="character"))), + list("a", "a", "c")) +}) + +test_that("subtractByKey() on pairwise RDDs", { + l <- list(list("a", 1), list("b", 4), + list("b", 5), list("a", 2)) + rdd1 <- parallelize(sc, l) + + # subtractByKey by itself + actual <- collect(subtractByKey(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtractByKey by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(l)) + + rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list("b", 4), list("b", 5))) + + l <- list(list(1, 1), list(2, 4), + list(2, 5), list(1, 2)) + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list(2, 4), list(2, 5))) +}) + +test_that("intersection() on RDDs", { + # intersection with self + actual <- collect(intersection(rdd, rdd)) + expect_equal(sort(as.integer(actual)), nums) + + # intersection with an empty RDD + emptyRdd <- parallelize(sc, list()) + actual <- collect(intersection(rdd, emptyRdd)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) + rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) + actual <- collect(intersection(rdd1, rdd2)) + expect_equal(sort(as.integer(actual)), 1:3) +}) + test_that("join() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) @@ -596,9 +711,9 @@ test_that("sortByKey() on pairwise RDDs", { sortedRdd3 <- sortByKey(rdd3) actual <- collect(sortedRdd3) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) - + # test on the boundary cases - + # boundary case 1: the RDD to be sorted has only 1 partition rdd4 <- parallelize(sc, l, 1L) sortedRdd4 <- sortByKey(rdd4) @@ -623,7 +738,7 @@ test_that("sortByKey() on pairwise RDDs", { rdd7 <- parallelize(sc, l3, 2L) sortedRdd7 <- sortByKey(rdd7) actual <- collect(sortedRdd7) - expect_equal(actual, l3) + expect_equal(actual, l3) }) test_that("collectAsMap() on a pairwise RDD", { @@ -634,12 +749,36 @@ test_that("collectAsMap() on a pairwise RDD", { rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) vals <- collectAsMap(rdd) expect_equal(vals, list(a = 1, b = 2)) - + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) - + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = "a", `2` = "b")) }) + +test_that("sampleByKey() on pairwise RDDs", { + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) + fractions <- list(a = 0.2, b = 0.1) + sample <- sampleByKey(pairsRDD, FALSE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")), TRUE) + expect_equal(50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")), TRUE) + expect_equal(lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0, TRUE) + expect_equal(lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000, TRUE) + expect_equal(lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0, TRUE) + expect_equal(lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000, TRUE) + + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list(2, x) else list(3, x) }) + fractions <- list(`2` = 0.2, `3` = 0.1) + sample <- sampleByKey(pairsRDD, TRUE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, 2)) && 300 > length(lookup(sample, 2)), TRUE) + expect_equal(50 < length(lookup(sample, 3)) && 150 > length(lookup(sample, 3)), TRUE) + expect_equal(lookup(sample, 2)[which.min(lookup(sample, 2))] >= 0, TRUE) + expect_equal(lookup(sample, 2)[which.max(lookup(sample, 2))] <= 2000, TRUE) + expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) + expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R index d1da8232ae..d7dedda553 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/test_shuffle.R @@ -87,6 +87,18 @@ test_that("combineByKey for doubles", { expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) +test_that("combineByKey for characters", { + stringKeyRDD <- parallelize(sc, + list(list("max", 1L), list("min", 2L), + list("other", 3L), list("max", 4L)), 2L) + reduced <- combineByKey(stringKeyRDD, + function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list("max", 5L), list("min", 2L), list("other", 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + test_that("aggregateByKey", { # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index cf5cf6d169..25831ae2d9 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -44,9 +44,8 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(a = 1L, b = "2")), - list(type = "struct", - fields = list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)))) + structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE))) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), @@ -54,6 +53,18 @@ test_that("infer types", { valueContainsNull = TRUE)) }) +test_that("structType and structField", { + testField <- structField("a", "string") + expect_true(inherits(testField, "structField")) + expect_true(testField$name() == "a") + expect_true(testField$nullable()) + + testSchema <- structType(testField, structField("b", "integer")) + expect_true(inherits(testSchema, "structType")) + expect_true(inherits(testSchema$fields()[[2]], "structField")) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") +}) + test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlCtx, rdd, list("a", "b")) @@ -66,9 +77,8 @@ test_that("create DataFrame from RDD", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlCtx, rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) @@ -94,9 +104,8 @@ test_that("toDF", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) @@ -635,7 +644,7 @@ test_that("isLocal()", { expect_false(isLocal(df)) }) -test_that("unionAll(), subtract(), and intersect() on a DataFrame", { +test_that("unionAll(), except(), and intersect() on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", @@ -650,10 +659,10 @@ test_that("unionAll(), subtract(), and intersect() on a DataFrame", { expect_true(count(unioned) == 6) expect_true(first(unioned)$name == "Michael") - subtracted <- sortDF(subtract(df, df2), desc(df$age)) + excepted <- sortDF(except(df, df2), desc(df$age)) expect_true(inherits(unioned, "DataFrame")) - expect_true(count(subtracted) == 2) - expect_true(first(subtracted)$name == "Justin") + expect_true(count(excepted) == 2) + expect_true(first(excepted)$name == "Justin") intersected <- sortDF(intersect(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index c6542928e8..014bf7bd7b 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -17,6 +17,23 @@ # Worker class +# Get current system time +currentTimeSecs <- function() { + as.numeric(Sys.time()) +} + +# Get elapsed time +elapsedSecs <- function() { + proc.time()[3] +} + +# Constants +specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L) + +# Timing R process boot +bootTime <- currentTimeSecs() +bootElap <- elapsedSecs() + rLibDir <- Sys.getenv("SPARKR_RLIBDIR") # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -37,7 +54,7 @@ serializer <- SparkR:::readString(inputCon) # Include packages as required packageNames <- unserialize(SparkR:::readRaw(inputCon)) for (pkg in packageNames) { - suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) + suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE)) } # read function dependencies @@ -46,6 +63,9 @@ computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) env <- environment(computeFunc) parent.env(env) <- .GlobalEnv # Attach under global environment. +# Timing init envs for computing +initElap <- elapsedSecs() + # Read and set broadcast variables numBroadcastVars <- SparkR:::readInt(inputCon) if (numBroadcastVars > 0) { @@ -56,6 +76,9 @@ if (numBroadcastVars > 0) { } } +# Timing broadcast +broadcastElap <- elapsedSecs() + # If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int # as number of partitions to create. numPartitions <- SparkR:::readInt(inputCon) @@ -73,14 +96,23 @@ if (isEmpty != 0) { } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } + # Timing reading input data for execution + inputElap <- elapsedSecs() + output <- computeFunc(partition, data) + # Timing computing + computeElap <- elapsedSecs() + if (serializer == "byte") { SparkR:::writeRawSerialize(outputCon, output) } else if (serializer == "row") { SparkR:::writeRowSerialize(outputCon, output) } else { - SparkR:::writeStrings(outputCon, output) + # write lines one-by-one with flag + lapply(output, function(line) SparkR:::writeString(outputCon, line)) } + # Timing output + outputElap <- elapsedSecs() } else { if (deserializer == "byte") { # Now read as many characters as described in funcLen @@ -90,6 +122,8 @@ if (isEmpty != 0) { } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } + # Timing reading input data for execution + inputElap <- elapsedSecs() res <- new.env() @@ -107,6 +141,8 @@ if (isEmpty != 0) { res[[bucket]] <- acc } invisible(lapply(data, hashTupleToEnvir)) + # Timing computing + computeElap <- elapsedSecs() # Step 2: write out all of the environment as key-value pairs. for (name in ls(res)) { @@ -116,13 +152,26 @@ if (isEmpty != 0) { length(res[[name]]$data) <- res[[name]]$counter SparkR:::writeRawSerialize(outputCon, res[[name]]$data) } + # Timing output + outputElap <- elapsedSecs() } +} else { + inputElap <- broadcastElap + computeElap <- broadcastElap + outputElap <- broadcastElap } +# Report timing +SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA) +SparkR:::writeDouble(outputCon, bootTime) +SparkR:::writeDouble(outputCon, initElap - bootElap) # init +SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast +SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input +SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute +SparkR:::writeDouble(outputCon, outputElap - computeElap) # output + # End of output -if (serializer %in% c("byte", "row")) { - SparkR:::writeInt(outputCon, 0L) -} +SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM) close(outputCon) close(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 5fa4d483b8..6fea5e1144 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -42,10 +42,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { + protected var dataStream: DataInputStream = _ + private var bootTime: Double = _ override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + // Timing start + bootTime = System.currentTimeMillis / 1000.0 + // The parent may be also an RRDD, so we should launch it first. val parentIterator = firstParent[T].iterator(partition, context) @@ -69,7 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( // the socket used to receive the output of task val outSocket = serverSocket.accept() val inputStream = new BufferedInputStream(outSocket.getInputStream) - val dataStream = openDataStream(inputStream) + dataStream = new DataInputStream(inputStream) serverSocket.close() try { @@ -155,6 +160,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( } else if (deserializer == SerializationFormats.ROW) { dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) printOut.println(elem) } } @@ -180,9 +186,41 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( }.start() } - protected def openDataStream(input: InputStream): Closeable + protected def readData(length: Int): U - protected def read(): U + protected def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length) + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } + } } /** @@ -202,31 +240,16 @@ private class PairwiseRRDD[T: ClassTag]( SerializationFormats.BYTE, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: DataInputStream = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new DataInputStream(input) - dataStream - } - - override protected def read(): (Int, Array[Byte]) = { - try { - val length = dataStream.readInt() - - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null // End of input - } - } catch { - case eof: EOFException => { - throw new SparkException("R worker exited unexpectedly (crashed)", eof) - } - } + override protected def readData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } } lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) @@ -247,28 +270,13 @@ private class RRDD[T: ClassTag]( parent, -1, func, deserializer, serializer, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: DataInputStream = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new DataInputStream(input) - dataStream - } - - override protected def read(): Array[Byte] = { - try { - val length = dataStream.readInt() - - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj, 0, length) - obj - case _ => null - } - } catch { - case eof: EOFException => { - throw new SparkException("R worker exited unexpectedly (crashed)", eof) - } + override protected def readData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null } } @@ -289,26 +297,21 @@ private class StringRRDD[T: ClassTag]( parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: BufferedReader = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new BufferedReader(new InputStreamReader(input)) - dataStream - } - - override protected def read(): String = { - try { - dataStream.readLine() - } catch { - case e: IOException => { - throw new SparkException("R worker exited unexpectedly (crashed)", e) - } + override protected def readData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null } } lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } +private object SpecialLengths { + val TIMING_DATA = -1 +} + private[r] class BufferedStreamThread( in: InputStream, name: String, diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index ccb2a371f4..371dfe454d 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -85,13 +85,17 @@ private[spark] object SerDe { in.readDouble() } + def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + assert(bytes(len - 1) == 0) + val str = new String(bytes.dropRight(1), "UTF-8") + str + } + def readString(in: DataInputStream): String = { val len = in.readInt() - val asciiBytes = new Array[Byte](len) - in.readFully(asciiBytes) - assert(asciiBytes(len - 1) == 0) - val str = new String(asciiBytes.dropRight(1).map(_.toChar)) - str + readStringBytes(in, len) } def readBoolean(in: DataInputStream): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d1ea7cc3e9..ae77f72998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} private[r] object SQLUtils { @@ -39,8 +39,34 @@ private[r] object SQLUtils { arr.toSeq } - def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + def createStructType(fields : Seq[StructField]): StructType = { + StructType(fields) + } + + def getSQLDataType(dataType: String): DataType = { + dataType match { + case "byte" => org.apache.spark.sql.types.ByteType + case "integer" => org.apache.spark.sql.types.IntegerType + case "double" => org.apache.spark.sql.types.DoubleType + case "numeric" => org.apache.spark.sql.types.DoubleType + case "character" => org.apache.spark.sql.types.StringType + case "string" => org.apache.spark.sql.types.StringType + case "binary" => org.apache.spark.sql.types.BinaryType + case "raw" => org.apache.spark.sql.types.BinaryType + case "logical" => org.apache.spark.sql.types.BooleanType + case "boolean" => org.apache.spark.sql.types.BooleanType + case "timestamp" => org.apache.spark.sql.types.TimestampType + case "date" => org.apache.spark.sql.types.DateType + case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + } + } + + def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { + val dtObj = getSQLDataType(dataType) + StructField(name, dtObj, nullable) + } + + def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size val rowRDD = rdd.map(bytesToRow) sqlContext.createDataFrame(rowRDD, schema) -- GitLab