Skip to content
Snippets Groups Projects
Commit a1964e9d authored by Tarek Auel's avatar Tarek Auel Committed by Davies Liu
Browse files

[SPARK-8830] [SQL] native levenshtein distance

Jira: https://issues.apache.org/jira/browse/SPARK-8830

rxin and HuJiayin can you have a look on it.

Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7236 from tarekauel/native-levenshtein-distance and squashes the following commits:

ee4c4de [Tarek Auel] [SPARK-8830] implemented improvement proposals
c252e71 [Tarek Auel] [SPARK-8830] removed chartAt; use unsafe method for byte array comparison
ddf2222 [Tarek Auel] Merge branch 'master' into native-levenshtein-distance
179920a [Tarek Auel] [SPARK-8830] added description
5e9ed54 [Tarek Auel] [SPARK-8830] removed StringUtils import
dce4308 [Tarek Auel] [SPARK-8830] native levenshtein distance
parent 23448a9e
No related branches found
No related tags found
No related merge requests found
......@@ -284,13 +284,12 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
override def dataType: DataType = IntegerType
protected override def nullSafeEval(input1: Any, input2: Any): Any =
StringUtils.getLevenshteinDistance(input1.toString, input2.toString)
protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any =
leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String])
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val stringUtils = classOf[StringUtils].getName
defineCodeGen(ctx, ev, (left, right) =>
s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())")
nullSafeCodeGen(ctx, ev, (left, right) =>
s"${ev.primitive} = $left.levenshteinDistance($right);")
}
}
......
......@@ -282,5 +282,10 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0)
checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3)
checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3)
checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4)
// scalastyle:on
}
}
......@@ -99,8 +99,6 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
/**
* Returns the number of code points in it.
*
* This is only used by Substring() when `start` is negative.
*/
public int numChars() {
int len = 0;
......@@ -254,6 +252,70 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
}
/**
* Levenshtein distance is a metric for measuring the distance of two strings. The distance is
* defined by the minimum number of single-character edits (i.e. insertions, deletions or
* substitutions) that are required to change one of the strings into the other.
*/
public int levenshteinDistance(UTF8String other) {
// Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance
int n = numChars();
int m = other.numChars();
if (n == 0) {
return m;
} else if (m == 0) {
return n;
}
UTF8String s, t;
if (n <= m) {
s = this;
t = other;
} else {
s = other;
t = this;
int swap;
swap = n;
n = m;
m = swap;
}
int p[] = new int[n + 1];
int d[] = new int[n + 1];
int swap[];
int i, i_bytes, j, j_bytes, num_bytes_j, cost;
for (i = 0; i <= n; i++) {
p[i] = i;
}
for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) {
num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes));
d[0] = j + 1;
for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) {
if (s.getByte(i_bytes) != t.getByte(j_bytes) ||
num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) {
cost = 1;
} else {
cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base,
s.offset + i_bytes, num_bytes_j)) ? 0 : 1;
}
d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost);
}
swap = p;
p = d;
d = swap;
}
return p[n];
}
@Override
public int hashCode() {
int result = 1;
......
......@@ -128,4 +128,28 @@ public class UTF8StringSuite {
assertEquals(fromString("数据砖头").substring(3, 5), fromString("头"));
assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷"));
}
@Test
public void levenshteinDistance() {
assertEquals(
UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0);
assertEquals(
UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1);
assertEquals(
UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7);
assertEquals(
UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1);
assertEquals(
UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3);
assertEquals(
UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7);
assertEquals(
UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7);
assertEquals(
UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8);
assertEquals(
UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1);
assertEquals(
UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment