diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 1432ab8a9d1ce3567627ed348c881ae4e1ba305d..239ad065d09ad67696997d19566a59778a20640d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -47,6 +47,7 @@ exportMethods("arrange", "covar_pop", "crosstab", "dapply", + "dapplyCollect", "describe", "dim", "distinct", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 43c46b847446b17be2119b351c512e6990c1b863..0c2a194483b0f93b4a34e6e676b6b44c67b7a20e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1153,9 +1153,27 @@ setMethod("summarize", agg(x, ...) }) +dapplyInternal <- function(x, func, schema) { + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + sdf <- callJStatic( + "org.apache.spark.sql.api.r.SQLUtils", + "dapply", + x@sdf, + serialize(cleanClosure(func), connection = NULL), + packageNamesArr, + broadcastArr, + if (is.null(schema)) { schema } else { schema$jobj }) + dataFrame(sdf) +} + #' dapply #' -#' Apply a function to each partition of a DataFrame. +#' Apply a function to each partition of a SparkDataFrame. #' #' @param x A SparkDataFrame #' @param func A function to be applied to each partition of the SparkDataFrame. @@ -1197,21 +1215,57 @@ setMethod("summarize", setMethod("dapply", signature(x = "SparkDataFrame", func = "function", schema = "structType"), function(x, func, schema) { - packageNamesArr <- serialize(.sparkREnv[[".packages"]], - connection = NULL) - - broadcastArr <- lapply(ls(.broadcastNames), - function(name) { get(name, .broadcastNames) }) - - sdf <- callJStatic( - "org.apache.spark.sql.api.r.SQLUtils", - "dapply", - x@sdf, - serialize(cleanClosure(func), connection = NULL), - packageNamesArr, - broadcastArr, - schema$jobj) - dataFrame(sdf) + dapplyInternal(x, func, schema) + }) + +#' dapplyCollect +#' +#' Apply a function to each partition of a SparkDataFrame and collect the result back +#’ to R as a data.frame. +#' +#' @param x A SparkDataFrame +#' @param func A function to be applied to each partition of the SparkDataFrame. +#' func should have only one parameter, to which a data.frame corresponds +#' to each partition will be passed. +#' The output of func should be a data.frame. +#' @family SparkDataFrame functions +#' @rdname dapply +#' @name dapplyCollect +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame (sqlContext, iris) +#' ldf <- dapplyCollect(df, function(x) { x }) +#' +#' # filter and add a column +#' df <- createDataFrame ( +#' sqlContext, +#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")), +#' c("a", "b", "c")) +#' ldf <- dapplyCollect( +#' df, +#' function(x) { +#' y <- x[x[1] > 1, ] +#' y <- cbind(y, y[1] + 1L) +#' }) +#' # the result +#' # a b c d +#' # 2 2 2 3 +#' # 3 3 3 4 +#' } +setMethod("dapplyCollect", + signature(x = "SparkDataFrame", func = "function"), + function(x, func) { + df <- dapplyInternal(x, func, NULL) + + content <- callJMethod(df@sdf, "collect") + # content is a list of items of struct type. Each item has a single field + # which is a serialized data.frame corresponds to one partition of the + # SparkDataFrame. + ldfs <- lapply(content, function(x) { unserialize(x[[1]]) }) + ldf <- do.call(rbind, ldfs) + row.names(ldf) <- NULL + ldf }) ############################## RDD Map Functions ################################## diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8563be1e649838379d4c56142bb26aa9517675f2..ed76ad6b73c8bfd321466d6439b2de7d365b43df 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -450,6 +450,10 @@ setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) #' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) +#' @rdname dapply +#' @export +setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 0f67bc2e331d1e7f73d840b957badb9f2518496b..6a99b43e5aa5920df522e9c6fdf2f927f8df1497 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2043,7 +2043,7 @@ test_that("Histogram", { expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1)) }) -test_that("dapply() on a DataFrame", { +test_that("dapply() and dapplyCollect() on a DataFrame", { df <- createDataFrame ( sqlContext, list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")), @@ -2053,6 +2053,8 @@ test_that("dapply() on a DataFrame", { result <- collect(df1) expect_identical(ldf, result) + result <- dapplyCollect(df, function(x) { x }) + expect_identical(ldf, result) # Filter and add a column schema <- structType(structField("a", "integer"), structField("b", "double"), @@ -2070,6 +2072,16 @@ test_that("dapply() on a DataFrame", { rownames(expected) <- NULL expect_identical(expected, result) + result <- dapplyCollect( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }) + expected1 <- expected + names(expected1) <- names(result) + expect_identical(expected1, result) + # Remove the added column df2 <- dapply( df1, @@ -2080,6 +2092,13 @@ test_that("dapply() on a DataFrame", { result <- collect(df2) expected <- expected[, c("a", "b", "c")] expect_identical(expected, result) + + result <- dapplyCollect( + df1, + function(x) { + x[, c("a", "b", "c")] + }) + expect_identical(expected, result) }) test_that("repartition by columns on DataFrame", {