diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 0fc320fb0887688276bb6dabd01d8d301ee66ef8..45b7e4d3405c85d50427a852192829a4801b5c08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.lang.{Long => JLong} -import java.util.Arrays +import java.{lang => jl} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -206,7 +204,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu if (evalE == null) { null } else { - val input = evalE.asInstanceOf[Integer] + val input = evalE.asInstanceOf[jl.Integer] if (input > 20 || input < 0) { null } else { @@ -290,7 +288,7 @@ case class Bin(child: Expression) if (evalE == null) { null } else { - UTF8String.fromString(JLong.toBinaryString(evalE.asInstanceOf[Long])) + UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long])) } } @@ -300,27 +298,18 @@ case class Bin(child: Expression) } } - /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) extends UnaryExpression with Serializable { +case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes { + // TODO: Create code-gen version. - override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, StringType, BinaryType)) - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[StringType] - || child.dataType.isInstanceOf[IntegerType] - || child.dataType.isInstanceOf[LongType] - || child.dataType.isInstanceOf[BinaryType] - || child.dataType == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type") - } - } + override def dataType: DataType = StringType override def eval(input: InternalRow): Any = { val num = child.eval(input) @@ -329,7 +318,6 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { } else { child.dataType match { case LongType => hex(num.asInstanceOf[Long]) - case IntegerType => hex(num.asInstanceOf[Integer].toLong) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) case StringType => hex(num.asInstanceOf[UTF8String]) } @@ -371,7 +359,55 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable { Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte numBuf >>>= 4 } while (numBuf != 0) - UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) + } +} + + +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes { + // TODO: Create code-gen version. + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def dataType: DataType = BinaryType + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].getBytes) + } + } + + private val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } + + private def unhex(inputBytes: Array[Byte]): Array[Byte] = { + var bytes = inputBytes + if ((bytes.length & 0x01) != 0) { + bytes = '0'.toByte +: bytes + } + val out = new Array[Byte](bytes.length >> 1) + // two characters form the hex value. + var i = 0 + while (i < bytes.length) { + val first = unhexDigits(bytes(i)) + val second = unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { return null} + out(i / 2) = (((first << 4) | second) & 0xFF).toByte + i += 2 + } + out } } @@ -423,22 +459,19 @@ case class Pow(left: Expression, right: Expression) } } -case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression { - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess - case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => - return TypeCheckResult.TypeCheckSuccess - case _ => // failed - } - case _ => // failed - } - TypeCheckResult.TypeCheckFailure( - s"ShiftLeft expects long, integer, short or byte value as first argument and an " + - s"integer value as second argument, not (${left.dataType}, ${right.dataType})") - } +/** + * Bitwise unsigned left shift. + * @param left the base number to shift. + * @param right number of bits to left shift. + */ +case class ShiftLeft(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType override def eval(input: InternalRow): Any = { val valueLeft = left.eval(input) @@ -446,10 +479,8 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi val valueRight = right.eval(input) if (valueRight != null) { valueLeft match { - case l: Long => l << valueRight.asInstanceOf[Integer] - case i: Integer => i << valueRight.asInstanceOf[Integer] - case s: Short => s << valueRight.asInstanceOf[Integer] - case b: Byte => b << valueRight.asInstanceOf[Integer] + case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer] } } else { null @@ -459,35 +490,24 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi } } - override def dataType: DataType = { - left.dataType match { - case LongType => LongType - case IntegerType | ShortType | ByteType => IntegerType - case _ => NullType - } - } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") } } -case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression { - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess - case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => - return TypeCheckResult.TypeCheckSuccess - case _ => // failed - } - case _ => // failed - } - TypeCheckResult.TypeCheckFailure( - s"ShiftRight expects long, integer, short or byte value as first argument and an " + - s"integer value as second argument, not (${left.dataType}, ${right.dataType})") - } +/** + * Bitwise unsigned left shift. + * @param left the base number to shift. + * @param right number of bits to left shift. + */ +case class ShiftRight(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType override def eval(input: InternalRow): Any = { val valueLeft = left.eval(input) @@ -495,10 +515,8 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress val valueRight = right.eval(input) if (valueRight != null) { valueLeft match { - case l: Long => l >> valueRight.asInstanceOf[Integer] - case i: Integer => i >> valueRight.asInstanceOf[Integer] - case s: Short => s >> valueRight.asInstanceOf[Integer] - case b: Byte => b >> valueRight.asInstanceOf[Integer] + case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer] } } else { null @@ -508,35 +526,24 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress } } - override def dataType: DataType = { - left.dataType match { - case LongType => LongType - case IntegerType | ShortType | ByteType => IntegerType - case _ => NullType - } - } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") } } -case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression { - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess - case (_, IntegerType) => left.dataType match { - case LongType | IntegerType | ShortType | ByteType => - return TypeCheckResult.TypeCheckSuccess - case _ => // failed - } - case _ => // failed - } - TypeCheckResult.TypeCheckFailure( - s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " + - s"integer value as second argument, not (${left.dataType}, ${right.dataType})") - } +/** + * Bitwise unsigned right shift, for integer and long data type. + * @param left the base number. + * @param right the number of bits to right shift. + */ +case class ShiftRightUnsigned(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType override def eval(input: InternalRow): Any = { val valueLeft = left.eval(input) @@ -544,10 +551,8 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar val valueRight = right.eval(input) if (valueRight != null) { valueLeft match { - case l: Long => l >>> valueRight.asInstanceOf[Integer] - case i: Integer => i >>> valueRight.asInstanceOf[Integer] - case s: Short => s >>> valueRight.asInstanceOf[Integer] - case b: Byte => b >>> valueRight.asInstanceOf[Integer] + case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer] } } else { null @@ -557,74 +562,21 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar } } - override def dataType: DataType = { - left.dataType match { - case LongType => LongType - case IntegerType | ShortType | ByteType => IntegerType - case _ => NullType - } - } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;") } } -/** - * Performs the inverse operation of HEX. - * Resulting characters are returned as a byte array. - */ -case class UnHex(child: Expression) extends UnaryExpression with Serializable { - - override def dataType: DataType = BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}") - } - } - - override def eval(input: InternalRow): Any = { - val num = child.eval(input) - if (num == null) { - null - } else { - unhex(num.asInstanceOf[UTF8String].getBytes) - } - } - - private val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { - var bytes = inputBytes - if ((bytes.length & 0x01) != 0) { - bytes = '0'.toByte +: bytes - } - val out = new Array[Byte](bytes.length >> 1) - // two characters form the hex value. - var i = 0 - while (i < bytes.length) { - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { return null} - out(i / 2) = (((first << 4) | second) & 0xFF).toByte - i += 2 - } - out - } -} case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +/** + * Computes the logarithm of a number. + * @param left the logarithm base, default to e. + * @param right the number to compute the logarithm of. + */ case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { @@ -642,7 +594,7 @@ case class Logarithm(left: Expression, right: Expression) defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") } logCode + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { + if (Double.isNaN(${ev.primitive})) { ${ev.isNull} = true; } """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 20839c83d4fd09e531957bdda60984817ad89d4f..03d8400cf356bee40c969c2307d9294f34cd126a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -161,11 +161,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("factorial") { - val dataLong = (0 to 20) - dataLong.foreach { value => + (0 to 20).foreach { value => checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow) } - checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null)) + checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) } @@ -244,10 +243,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) - checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42) - checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42) - checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) + checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) } @@ -257,10 +254,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) - checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21) - checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21) - checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) } @@ -270,16 +265,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21) - checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21) - checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21) - checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) } test("hex") { - checkEvaluation(Hex(Literal(28)), "1C") - checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") @@ -313,6 +304,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) } + + // null input should yield null output checkEvaluation( Logarithm(Literal.create(null, DoubleType), Literal(1.0)), null, @@ -321,5 +314,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) + + // negative input should yield null output + checkEvaluation( + Logarithm(Literal(-1.0), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal(-1.0)), + null, + create_row(null)) } }