diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e3528bc7c31356ceaafc7133faa9b505624000c3..3b7f71bbbffb8cf5b7b619f257b3089dbb772b9e 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -584,7 +584,7 @@ tableToDF <- function(tableName) { #' #' @param path The path of files to load #' @param source The name of external data source -#' @param schema The data schema defined in structType +#' @param schema The data schema defined in structType or a DDL-formatted string. #' @param na.strings Default string value for NA when source is "csv" #' @param ... additional external data source specific named properties. #' @return SparkDataFrame @@ -600,6 +600,8 @@ tableToDF <- function(tableName) { #' structField("info", "map<string,double>")) #' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") +#' stringSchema <- "name STRING, info MAP<STRING, DOUBLE>" +#' df4 <- read.df(mapTypeJsonPath, "json", stringSchema, multiLine = TRUE) #' } #' @name read.df #' @method read.df default @@ -623,14 +625,19 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string if (source == "csv" && is.null(options[["nullValue"]])) { options[["nullValue"]] <- na.strings } + read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, schema$jobj, options) - } else { - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, options) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") dataFrame(sdf) } @@ -717,8 +724,8 @@ read.jdbc <- function(url, tableName, #' "spark.sql.sources.default" will be used. #' #' @param source The name of external data source -#' @param schema The data schema defined in structType, this is required for file-based streaming -#' data source +#' @param schema The data schema defined in structType or a DDL-formatted string, this is +#' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for #' file-based streaming data source #' @return SparkDataFrame @@ -733,6 +740,8 @@ read.jdbc <- function(url, tableName, #' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") #' #' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' stringSchema <- "name STRING, info MAP<STRING, DOUBLE>" +#' df1 <- read.stream("json", path = jsonDir, schema = stringSchema, maxFilesPerTrigger = 1) #' } #' @name read.stream #' @note read.stream since 2.2.0 @@ -750,10 +759,15 @@ read.stream <- function(source = NULL, schema = NULL, ...) { read <- callJMethod(sparkSession, "readStream") read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - read <- callJMethod(read, "schema", schema$jobj) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } read <- callJMethod(read, "options", options) sdf <- handledCallJMethod(read, "load") - dataFrame(callJMethod(sdf, "toDF")) + dataFrame(sdf) } diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 911b73b9ee5512a7a0fcbc960b95713a45e2fae3..a2bcb5aefe16de65303ba3691788252ada990b17 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3248,9 +3248,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", + paste("Error in load : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) - expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.df("arbitrary_path"), "Error in load : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") @@ -3268,6 +3268,22 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.df with a user defined schema in a DDL-formatted string. + df1 <- read.df(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + expect_error(read.df(jsonPath, "json", "name stri"), "DataType stri is not supported.") + + # Test loadDF with a user defined schema in a DDL-formatted string. + df2 <- loadDF(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + expect_error(loadDF(jsonPath, "json", "name stri"), "DataType stri is not supported.") +}) + test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { ldf <- data.frame(col1 = c(0, 1, 2), col2 = c(as.POSIXct("2017-01-01 00:00:01"), diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index d691de7cd725dd90471b6e306984105a0a1db961..54f40bbd5f517f3c253cfc15b5da3cef60a979f1 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -46,6 +46,8 @@ schema <- structType(structField("name", "string"), structField("age", "integer"), structField("count", "double")) +stringSchema <- "name STRING, age INTEGER, count DOUBLE" + test_that("read.stream, write.stream, awaitTermination, stopQuery", { df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) @@ -111,6 +113,27 @@ test_that("Stream other format", { unlink(parquetPath) }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.stream with a user defined schema in a DDL-formatted string. + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_error(read.stream(path = parquetPath, schema = "name stri"), + "DataType stri is not supported.") + + unlink(parquetPath) +}) + test_that("Non-streaming DataFrame", { c <- as.DataFrame(cars) expect_false(isStreaming(c)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d94e528a3ad4715e5e37783b23f4ff61a325db67..9bd2987057dbcf1ffb4a2161df5150b92a462513 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -193,21 +193,6 @@ private[sql] object SQLUtils extends Logging { } } - def loadDF( - sparkSession: SparkSession, - source: String, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).options(options).load() - } - - def loadDF( - sparkSession: SparkSession, - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).schema(schema).options(options).load() - } - def readSqlObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 's' =>