Skip to content
Snippets Groups Projects
Commit a5c2961c authored by Tarek Auel's avatar Tarek Auel Committed by Davies Liu
Browse files

[SPARK-8235] [SQL] misc function sha / sha1

Jira: https://issues.apache.org/jira/browse/SPARK-8235

I added the support for sha1. If I understood rxin correctly, sha and sha1 should execute the same algorithm, shouldn't they?

Please take a close look on the Python part. This is adopted from #6934

Author: Tarek Auel <tarek.auel@gmail.com>
Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #6963 from tarekauel/SPARK-8235 and squashes the following commits:

f064563 [Tarek Auel] change to shaHex
7ce3cdc [Tarek Auel] rely on automatic cast
a1251d6 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-8235
68eb043 [Tarek Auel] added docstring
be5aff1 [Tarek Auel] improved error message
7336c96 [Tarek Auel] added type check
cf23a80 [Tarek Auel] simplified example
ebf75ef [Tarek Auel] [SPARK-8301] updated the python documentation. Removed sha in python and scala
6d6ff0d [Tarek Auel] [SPARK-8233] added docstring
ea191a9 [Tarek Auel] [SPARK-8233] fixed signatureof python function. Added expected type to misc
e3fd7c3 [Tarek Auel] SPARK[8235] added sha to the list of __all__
e5dad4e [Tarek Auel] SPARK[8235] sha / sha1
parent 3664ee25
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,7 @@ __all__ = [ ...@@ -42,6 +42,7 @@ __all__ = [
'monotonicallyIncreasingId', 'monotonicallyIncreasingId',
'rand', 'rand',
'randn', 'randn',
'sha1',
'sha2', 'sha2',
'sparkPartitionId', 'sparkPartitionId',
'struct', 'struct',
...@@ -382,6 +383,19 @@ def sha2(col, numBits): ...@@ -382,6 +383,19 @@ def sha2(col, numBits):
return Column(jc) return Column(jc)
@ignore_unicode_prefix
@since(1.5)
def sha1(col):
"""Returns the hex string result of SHA-1.
>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
[Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.sha1(_to_java_column(col))
return Column(jc)
@since(1.4) @since(1.4)
def sparkPartitionId(): def sparkPartitionId():
"""A column for partition ID of the Spark task. """A column for partition ID of the Spark task.
......
...@@ -136,6 +136,8 @@ object FunctionRegistry { ...@@ -136,6 +136,8 @@ object FunctionRegistry {
// misc functions // misc functions
expression[Md5]("md5"), expression[Md5]("md5"),
expression[Sha2]("sha2"), expression[Sha2]("sha2"),
expression[Sha1]("sha1"),
expression[Sha1]("sha"),
// aggregate functions // aggregate functions
expression[Average]("avg"), expression[Average]("avg"),
......
...@@ -21,8 +21,9 @@ import java.security.MessageDigest ...@@ -21,8 +21,9 @@ import java.security.MessageDigest
import java.security.NoSuchAlgorithmException import java.security.NoSuchAlgorithmException
import org.apache.commons.codec.digest.DigestUtils import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType} import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String
/** /**
...@@ -140,3 +141,30 @@ case class Sha2(left: Expression, right: Expression) ...@@ -140,3 +141,30 @@ case class Sha2(left: Expression, right: Expression)
""" """
} }
} }
/**
* A function that calculates a sha1 hash value and returns it as a hex string
* For input of type [[BinaryType]] or [[StringType]]
*/
case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]]))
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
"org.apache.spark.unsafe.types.UTF8String.fromString" +
s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
)
}
}
...@@ -31,6 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -31,6 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Md5(Literal.create(null, BinaryType)), null) checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
} }
test("sha1") {
checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
"5d211bad8f4ee70e16c7d343a838fc344a1ed961")
checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
}
test("sha2") { test("sha2") {
checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
......
...@@ -1414,6 +1414,22 @@ object functions { ...@@ -1414,6 +1414,22 @@ object functions {
*/ */
def md5(columnName: String): Column = md5(Column(columnName)) def md5(columnName: String): Column = md5(Column(columnName))
/**
* Calculates the SHA-1 digest and returns the value as a 40 character hex string.
*
* @group misc_funcs
* @since 1.5.0
*/
def sha1(e: Column): Column = Sha1(e.expr)
/**
* Calculates the SHA-1 digest and returns the value as a 40 character hex string.
*
* @group misc_funcs
* @since 1.5.0
*/
def sha1(columnName: String): Column = sha1(Column(columnName))
/** /**
* Calculates the SHA-2 family of hash functions and returns the value as a hex string. * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
* *
......
...@@ -144,6 +144,18 @@ class DataFrameFunctionsSuite extends QueryTest { ...@@ -144,6 +144,18 @@ class DataFrameFunctionsSuite extends QueryTest {
Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
} }
test("misc sha1 function") {
val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b")
checkAnswer(
df.select(sha1($"a"), sha1("b")),
Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8"))
val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b")
checkAnswer(
dfEmpty.selectExpr("sha1(a)", "sha1(b)"),
Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709"))
}
test("misc sha2 function") { test("misc sha2 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer( checkAnswer(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment