diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 10b9d16279308f5719520bd1fc22c707e98b6957..667fff7192b598265f12123703d9689e75a85524 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -126,6 +126,7 @@ exportMethods("%in%", "between", "bin", "bitwiseNOT", + "bround", "cast", "cbrt", "ceil", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index db877b2d63d3004883318aed9f88553b18dea810..54234b0455eab39a68a978501ab5e473992b6e4c 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -994,7 +994,7 @@ setMethod("rint", #' round #' -#' Returns the value of the column `e` rounded to 0 decimal places. +#' Returns the value of the column `e` rounded to 0 decimal places using HALF_UP rounding mode. #' #' @rdname round #' @name round @@ -1008,6 +1008,26 @@ setMethod("round", column(jc) }) +#' bround +#' +#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding +#' mode if `scale` >= 0 or at integral part when `scale` < 0. +#' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. +#' bround(2.5, 0) = 2, bround(3.5, 0) = 4. +#' +#' @rdname bround +#' @name bround +#' @family math_funcs +#' @export +#' @examples \dontrun{bround(df$c, 0)} +setMethod("bround", + signature(x = "Column"), + function(x, scale = 0) { + jc <- callJStatic("org.apache.spark.sql.functions", "bround", x@jc, as.integer(scale)) + column(jc) + }) + + #' rtrim #' #' Trim the spaces from right end for the specified string value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a71be55bcae81c281580ec1bc8891a845d8588f3..6b67258d77e6c7f4f16b89dd9b81dd67725e7210 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -760,6 +760,10 @@ setGeneric("bin", function(x) { standardGeneric("bin") }) #' @export setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) +#' @rdname bround +#' @export +setGeneric("bround", function(x, ...) { standardGeneric("bround") }) + #' @rdname cbrt #' @export setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2f65484fcbdd86f1a9af77fef8cfe085b427b805..b923ccf6bb1aeef1a46595808499e338211e5711 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1087,6 +1087,11 @@ test_that("column functions", { expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) expect_equal(collect(select(df, last("age")))[[1]], 19) expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19) + + # Test bround() + df <- createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5))) + expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) + expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) }) test_that("column binary mathfunctions", { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5017ab5b3646dcd00edce79bc33a3367187c99b4..dac842c0ce8c026ae7b35d15d1ea29fdeb94e0f2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -467,16 +467,29 @@ def randn(seed=None): @since(1.5) def round(col, scale=0): """ - Round the value of `e` to `scale` decimal places if `scale` >= 0 + Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 or at integral part when `scale` < 0. - >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect() - [Row(r=2.5)] + >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() + [Row(r=3.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.round(_to_java_column(col), scale)) +@since(2.0) +def bround(col, scale=0): + """ + Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 + or at integral part when `scale` < 0. + + >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() + [Row(r=2.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.bround(_to_java_column(col), scale)) + + @since(1.5) def shiftLeft(col, numBits): """Shift the given value numBits left.