diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index bc0846646174aad977ff81c06ac7a256f9d52c70..6140d1b129c64bc28e63b1bd866782fde45c74e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -177,6 +177,7 @@ object FunctionRegistry { expression[ConcatWs]("concat_ws"), expression[Encode]("encode"), expression[Decode]("decode"), + expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), expression[InitCap]("initcap"), expression[Lower]("lcase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 56225290cd6b1765cd3526ce101d904448d909a1..0cc785d9f3a49d6168bd89b69e82e900cab40ab4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat -import java.util.Arrays -import java.util.Locale +import java.util.{Arrays, Locale} import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -350,6 +349,28 @@ case class EndsWith(left: Expression, right: Expression) } } +/** + * A function that returns the index (1-based) of the given string (left) in the comma- + * delimited list (right). Returns 0, if the string wasn't found or if the given + * string (left) contains a comma. + */ +case class FindInSet(left: Expression, right: Expression) extends BinaryExpression + with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override protected def nullSafeEval(word: Any, set: Any): Any = + set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (word, set) => + s"${ev.primitive} = $set.findInSet($word);" + ) + } + + override def dataType: DataType = IntegerType +} + /** * A function that trim the spaces from both ends for the specified string. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 906be701beed7fba99557f49f481ff8fa0025d9c..23f36ca43d663078c665fbbad9dff2dde1bdcabc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -675,4 +675,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) } + + test("find in set") { + checkEvaluation( + FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null) + checkEvaluation(FindInSet(Literal("ab"), Literal.create(null, StringType)), null) + checkEvaluation(FindInSet(Literal.create(null, StringType), Literal("abc,b,ab,c,def")), null) + checkEvaluation(FindInSet(Literal("ab"), Literal("abc,b,ab,c,def")), 3) + checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0) + checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 431dcf7382f163f338f5fcf9ea99cbe7939e74db..6137527757f8576d7727818fe6303c34f3b25ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,6 +208,14 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } + test("string function find_in_set") { + val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b") + + checkAnswer( + df.selectExpr("find_in_set('ab', a)", "find_in_set('x', b)"), + Row(3, 0)) + } + test("conditional function: least") { checkAnswer( testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d80bd57bd2048f688b1a008f8fbf9a65bd824ab3..febbe3d4e54d10edb9d5f1f2107eec7e3940909a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -54,8 +54,9 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { 5, 5, 5, 5, 6, 6}; - private static ByteOrder byteOrder = ByteOrder.nativeOrder(); + private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); /** @@ -179,7 +180,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { // After getting the data, we use a mask to mask out data that is not part of the string. long p; long mask = 0; - if (byteOrder == ByteOrder.LITTLE_ENDIAN) { + if (isLittleEndian) { if (numBytes >= 8) { p = PlatformDependent.UNSAFE.getLong(base, offset); } else if (numBytes > 4) { @@ -411,6 +412,36 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { return fromString(sb.toString()); } + /* + * Returns the index of the string `match` in this String. This string has to be a comma separated + * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String, + * 0 will be returned, else the index of match (1-based index) + */ + public int findInSet(UTF8String match) { + if (match.contains(COMMA_UTF8)) { + return 0; + } + + int n = 1, lastComma = -1; + for (int i = 0; i < numBytes; i++) { + if (getByte(i) == (byte) ',') { + if (i - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + lastComma = i; + n++; + } + } + if (numBytes - (lastComma + 1) == match.numBytes && + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { + return n; + } + return 0; + } + /** * Copy the bytes from the current UTF8String, and make a new UTF8String. * @param start the start position of the current UTF8String in bytes. diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b3190f8f0c3fd545aa05a4b27162421e9ca20df..b30c94c1c1f80801be43a59b544eba0e933741e0 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -399,6 +399,18 @@ public class UTF8StringSuite { assertEquals(fromString(""), blankString(0)); } + @Test + public void findInSet() { + assertEquals(fromString("ab").findInSet(fromString("ab")), 1); + assertEquals(fromString("a,b").findInSet(fromString("b")), 2); + assertEquals(fromString("abc,b,ab,c,def").findInSet(fromString("ab")), 3); + assertEquals(fromString("ab,abc,b,ab,c,def").findInSet(fromString("ab")), 1); + assertEquals(fromString(",,,ab,abc,b,ab,c,def").findInSet(fromString("ab")), 4); + assertEquals(fromString(",ab,abc,b,ab,c,def").findInSet(fromString("")), 1); + assertEquals(fromString("æ•°æ®ç –头,abc,b,ab,c,def").findInSet(fromString("ab")), 4); + assertEquals(fromString("æ•°æ®ç –头,abc,b,ab,c,def").findInSet(fromString("def")), 6); + } + @Test public void soundex() { assertEquals(fromString("Robert").soundex(), fromString("R163"));