Skip to content
Snippets Groups Projects
Commit 6b3574e6 authored by Tarek Auel's avatar Tarek Auel Committed by Reynold Xin
Browse files

[SPARK-8270][SQL] levenshtein distance

Jira: https://issues.apache.org/jira/browse/SPARK-8270

Info: I can not build the latest master, it stucks during the build process: `[INFO] Dependency-reduced POM written at: /Users/tarek/test/spark/bagel/dependency-reduced-pom.xml`

Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7214 from tarekauel/SPARK-8270 and squashes the following commits:

ab348b9 [Tarek Auel] Merge branch 'master' into SPARK-8270
a2ad318 [Tarek Auel] [SPARK-8270] changed order of fields
d91b12c [Tarek Auel] [SPARK-8270] python fix
adbd075 [Tarek Auel] [SPARK-8270] fixed typo
23185c9 [Tarek Auel] [SPARK-8270] levenshtein distance
parent f35b0c34
No related branches found
No related tags found
No related merge requests found
......@@ -323,6 +323,20 @@ def explode(col):
return Column(jc)
@ignore_unicode_prefix
@since(1.5)
def levenshtein(left, right):
"""Computes the Levenshtein distance of the two given strings.
>>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
>>> df0.select(levenshtein('l', 'r').alias('d')).collect()
[Row(d=3)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
return Column(jc)
@ignore_unicode_prefix
@since(1.5)
def md5(col):
......
......@@ -163,6 +163,7 @@ object FunctionRegistry {
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Levenshtein]("levenshtein"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[UnBase64]("unbase64"),
......
......@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
......@@ -299,6 +300,37 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI
override def prettyName: String = "length"
}
/**
* A function that return the Levenshtein distance between the two given strings.
*/
case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def dataType: DataType = IntegerType
override def eval(input: InternalRow): Any = {
val leftValue = left.eval(input)
if (leftValue == null) {
null
} else {
val rightValue = right.eval(input)
if(rightValue == null) {
null
} else {
StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString)
}
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val stringUtils = classOf[StringUtils].getName
nullSafeCodeGen(ctx, ev, (res, left, right) =>
s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());")
}
}
/**
* Returns the numeric value of the first character of str.
*/
......
......@@ -274,4 +274,13 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null)
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}
test("Levenshtein distance") {
checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null)
checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null)
checkEvaluation(Levenshtein(Literal(""), Literal("")), 0)
checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0)
checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3)
checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1)
}
}
......@@ -1580,21 +1580,36 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
* Computes the length of a given string value
*
* Computes the length of a given string value.
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)
/**
* Computes the length of a given string column
*
* Computes the length of a given string column.
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))
/**
* Computes the Levenshtein distance of the two given strings.
* @group string_funcs
* @since 1.5.0
*/
def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr)
/**
* Computes the Levenshtein distance of the two given strings.
* @group string_funcs
* @since 1.5.0
*/
def levenshtein(leftColumnName: String, rightColumnName: String): Column =
levenshtein(Column(leftColumnName), Column(rightColumnName))
/**
* Computes the numeric value of the first character of the specified string value.
*
......
......@@ -226,6 +226,12 @@ class DataFrameFunctionsSuite extends QueryTest {
})
}
test("Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
}
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment