From 6996bd2e81bf6597dcda499d9a9a80927a43e30f Mon Sep 17 00:00:00 2001
From: "zhichao.li" <zhichao.li@intel.com>
Date: Fri, 31 Jul 2015 21:18:01 -0700
Subject: [PATCH] [SPARK-8264][SQL]add substring_index function

This PR is based on #7533 , thanks to zhichao-li

Closes #7533

Author: zhichao.li <zhichao.li@intel.com>
Author: Davies Liu <davies@databricks.com>

Closes #7843 from davies/str_index and squashes the following commits:

391347b [Davies Liu] add python api
3ce7802 [Davies Liu] fix substringIndex
f2d29a1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into str_index
515519b [zhichao.li] add foldable and remove null checking
9546991 [zhichao.li] scala style
67c253a [zhichao.li] hide some apis and clean code
b19b013 [zhichao.li] add codegen and clean code
ac863e9 [zhichao.li] reduce the calling of numChars
12e108f [zhichao.li] refine unittest
d92951b [zhichao.li] add lastIndexOf
52d7b03 [zhichao.li] add substring_index function
---
 python/pyspark/sql/functions.py               | 19 +++++
 .../catalyst/analysis/FunctionRegistry.scala  |  1 +
 .../expressions/stringOperations.scala        | 25 ++++++
 .../expressions/StringExpressionsSuite.scala  | 31 +++++++
 .../org/apache/spark/sql/functions.scala      | 12 ++-
 .../spark/sql/StringFunctionsSuite.scala      | 57 +++++++++++++
 .../apache/spark/unsafe/types/UTF8String.java | 80 ++++++++++++++++++-
 .../spark/unsafe/types/UTF8StringSuite.java   | 38 +++++++++
 8 files changed, 261 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bb9926ce8c..89a2a5ceaa 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -920,6 +920,25 @@ def trunc(date, format):
     return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
 
 
+@since(1.5)
+@ignore_unicode_prefix
+def substring_index(str, delim, count):
+    """
+    Returns the substring from string str before count occurrences of the delimiter delim.
+    If count is positive, everything the left of the final delimiter (counting from left) is
+    returned. If count is negative, every to the right of the final delimiter (counting from the
+    right) is returned. substring_index performs a case-sensitive match when searching for delim.
+
+    >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s'])
+    >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()
+    [Row(s=u'a.b')]
+    >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()
+    [Row(s=u'b.c.d')]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count))
+
+
 @since(1.5)
 def size(col):
     """
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3f61a9af1f..ee44cbcba6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -199,6 +199,7 @@ object FunctionRegistry {
     expression[StringSplit]("split"),
     expression[Substring]("substr"),
     expression[Substring]("substring"),
+    expression[SubstringIndex]("substring_index"),
     expression[StringTrim]("trim"),
     expression[UnBase64]("unbase64"),
     expression[Upper]("ucase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 160e72f384..5dd387a418 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -421,6 +421,31 @@ case class StringInstr(str: Expression, substr: Expression)
   }
 }
 
+/**
+ * Returns the substring from string str before count occurrences of the delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting from left) is
+ * returned. If count is negative, every to the right of the final delimiter (counting from the
+ * right) is returned. substring_index performs a case-sensitive match when searching for delim.
+ */
+case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = StringType
+  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+  override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
+  override def prettyName: String = "substring_index"
+
+  override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
+    str.asInstanceOf[UTF8String].subStringIndex(
+      delim.asInstanceOf[UTF8String],
+      count.asInstanceOf[Int])
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
+  }
+}
+
 /**
  * A function that returns the position of the first occurrence of substr
  * in given string after position pos.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index fb72fe1714..ad87ab36fd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 
 class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(s.substring(0), "example", row)
   }
 
+  test("string substring_index function") {
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
+    checkEvaluation(
+      SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
+    checkEvaluation(
+      SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
+    checkEvaluation(SubstringIndex(
+        Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
+    // non ascii chars
+    // scalastyle:off
+    checkEvaluation(
+      SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大")
+    // scalastyle:on
+    checkEvaluation(
+      SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
+  }
+
   test("LIKE literal Regular Expression") {
     checkEvaluation(Literal.create(null, StringType).like("a"), null)
     checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 89ffa9c50d..57bb00a741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1788,8 +1788,18 @@ object functions {
   def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
 
   /**
-   * Locate the position of the first occurrence of substr in a string column.
+   * Returns the substring from string str before count occurrences of the delimiter delim.
+   * If count is positive, everything the left of the final delimiter (counting from left) is
+   * returned. If count is negative, every to the right of the final delimiter (counting from the
+   * right) is returned. substring_index performs a case-sensitive match when searching for delim.
    *
+   * @group string_funcs
+   */
+  def substring_index(str: Column, delim: String, count: Int): Column =
+    SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
+
+  /**
+   * Locate the position of the first occurrence of substr.
    * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
    * could not be found in str.
    *
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index b7f073cccb..628da95298 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -163,6 +163,63 @@ class StringFunctionsSuite extends QueryTest {
       Row(1))
   }
 
+  test("string substring_index function") {
+    val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c")
+    checkAnswer(
+      df.select(substring_index($"a", ".", 3)),
+      Row("www.apache.org"))
+    checkAnswer(
+      df.select(substring_index($"a", ".", 2)),
+      Row("www.apache"))
+    checkAnswer(
+      df.select(substring_index($"a", ".", 1)),
+      Row("www"))
+    checkAnswer(
+      df.select(substring_index($"a", ".", 0)),
+      Row(""))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), ".", -1)),
+      Row("org"))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), ".", -2)),
+      Row("apache.org"))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), ".", -3)),
+      Row("www.apache.org"))
+    // str is empty string
+    checkAnswer(
+      df.select(substring_index(lit(""), ".", 1)),
+      Row(""))
+    // empty string delim
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), "", 1)),
+      Row(""))
+    // delim does not exist in str
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), "#", 1)),
+      Row("www.apache.org"))
+    // delim is 2 chars
+    checkAnswer(
+      df.select(substring_index(lit("www||apache||org"), "||", 2)),
+      Row("www||apache"))
+    checkAnswer(
+      df.select(substring_index(lit("www||apache||org"), "||", -2)),
+      Row("apache||org"))
+    // null
+    checkAnswer(
+      df.select(substring_index(lit(null), "||", 2)),
+      Row(null))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), null, 2)),
+      Row(null))
+    // non ascii chars
+    // scalastyle:off
+    checkAnswer(
+      df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""),
+      Row("大千世界大"))
+    // scalastyle:on
+  }
+
   test("string locate function") {
     val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
 
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 9d4998fd48..2561c1c2a1 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -198,7 +198,7 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
    */
   public UTF8String substring(final int start, final int until) {
     if (until <= start || start >= numBytes) {
-      return fromBytes(new byte[0]);
+      return UTF8String.EMPTY_UTF8;
     }
 
     int i = 0;
@@ -406,6 +406,84 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
     return -1;
   }
 
