diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 80e92d3105a36d23c34909abaf2b106b91fc2746..8e4b0f5bf1c4dc7366101b661d73bb238d3c9e47 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -210,6 +210,22 @@ setMethod("cast", } }) +#' Match a column with given values. +#' +#' @rdname column +#' @return a matched values as a result of comparing with given values. +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) +#' } +setMethod("%in%", + signature(x = "Column"), + function(x, table) { + table <- listToSeq(as.list(table)) + jc <- callJMethod(x@jc, "in", table) + return(column(jc)) + }) + #' Approx Count Distinct #' #' @rdname column diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index fc7f3f074b67ccef901489f12c678c0245add522..417153dc0985cbb43fb3a55d2f162a1510b05dbb 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -693,6 +693,16 @@ test_that("filter() on a DataFrame", { filtered2 <- where(df, df$name != "Michael") expect_true(count(filtered2) == 2) expect_true(collect(filtered2)$age[2] == 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) }) test_that("join() on a DataFrame", {