From 6b3574e68704d58ba41efe0ea4fe928cc166afcd Mon Sep 17 00:00:00 2001 From: Tarek Auel <tarek.auel@googlemail.com> Date: Sat, 4 Jul 2015 01:10:52 -0700 Subject: [PATCH] [SPARK-8270][SQL] levenshtein distance Jira: https://issues.apache.org/jira/browse/SPARK-8270 Info: I can not build the latest master, it stucks during the build process: `[INFO] Dependency-reduced POM written at: /Users/tarek/test/spark/bagel/dependency-reduced-pom.xml` Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7214 from tarekauel/SPARK-8270 and squashes the following commits: ab348b9 [Tarek Auel] Merge branch 'master' into SPARK-8270 a2ad318 [Tarek Auel] [SPARK-8270] changed order of fields d91b12c [Tarek Auel] [SPARK-8270] python fix adbd075 [Tarek Auel] [SPARK-8270] fixed typo 23185c9 [Tarek Auel] [SPARK-8270] levenshtein distance --- python/pyspark/sql/functions.py | 14 ++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringOperations.scala | 32 +++++++++++++++++++ .../expressions/StringFunctionsSuite.scala | 9 ++++++ .../org/apache/spark/sql/functions.scala | 23 ++++++++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 6 ++++ 6 files changed, 81 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 69e563ef36..49dd0332af 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -323,6 +323,20 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def levenshtein(left, right): + """Computes the Levenshtein distance of the two given strings. + + >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + [Row(d=3)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def md5(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e249b58927..92a50e7092 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -163,6 +163,7 @@ object FunctionRegistry { expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), + expression[Levenshtein]("levenshtein"), expression[Substring]("substr"), expression[Substring]("substring"), expression[UnBase64]("unbase64"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 154ac3508c..6de40629ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern +import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -299,6 +300,37 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI override def prettyName: String = "length" } +/** + * A function that return the Levenshtein distance between the two given strings. + */ +case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression + with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override def dataType: DataType = IntegerType + + override def eval(input: InternalRow): Any = { + val leftValue = left.eval(input) + if (leftValue == null) { + null + } else { + val rightValue = right.eval(input) + if(rightValue == null) { + null + } else { + StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val stringUtils = classOf[StringUtils].getName + nullSafeCodeGen(ctx, ev, (res, left, right) => + s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());") + } +} + /** * Returns the numeric value of the first character of str. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 468df20442..1efbe1a245 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -274,4 +274,13 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null) checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null)) } + + test("Levenshtein distance") { + checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null) + checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null) + checkEvaluation(Levenshtein(Literal(""), Literal("")), 0) + checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0) + checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3) + checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b63c6ee8ab..e4109da08e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1580,21 +1580,36 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value - * + * Computes the length of a given string value. + * * @group string_funcs * @since 1.5.0 */ def strlen(e: Column): Column = StringLength(e.expr) /** - * Computes the length of a given string column - * + * Computes the length of a given string column. + * * @group string_funcs * @since 1.5.0 */ def strlen(columnName: String): Column = strlen(Column(columnName)) + /** + * Computes the Levenshtein distance of the two given strings. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + + /** + * Computes the Levenshtein distance of the two given strings. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(leftColumnName: String, rightColumnName: String): Column = + levenshtein(Column(leftColumnName), Column(rightColumnName)) + /** * Computes the numeric value of the first character of the specified string value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bd9fa400e5..bc455a922d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -226,6 +226,12 @@ class DataFrameFunctionsSuite extends QueryTest { }) } + test("Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( -- GitLab