+  /**
+   * Find the `str` from left to right.
+   */
+  private int find(UTF8String str, int start) {
+    assert (str.numBytes > 0);
+    while (start <= numBytes - str.numBytes) {
+      if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+        return start;
+      }
+      start += 1;
+    }
+    return -1;
+  }
+
+  /**
+   * Find the `str` from right to left.
+   */
+  private int rfind(UTF8String str, int start) {
+    assert (str.numBytes > 0);
+    while (start >= 0) {
+      if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
+        return start;
+      }
+      start -= 1;
+    }
+    return -1;
+  }
+
+  /**
+   * Returns the substring from string str before count occurrences of the delimiter delim.
+   * If count is positive, everything the left of the final delimiter (counting from left) is
+   * returned. If count is negative, every to the right of the final delimiter (counting from the
+   * right) is returned. subStringIndex performs a case-sensitive match when searching for delim.
+   */
+  public UTF8String subStringIndex(UTF8String delim, int count) {
+    if (delim.numBytes == 0 || count == 0) {
+      return EMPTY_UTF8;
+    }
+    if (count > 0) {
+      int idx = -1;
+      while (count > 0) {
+        idx = find(delim, idx + 1);
+        if (idx >= 0) {
+          count --;
+        } else {
+          // can not find enough delim
+          return this;
+        }
+      }
+      if (idx == 0) {
+        return EMPTY_UTF8;
+      }
+      byte[] bytes = new byte[idx];
+      copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx);
+      return fromBytes(bytes);
+
+    } else {
+      int idx = numBytes - delim.numBytes + 1;
+      count = -count;
+      while (count > 0) {
+        idx = rfind(delim, idx - 1);
+        if (idx >= 0) {
+          count --;
+        } else {
+          // can not find enough delim
+          return this;
+        }
+      }
+      if (idx + delim.numBytes == numBytes) {
+        return EMPTY_UTF8;
+      }
+      int size = numBytes - delim.numBytes - idx;
+      byte[] bytes = new byte[size];
+      copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size);
+      return fromBytes(bytes);
+    }
+  }
+
   /**
    * Returns str, right-padded with pad to a length of len
    * For example:
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index c565210872..43eed70632 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -240,6 +240,44 @@ public class UTF8StringSuite {
     assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0));
   }
 
+  @Test
+  public void substring_index() {
+    assertEquals(fromString("www.apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), 3));
+    assertEquals(fromString("www.apache"),
+      fromString("www.apache.org").subStringIndex(fromString("."), 2));
+    assertEquals(fromString("www"),
+      fromString("www.apache.org").subStringIndex(fromString("."), 1));
+    assertEquals(fromString(""),
+      fromString("www.apache.org").subStringIndex(fromString("."), 0));
+    assertEquals(fromString("org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), -1));
+    assertEquals(fromString("apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), -2));
+    assertEquals(fromString("www.apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), -3));
+    // str is empty string
+    assertEquals(fromString(""),
+      fromString("").subStringIndex(fromString("."), 1));
+    // empty string delim
+    assertEquals(fromString(""),
+      fromString("www.apache.org").subStringIndex(fromString(""), 1));
+    // delim does not exist in str
+    assertEquals(fromString("www.apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("#"), 2));
+    // delim is 2 chars
+    assertEquals(fromString("www||apache"),
+      fromString("www||apache||org").subStringIndex(fromString("||"), 2));
+    assertEquals(fromString("apache||org"),
+      fromString("www||apache||org").subStringIndex(fromString("||"), -2));
+    // non ascii chars
+    assertEquals(fromString("大千世界大"),
+      fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));
+    // overlapped delim
+    assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3));
+    assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4));
+  }
+
   @Test
   public void reverse() {
     assertEquals(fromString("olleh"), fromString("hello").reverse());
-- 
GitLab