Skip to content
Snippets Groups Projects
Commit 50da9e89 authored by qhuang's avatar qhuang Committed by Shivaram Venkataraman
Browse files

[SPARK-7226] [SPARKR] Support math functions in R DataFrame

Author: qhuang <qian.huang@intel.com>

Closes #6170 from hqzizania/master and squashes the following commits:

f20c39f [qhuang] add tests units and fixes
2a7d121 [qhuang] use a function name more familiar to R users
07aa72e [qhuang] Support math functions in R DataFrame
parent 9b6cf285
No related branches found
No related tags found
No related merge requests found
...@@ -59,33 +59,56 @@ exportMethods("arrange", ...@@ -59,33 +59,56 @@ exportMethods("arrange",
exportClasses("Column") exportClasses("Column")
exportMethods("abs", exportMethods("abs",
"acos",
"alias", "alias",
"approxCountDistinct", "approxCountDistinct",
"asc", "asc",
"asin",
"atan",
"atan2",
"avg", "avg",
"cast", "cast",
"cbrt",
"ceiling",
"contains", "contains",
"cos",
"cosh",
"countDistinct", "countDistinct",
"desc", "desc",
"endsWith", "endsWith",
"exp",
"expm1",
"floor",
"getField", "getField",
"getItem", "getItem",
"hypot",
"isNotNull", "isNotNull",
"isNull", "isNull",
"last", "last",
"like", "like",
"log",
"log10",
"log1p",
"lower", "lower",
"max", "max",
"mean", "mean",
"min", "min",
"n", "n",
"n_distinct", "n_distinct",
"rint",
"rlike", "rlike",
"sign",
"sin",
"sinh",
"sqrt", "sqrt",
"startsWith", "startsWith",
"substr", "substr",
"sum", "sum",
"sumDistinct", "sumDistinct",
"tan",
"tanh",
"toDegrees",
"toRadians",
"upper") "upper")
exportClasses("GroupedData") exportClasses("GroupedData")
......
...@@ -55,12 +55,17 @@ operators <- list( ...@@ -55,12 +55,17 @@ operators <- list(
"+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod",
"==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq",
# we can not override `&&` and `||`, so use `&` and `|` instead # we can not override `&&` and `||`, so use `&` and `|` instead
"&" = "and", "|" = "or" #, "!" = "unary_$bang" "&" = "and", "|" = "or", #, "!" = "unary_$bang"
"^" = "pow"
) )
column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions1 <- c("asc", "desc", "isNull", "isNotNull")
column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains")
functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt",
"first", "last", "lower", "upper", "sumDistinct") "first", "last", "lower", "upper", "sumDistinct",
"acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp",
"expm1", "floor", "log", "log10", "log1p", "rint", "sign",
"sin", "sinh", "tan", "tanh", "toDegrees", "toRadians")
binary_mathfunctions<- c("atan2", "hypot")
createOperator <- function(op) { createOperator <- function(op) {
setMethod(op, setMethod(op,
...@@ -76,7 +81,11 @@ createOperator <- function(op) { ...@@ -76,7 +81,11 @@ createOperator <- function(op) {
if (class(e2) == "Column") { if (class(e2) == "Column") {
e2 <- e2@jc e2 <- e2@jc
} }
callJMethod(e1@jc, operators[[op]], e2) if (op == "^") {
jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2)
} else {
callJMethod(e1@jc, operators[[op]], e2)
}
} }
column(jc) column(jc)
}) })
...@@ -106,11 +115,29 @@ createStaticFunction <- function(name) { ...@@ -106,11 +115,29 @@ createStaticFunction <- function(name) {
setMethod(name, setMethod(name,
signature(x = "Column"), signature(x = "Column"),
function(x) { function(x) {
if (name == "ceiling") {
name <- "ceil"
}
if (name == "sign") {
name <- "signum"
}
jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc)
column(jc) column(jc)
}) })
} }
createBinaryMathfunctions <- function(name) {
setMethod(name,
signature(y = "Column"),
function(y, x) {
if (class(x) == "Column") {
x <- x@jc
}
jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x)
column(jc)
})
}
createMethods <- function() { createMethods <- function() {
for (op in names(operators)) { for (op in names(operators)) {
createOperator(op) createOperator(op)
...@@ -124,6 +151,9 @@ createMethods <- function() { ...@@ -124,6 +151,9 @@ createMethods <- function() {
for (x in functions) { for (x in functions) {
createStaticFunction(x) createStaticFunction(x)
} }
for (name in binary_mathfunctions) {
createBinaryMathfunctions(name)
}
} }
createMethods() createMethods()
......
...@@ -552,6 +552,10 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) ...@@ -552,6 +552,10 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") })
#' @export #' @export
setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
#' @rdname column
#' @export
setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
#' @rdname column #' @rdname column
#' @export #' @export
setGeneric("contains", function(x, ...) { standardGeneric("contains") }) setGeneric("contains", function(x, ...) { standardGeneric("contains") })
...@@ -575,6 +579,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) ...@@ -575,6 +579,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") })
#' @export #' @export
setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) setGeneric("getItem", function(x, ...) { standardGeneric("getItem") })
#' @rdname column
#' @export
setGeneric("hypot", function(y, x) { standardGeneric("hypot") })
#' @rdname column #' @rdname column
#' @export #' @export
setGeneric("isNull", function(x) { standardGeneric("isNull") }) setGeneric("isNull", function(x) { standardGeneric("isNull") })
...@@ -603,6 +611,10 @@ setGeneric("n", function(x) { standardGeneric("n") }) ...@@ -603,6 +611,10 @@ setGeneric("n", function(x) { standardGeneric("n") })
#' @export #' @export
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
#' @rdname column
#' @export
setGeneric("rint", function(x, ...) { standardGeneric("rint") })
#' @rdname column #' @rdname column
#' @export #' @export
setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
...@@ -615,6 +627,14 @@ setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) ...@@ -615,6 +627,14 @@ setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") })
#' @export #' @export
setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") })
#' @rdname column
#' @export
setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") })
#' @rdname column
#' @export
setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
#' @rdname column #' @rdname column
#' @export #' @export
setGeneric("upper", function(x) { standardGeneric("upper") }) setGeneric("upper", function(x) { standardGeneric("upper") })
......
...@@ -530,6 +530,7 @@ test_that("column operators", { ...@@ -530,6 +530,7 @@ test_that("column operators", {
c2 <- (- c + 1 - 2) * 3 / 4.0 c2 <- (- c + 1 - 2) * 3 / 4.0
c3 <- (c + c2 - c2) * c2 %% c2 c3 <- (c + c2 - c2) * c2 %% c2
c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3)
c5 <- c2 ^ c3 ^ c4
}) })
test_that("column functions", { test_that("column functions", {
...@@ -538,6 +539,29 @@ test_that("column functions", { ...@@ -538,6 +539,29 @@ test_that("column functions", {
c3 <- lower(c) + upper(c) + first(c) + last(c) c3 <- lower(c) + upper(c) + first(c) + last(c)
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
c5 <- n(c) + n_distinct(c) c5 <- n(c) + n_distinct(c)
c5 <- acos(c) + asin(c) + atan(c) + cbrt(c)
c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c)
c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c)
c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c)
c9 <- toDegrees(c) + toRadians(c)
})
test_that("column binary mathfunctions", {
lines <- c("{\"a\":1, \"b\":5}",
"{\"a\":2, \"b\":6}",
"{\"a\":3, \"b\":7}",
"{\"a\":4, \"b\":8}")
jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp")
writeLines(lines, jsonPathWithDup)
df <- jsonFile(sqlCtx, jsonPathWithDup)
expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5))
expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6))
expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7))
expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8))
expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2))
expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2))
expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2))
expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2))
}) })
test_that("string operators", { test_that("string operators", {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment