diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ff116cb1fbde26e9c4d16bb428cd8dc63b1da7d2..b2d92bdf4840e3808053c98392cfb2498cd5db74 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -46,6 +46,7 @@ exportMethods("arrange", "isLocal", "join", "limit", + "merge", "names", "ncol", "nrow", @@ -69,6 +70,7 @@ exportMethods("arrange", "show", "showDF", "summarize", + "summary", "take", "unionAll", "unique", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index b4065d2944bdc30533b3adf68cca1f8effcfde7a..895603235011eb319ebc4e2b20d3747fa7fe1491 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1279,6 +1279,15 @@ setMethod("join", dataFrame(sdf) }) +#' rdname merge +#' aliases join +setMethod("merge", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL, ...) { + join(x, y, joinExpr, joinType) + }) + + #' UnionAll #' #' Return a new DataFrame containing the union of rows in this DataFrame @@ -1524,6 +1533,19 @@ setMethod("describe", dataFrame(sdf) }) +#' @title Summary +#' +#' @description Computes statistics for numeric columns of the DataFrame +#' +#' @rdname summary +#' @aliases describe +setMethod("summary", + signature(x = "DataFrame"), + function(x) { + describe(x) + }) + + #' dropna #' #' Returns a new DataFrame omitting rows with null values. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 71d1e348c4efb96048fddaa3d5ee9d4b3af83d26..c43b947129e87a500c0880b206784d06e99cd0eb 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -461,6 +461,10 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) +#' rdname merge +#' @export +setGeneric("merge") + #' @rdname withColumn #' @export setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) @@ -531,6 +535,10 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) +##' rdname summary +##' @export +setGeneric("summary", function(x, ...) { standardGeneric("summary") }) + # @rdname tojson # @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index efddcc1d8d71c587d4ddcb041ef5480c5729632d..b524d1fd874966f6b1c093e5b3cb7672d940490e 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -86,12 +86,12 @@ setMethod("predict", signature(object = "PipelineModel"), #' model <- glm(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(object = "PipelineModel"), - function(object) { +setMethod("summary", signature(x = "PipelineModel"), + function(x, ...) { features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) + "getModelFeatures", x@model) weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelWeights", object@model) + "getModelWeights", x@model) coefficients <- as.matrix(unlist(weights)) colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 9faee8d59c3afe63e0fc4e059be40973caf74b6a..7377fc8f1ca9c36b1dccc1bf2cc248f867b560f8 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -765,7 +765,7 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered6), 2) }) -test_that("join() on a DataFrame", { +test_that("join() and merge() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -794,6 +794,12 @@ test_that("join() on a DataFrame", { expect_equal(names(joined4), c("newAge", "name", "test")) expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + + merged <- select(merge(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(merged), c("newAge", "name", "test")) + expect_equal(count(merged), 4) + expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { @@ -899,7 +905,7 @@ test_that("parquetFile works with multiple input paths", { expect_equal(count(parquetDF), count(df) * 2) }) -test_that("describe() on a DataFrame", { +test_that("describe() and summarize() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") @@ -908,6 +914,10 @@ test_that("describe() on a DataFrame", { stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") + + stats2 <- summary(df) + expect_equal(collect(stats2)[4, "name"], "Andy") + expect_equal(collect(stats2)[5, "age"], "30") }) test_that("dropna() on a DataFrame", {