From 47735cdc2a878cfdbe76316d3ff8314a45dabf54 Mon Sep 17 00:00:00 2001 From: "Oscar D. Lara Yejas" <olarayej@mail.usf.edu> Date: Tue, 10 Nov 2015 11:07:57 -0800 Subject: [PATCH] [SPARK-10863][SPARKR] Method coltypes() (New version) This is a follow up on PR #8984, as the corresponding branch for such PR was damaged. Author: Oscar D. Lara Yejas <olarayej@mail.usf.edu> Closes #9579 from olarayej/SPARK-10863_NEW14. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 6 ++-- R/pkg/R/DataFrame.R | 49 ++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 +++ R/pkg/R/schema.R | 15 +--------- R/pkg/R/types.R | 43 ++++++++++++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 24 +++++++++++++++- 7 files changed, 124 insertions(+), 18 deletions(-) create mode 100644 R/pkg/R/types.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 3d6edb70ec..369714f7b9 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -34,4 +34,5 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'types.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 56b8ed0bf2..52fd6c9f76 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -23,9 +23,11 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "as.data.frame", "attach", "cache", "collect", + "coltypes", "columns", "count", "cov", @@ -262,6 +264,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") - -export("as.data.frame") + "print.structType") \ No newline at end of file diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e9013aa34a..cc868069d1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,3 +2152,52 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) + +#' Returns the column types of a DataFrame. +#' +#' @name coltypes +#' @title Get column types of a DataFrame +#' @family dataframe_funcs +#' @param x (DataFrame) +#' @return value (character) A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @examples \dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#' } +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) \ No newline at end of file diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index efef7d66b5..89731affeb 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1047,3 +1047,7 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 6f0e9a94e9..c6ddb56227 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -115,20 +115,7 @@ structField.jobj <- function(x) { } checkType <- function(type) { - primtiveTypes <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - if (type %in% primtiveTypes) { + if (!is.null(PRIMITIVE_TYPES[[type]])) { return() } else { # Check complex types diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R new file mode 100644 index 0000000000..1828c23ab0 --- /dev/null +++ b/R/pkg/R/types.R @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# types.R. This file handles the data type mapping between Spark and R + +# The primitive data types, where names(PRIMITIVE_TYPES) are Scala types whereas +# values are equivalent R types. This is stored in an environment to allow for +# more efficient look up (environments use hashmaps). +PRIMITIVE_TYPES <- as.environment(list( + "byte"="integer", + "tinyint"="integer", + "smallint"="integer", + "integer"="integer", + "bigint"="numeric", + "float"="numeric", + "double"="numeric", + "decimal"="numeric", + "string"="character", + "binary"="raw", + "boolean"="logical", + "timestamp"="POSIXct", + "date"="Date")) + +# The complex data types. These do not have any direct mapping to R's types. +COMPLEX_TYPES <- list( + "map"=NA, + "array"=NA, + "struct"=NA) + +# The full list of data types. +DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index fbdb9a8f1e..06f52d021c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1467,8 +1467,9 @@ test_that("SQL error message is returned from JVM", { expect_equal(grepl("Table not found: blah", retError), TRUE) }) +irisDF <- createDataFrame(sqlContext, iris) + test_that("Method as.data.frame as a synonym for collect()", { - irisDF <- createDataFrame(sqlContext, iris) expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -1503,6 +1504,27 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) +test_that("Method coltypes() to get R's data types of a DataFrame", { + expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) + + data <- data.frame(c1=c(1,2,3), + c2=c(T,F,T), + c3=c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) + + schema <- structType(structField("c1", "byte"), + structField("c3", "boolean"), + structField("c4", "timestamp")) + + # Test primitive types + DF <- createDataFrame(sqlContext, data, schema) + expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + + # Test complex types + x <- createDataFrame(sqlContext, list(list(as.environment( + list("a"="b", "c"="d", "e"="f"))))) + expect_equal(coltypes(x), "map<string,string>") +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) -- GitLab