diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 415ef4e4a37ec0d461bcf76f4106d77b90dd8b46..e14f0544c2b81ba6c70af6b9a7c3b3960108a55c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -268,15 +268,16 @@ abstract class HashExpression[E] extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" - val childrenHash = children.map { child => + val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } - }.mkString("\n") + }) + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") ev.copy(code = s""" - ${ctx.javaType(dataType)} ${ev.value} = $seed; + ${ev.value} = $seed; $childrenHash""") } @@ -600,15 +601,18 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" val childHash = ctx.freshName("childHash") - val childrenHash = children.map { child => + val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, childHash, ctx) - } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" - }.mkString(s"int $childHash = 0;", s"\n$childHash = 0;\n", "") + } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + + s"\n$childHash = 0;" + }) + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState("int", childHash, s"$childHash = 0;") ev.copy(code = s""" - ${ctx.javaType(dataType)} ${ev.value} = $seed; + ${ev.value} = $seed; $childrenHash""") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index c714bc03dc0d5b8e149db3697bdb49b4f28dada2..032629265269a4f13a634899e57443a686b1d468 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -24,7 +24,9 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -124,6 +126,26 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) + test("SPARK-18207: Compute hash for a lot of expressions") { + val N = 1000 + val wideRow = new GenericInternalRow( + Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema = StructType((1 to N).map(i => StructField("", StringType))) + + val exprs = schema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val murmur3HashExpr = Murmur3Hash(exprs, 42) + val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) + val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow) + assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval) + + val hiveHashExpr = HiveHash(exprs) + val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr)) + val hiveHashEval = HiveHash(exprs).eval(wideRow) + assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval) + } + private def testHash(inputSchema: StructType): Unit = { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get val encoder = RowEncoder(inputSchema)