Skip to content
Snippets Groups Projects
Commit ab535b9a authored by zhichao.li's avatar zhichao.li Committed by Davies Liu
Browse files

[SPARK-8226] [SQL] Add function shiftrightunsigned

Author: zhichao.li <zhichao.li@intel.com>

Closes #7035 from zhichao-li/shiftRightUnsigned and squashes the following commits:

6bcca5a [zhichao.li] change coding style
3e9f5ae [zhichao.li] python style
d85ae0b [zhichao.li] add shiftrightunsigned
parent 2848f4da
No related branches found
No related tags found
No related merge requests found
......@@ -436,6 +436,19 @@ def shiftRight(col, numBits):
return Column(jc)
@since(1.5)
def shiftRightUnsigned(col, numBits):
"""Unsigned shift the the given value numBits right.
>>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
.collect()
[Row(r=9223372036854775787)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
return Column(jc)
@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
......
......@@ -129,6 +129,7 @@ object FunctionRegistry {
expression[Rint]("rint"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
......
......@@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
}
}
case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
case (_, IntegerType) => left.dataType match {
case LongType | IntegerType | ShortType | ByteType =>
return TypeCheckResult.TypeCheckSuccess
case _ => // failed
}
case _ => // failed
}
TypeCheckResult.TypeCheckFailure(
s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
}
override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
if (valueLeft != null) {
val valueRight = right.eval(input)
if (valueRight != null) {
valueLeft match {
case l: Long => l >>> valueRight.asInstanceOf[Integer]
case i: Integer => i >>> valueRight.asInstanceOf[Integer]
case s: Short => s >>> valueRight.asInstanceOf[Integer]
case b: Byte => b >>> valueRight.asInstanceOf[Integer]
}
} else {
null
}
} else {
null
}
}
override def dataType: DataType = {
left.dataType match {
case LongType => LongType
case IntegerType | ShortType | ByteType => IntegerType
case _ => NullType
}
}
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
}
}
/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
......
......@@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
}
test("shift right unsigned") {
checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
checkEvaluation(
ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
}
test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
......
......@@ -1343,6 +1343,26 @@ object functions {
*/
def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(columnName: String, numBits: Int): Column =
shiftRightUnsigned(Column(columnName), numBits)
/**
* Unsigned shift the the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(e: Column, numBits: Int): Column =
ShiftRightUnsigned(e.expr, lit(numBits).expr)
/**
* Shift the the given value numBits right. If the given value is a long value, it will return
* a long value else it will return an integer value.
......
......@@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
}
test("shift right unsigned") {
val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
.toDF("a", "b", "c", "d", "e", "f")
checkAnswer(
df.select(
shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
checkAnswer(
df.selectExpr(
"shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
"shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
}
test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment