diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 1501111a06655d3c4f112958851358dfc6616cb6..64e7102e3654cec2ad51424124cf9c7e9a620c5f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -20,6 +20,8 @@ package org.apache.spark.util.collection import scala.reflect._ import com.google.common.hash.Hashing +import org.apache.spark.annotation.Private + /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never * removed. @@ -37,7 +39,7 @@ import com.google.common.hash.Hashing * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ -private[spark] +@Private class OpenHashSet[@specialized(Long, Int) T: ClassTag]( initialCapacity: Int, loadFactor: Double) @@ -110,6 +112,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( rehashIfNeeded(k, grow, move) } + def union(other: OpenHashSet[T]): OpenHashSet[T] = { + val iterator = other.iterator + while (iterator.hasNext) { + add(iterator.next()) + } + this + } + /** * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. diff --git a/pom.xml b/pom.xml index d03d33bf02468b3dc24d71b325d4848e8c2fe23b..bcb6ef96a1206f54ef24c60096d0d5880c65f4b2 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,6 @@ <akka.version>2.3.4-spark</akka.version> <java.version>1.6</java.version> <sbt.project.name>spark</sbt.project.name> - <scala.macros.version>2.0.1</scala.macros.version> <mesos.version>0.21.1</mesos.version> <mesos.classifier>shaded-protobuf</mesos.classifier> <slf4j.version>1.7.10</slf4j.version> @@ -1217,15 +1216,6 @@ <javacArg>-target</javacArg> <javacArg>${java.version}</javacArg> </javacArgs> - <!-- The following plugin is required to use quasiquotes in Scala 2.10 and is used - by Spark SQL for code generation. --> - <compilerPlugins> - <compilerPlugin> - <groupId>org.scalamacros</groupId> - <artifactId>paradise_${scala.version}</artifactId> - <version>${scala.macros.version}</version> - </compilerPlugin> - </compilerPlugins> </configuration> </plugin> <plugin> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a849639233bc9c648492585b60fbe3992fe5714..f65031fe25ac247801d13e230df2637d7dbb764b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -178,9 +178,6 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) - /* Catalyst macro settings */ - enable(Catalyst.settings)(catalyst) - /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -275,14 +272,6 @@ object OldDeps { ) } -object Catalyst { - lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), - // Quasiquotes break compiling scala doc... - // TODO: Investigate fixing this. - sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) -} - object SQL { lazy val settings = Seq( initialCommands in console := diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index bf0a7327a58a292c8a1be9880f6836d8464698cc..f4b1cc3a4ffe7ecc4eba9993aa09cb799ca0fabd 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -36,10 +36,6 @@ </properties> <dependencies> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-compiler</artifactId> - </dependency> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-reflect</artifactId> @@ -67,6 +63,11 @@ <artifactId>scalacheck_${scala.binary.version}</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.codehaus.janino</groupId> + <artifactId>janino</artifactId> + <version>2.7.8</version> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> @@ -108,13 +109,6 @@ <activation> <property><name>!scala-2.11</name></property> </activation> - <dependencies> - <dependency> - <groupId>org.scalamacros</groupId> - <artifactId>quasiquotes_${scala.binary.version}</artifactId> - <version>${scala.macros.version}</version> - </dependency> - </dependencies> </profile> </profiles> </project> diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb546b3086b335b20054a58fadcf5b7ac2e79d60..ec97fe603c44ffeab5ac7a33f77065f41538dd21 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,23 +17,25 @@ package org.apache.spark.sql.catalyst.expressions; -import scala.collection.Map; +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import scala.collection.Seq; import scala.collection.mutable.ArraySeq; -import javax.annotation.Nullable; -import java.math.BigDecimal; -import java.sql.Date; -import java.util.*; - import org.apache.spark.sql.Row; +import org.apache.spark.sql.BaseMutableRow; import org.apache.spark.sql.types.DataType; -import static org.apache.spark.sql.types.DataTypes.*; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import static org.apache.spark.sql.types.DataTypes.*; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -49,7 +51,7 @@ import org.apache.spark.unsafe.bitset.BitSetMethods; * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow implements MutableRow { +public final class UnsafeRow extends BaseMutableRow { private Object baseObject; private long baseOffset; @@ -227,21 +229,11 @@ public final class UnsafeRow implements MutableRow { return numFields; } - @Override - public int length() { - return size(); - } - @Override public StructType schema() { return schema; } - @Override - public Object apply(int i) { - return get(i); - } - @Override public Object get(int i) { assertIndexIsValid(i); @@ -339,60 +331,7 @@ public final class UnsafeRow implements MutableRow { return getUTF8String(i).toString(); } - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - @Override - public <T> Seq<T> getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public <T> List<T> getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public <K, V> Map<K, V> getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public <K, V> java.util.Map<K, V> getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public <T> T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public <T> T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } @Override public Row copy() { @@ -412,24 +351,4 @@ public final class UnsafeRow implements MutableRow { } return values; } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java new file mode 100644 index 0000000000000000000000000000000000000000..acec2bf4520f236c6722f7d171a759386345fa69 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import org.apache.spark.sql.catalyst.expressions.MutableRow; + +public abstract class BaseMutableRow extends BaseRow implements MutableRow { + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setLong(int ordinal, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDouble(int ordinal, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setShort(int ordinal, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setByte(int ordinal, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setFloat(int ordinal, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java new file mode 100644 index 0000000000000000000000000000000000000000..d138b43a3482b53b3a09c4682b88b2bf60682216 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.util.List; + +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +public abstract class BaseRow implements Row { + + @Override + final public int length() { + return size(); + } + + @Override + public boolean anyNull() { + final int n = size(); + for (int i=0; i < n; i++) { + if (isNullAt(i)) { + return true; + } + } + return false; + } + + @Override + public StructType schema() { throw new UnsupportedOperationException(); } + + @Override + final public Object apply(int i) { + return get(i); + } + + @Override + public int getInt(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> Seq<T> getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> List<T> getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <K, V> scala.collection.Map<K, V> getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public <K, V> java.util.Map<K, V> getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + final int n = size(); + Object[] arr = new Object[n]; + for (int i = 0; i < n; i++) { + arr[i] = get(i); + } + return new GenericRow(arr); + } + + @Override + public Seq<Object> toSeq() { + final int n = size(); + final ArraySeq<Object> values = new ArraySeq<Object>(n); + for (int i = 0; i < n; i++) { + values.update(i, get(i)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 36964af68dd8d2459232f6bf262456d7984fbb8c..cd604121b7dd9c889e96389ec455acaaa29a2871 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.google.common.cache.{CacheLoader, CacheBuilder} - +import scala.collection.mutable import scala.language.existentials +import com.google.common.cache.{CacheBuilder, CacheLoader} +import org.codehaus.janino.ClassBodyEvaluator + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -36,23 +38,15 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * expressions. */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - import scala.tools.reflect.ToolBox - - protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() - protected val rowType = typeOf[Row] - protected val mutableRowType = typeOf[MutableRow] - protected val genericRowType = typeOf[GenericRow] - protected val genericMutableRowType = typeOf[GenericMutableRow] - - protected val projectionType = typeOf[Projection] - protected val mutableProjectionType = typeOf[MutableProjection] + protected val rowType = classOf[Row].getName + protected val stringType = classOf[UTF8String].getName + protected val decimalType = classOf[Decimal].getName + protected val exprType = classOf[Expression].getName + protected val mutableRowType = classOf[MutableRow].getName + protected val genericMutableRowType = classOf[GenericMutableRow].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() - private val javaSeparator = "$" /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. @@ -74,6 +68,20 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** + * Compile the Java source code into a Java class, using Janino. + * + * It will track the time used to compile + */ + protected def compile(code: String): Class[_] = { + val startTime = System.nanoTime() + val clazz = new ClassBodyEvaluator(code).getClazz() + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") + clazz + } + /** * A cache of generated classes. * @@ -87,7 +95,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .maximumSize(1000) .build( new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = globalLock.synchronized { + override def load(in: InType): OutType = { val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() @@ -110,8 +118,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - protected def freshName(prefix: String): TermName = { - newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + protected def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" } /** @@ -125,32 +133,51 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( - code: Seq[Tree], - nullTerm: TermName, - primitiveTerm: TermName, - objectTerm: TermName) + code: String, + nullTerm: String, + primitiveTerm: String, + objectTerm: String) + + /** + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. + */ + protected class CodeGenContext { + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + } + + /** + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen + */ + def newCodeGenContext(): CodeGenContext = { + new CodeGenContext() + } /** * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that * can be used to determine the result of evaluating the expression on an input row. */ - def expressionEvaluator(e: Expression): EvaluatedExpression = { + def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { val primitiveTerm = freshName("primitiveTerm") val nullTerm = freshName("nullTerm") val objectTerm = freshName("objectTerm") implicit class Evaluate1(e: Expression) { - def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${f(eval.primitiveTerm)} - """.children + def castOrNull(f: String => String, dataType: DataType): String = { + val eval = expressionEvaluator(e, ctx) + eval.code + + s""" + boolean $nullTerm = ${eval.nullTerm}; + ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; + if (!$nullTerm) { + $primitiveTerm = ${f(eval.primitiveTerm)}; + } + """ } } @@ -163,529 +190,505 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * * @param f a function from two primitive term names to a tree that evaluates them. */ - def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + def evaluate(f: (String, String) => String): String = evaluateAs(expressions._1.dataType)(f) - def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (expressions._1.dataType != expressions._2.dataType) { log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") } - val eval1 = expressionEvaluator(expressions._1) - val eval2 = expressionEvaluator(expressions._2) + val eval1 = expressionEvaluator(expressions._1, ctx) + val eval2 = expressionEvaluator(expressions._2, ctx) val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - eval1.code ++ eval2.code ++ - q""" - val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} - val $primitiveTerm: ${termForType(resultType)} = - if($nullTerm) { - ${defaultPrimitive(resultType)} - } else { - $resultCode.asInstanceOf[${termForType(resultType)}] - } - """.children : Seq[Tree] + eval1.code + eval2.code + + s""" + boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; + ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; + if(!$nullTerm) { + $primitiveTerm = (${primitiveForType(resultType)})($resultCode); + } + """ } } - val inputTuple = newTermName(s"i") + val inputTuple = "i" // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + val primitiveEvaluation: PartialFunction[Expression, String] = { case b @ BoundReference(ordinal, dataType, nullable) => - val nullValue = q"$inputTuple.isNullAt($ordinal)" - q""" - val $nullTerm: Boolean = $nullValue - val $primitiveTerm: ${termForType(dataType)} = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${getColumn(inputTuple, dataType, ordinal)} - """.children + s""" + final boolean $nullTerm = $inputTuple.isNullAt($ordinal); + final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? + ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); + """ case expressions.Literal(null, dataType) => - q""" - val $nullTerm = true - val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] - """.children - - case expressions.Literal(value: Boolean, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: UTF8String, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = - org.apache.spark.sql.types.UTF8String(${value.getBytes}) - """.children - - case expressions.Literal(value: Int, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: Long, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case Cast(e @ BinaryType(), StringType) => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) - """.children + s""" + final boolean $nullTerm = true; + ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; + """ + + case expressions.Literal(value: UTF8String, StringType) => + val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" + s""" + final boolean $nullTerm = false; + ${stringType} $primitiveTerm = + new ${stringType}().set(${arr}); + """ + + case expressions.Literal(value, FloatType) => + s""" + final boolean $nullTerm = false; + float $primitiveTerm = ${value}f; + """ + + case expressions.Literal(value, dt @ DecimalType()) => + s""" + final boolean $nullTerm = false; + ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); + """ + + case expressions.Literal(value, dataType) => + s""" + final boolean $nullTerm = false; + ${primitiveForType(dataType)} $primitiveTerm = $value; + """ + + case Cast(child @ BinaryType(), StringType) => + child.castOrNull(c => + s"new ${stringType}().set($c)", + StringType) case Cast(child @ DateType(), StringType) => child.castOrNull(c => - q"""org.apache.spark.sql.types.UTF8String( + s"""new ${stringType}().set( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", StringType) - case Cast(child @ NumericType(), IntegerType) => - child.castOrNull(c => q"$c.toInt", IntegerType) + case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) - case Cast(child @ NumericType(), LongType) => - child.castOrNull(c => q"$c.toLong", LongType) + case Cast(child @ DecimalType(), IntegerType) => + child.castOrNull(c => s"($c).toInt()", IntegerType) - case Cast(child @ NumericType(), DoubleType) => - child.castOrNull(c => q"$c.toDouble", DoubleType) + case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) - case Cast(child @ NumericType(), FloatType) => - child.castOrNull(c => q"$c.toFloat", FloatType) + case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) - """.children + e.castOrNull(c => + s"new ${stringType}().set(String.valueOf($c))", + StringType) case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => - q""" - java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], - $eval2.asInstanceOf[Array[Byte]]) - """ + s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" } case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } - - /* TODO: Fix null semantics. - case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => - val eval = expressionEvaluator(e1) - - val checks = list.map { - case expressions.Literal(v: String, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - case expressions.Literal(v: Int, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - } - - val funcName = newTermName(s"isIn${curId.getAndIncrement()}") - - q""" - def $funcName: Boolean = { - ..${eval.code} - if(${eval.nullTerm}) return false - ..$checks - return false - } - val $nullTerm = false - val $primitiveTerm = $funcName - """.children - */ + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } case And(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) { + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + s""" + ${eval1.code} + boolean $nullTerm = false; + boolean $primitiveTerm = false; + + if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { } else { - ..${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) { + ${eval2.code} + if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else { - $nullTerm = true + $nullTerm = true; } } - """.children + """ case Or(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false + s""" + ${eval1.code} + boolean $nullTerm = false; + boolean $primitiveTerm = false; if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else { - ..${eval2.code} + ${eval2.code} if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true + $primitiveTerm = true; } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false + $primitiveTerm = false; } else { - $nullTerm = true + $nullTerm = true; } } - """.children + """ case Not(child) => // Uh, bad function name... - child.castOrNull(c => q"!$c", BooleanType) - - case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } - case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } - case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + child.castOrNull(c => s"!$c", BooleanType) + + case Add(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } + case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } + case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } + case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = null; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); + } + """ + case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); + } + """ + + case Add(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } + case Subtract(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } + case Multiply(e1, e2) => + (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm} + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; } - """.children - + """ case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $nullTerm = false - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm} + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = 0; + if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { + $nullTerm = true; + } else { + $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; } - """.children + """ case IsNotNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} - """.children + val eval = expressionEvaluator(e, ctx) + s""" + ${eval.code} + boolean $nullTerm = false; + boolean $primitiveTerm = !${eval.nullTerm}; + """ case IsNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} - """.children - - case c @ Coalesce(children) => - q""" - var $nullTerm = true - var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} - """.children ++ + val eval = expressionEvaluator(e, ctx) + s""" + ${eval.code} + boolean $nullTerm = false; + boolean $primitiveTerm = ${eval.nullTerm}; + """ + + case e @ Coalesce(children) => + s""" + boolean $nullTerm = true; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + """ + children.map { c => - val eval = expressionEvaluator(c) - q""" + val eval = expressionEvaluator(c, ctx) + s""" if($nullTerm) { - ..${eval.code} + ${eval.code} if(!${eval.nullTerm}) { - $nullTerm = false - $primitiveTerm = ${eval.primitiveTerm} + $nullTerm = false; + $primitiveTerm = ${eval.primitiveTerm}; } } """ - } + }.mkString("\n") - case i @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition) - val trueEval = expressionEvaluator(trueValue) - val falseEval = expressionEvaluator(falseValue) + case e @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition, ctx) + val trueEval = expressionEvaluator(trueValue, ctx) + val falseEval = expressionEvaluator(falseValue, ctx) - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} - ..${condEval.code} + s""" + boolean $nullTerm = false; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + ${condEval.code} if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ..${trueEval.code} - $nullTerm = ${trueEval.nullTerm} - $primitiveTerm = ${trueEval.primitiveTerm} + ${trueEval.code} + $nullTerm = ${trueEval.nullTerm}; + $primitiveTerm = ${trueEval.primitiveTerm}; } else { - ..${falseEval.code} - $nullTerm = ${falseEval.nullTerm} - $primitiveTerm = ${falseEval.primitiveTerm} + ${falseEval.code} + $nullTerm = ${falseEval.nullTerm}; + $primitiveTerm = ${falseEval.primitiveTerm}; } - """.children + """ case NewSet(elementType) => - q""" - val $nullTerm = false - val $primitiveTerm = new ${hashSetForType(elementType)}() - """.children + s""" + boolean $nullTerm = false; + ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); + """ case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item) - val setEval = expressionEvaluator(set) + val itemEval = expressionEvaluator(item, ctx) + val setEval = expressionEvaluator(set, ctx) val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = hashSetForType(elementType) - itemEval.code ++ setEval.code ++ - q""" - if (!${itemEval.nullTerm}) { - ${setEval.primitiveTerm} - .asInstanceOf[${hashSetForType(elementType)}] - .add(${itemEval.primitiveTerm}) + itemEval.code + setEval.code + + s""" + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); } - - val $nullTerm = false - val $primitiveTerm = ${setEval.primitiveTerm} - """.children + boolean $nullTerm = false; + ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; + """ case CombineSets(left, right) => - val leftEval = expressionEvaluator(left) - val rightEval = expressionEvaluator(right) + val leftEval = expressionEvaluator(left, ctx) + val rightEval = expressionEvaluator(right, ctx) val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = hashSetForType(elementType) - leftEval.code ++ rightEval.code ++ - q""" - val $nullTerm = false - var $primitiveTerm: ${hashSetForType(elementType)} = null - - { - val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val iterator = rightSet.iterator - while (iterator.hasNext) { - leftSet.add(iterator.next()) - } - $primitiveTerm = leftSet - } - """.children + leftEval.code + rightEval.code + + s""" + boolean $nullTerm = false; + ${htype} $primitiveTerm = + (${htype})${leftEval.primitiveTerm}; + $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); + """ - case MaxOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} + $nullTerm = ${eval2.nullTerm}; + $primitiveTerm = ${eval2.primitiveTerm}; } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} + $nullTerm = ${eval1.nullTerm}; + $primitiveTerm = ${eval1.primitiveTerm}; } else { if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} + $primitiveTerm = ${eval1.primitiveTerm}; } else { - $primitiveTerm = ${eval2.primitiveTerm} + $primitiveTerm = ${eval2.primitiveTerm}; } } - """.children + """ - case MinOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) + case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => + val eval1 = expressionEvaluator(e1, ctx) + val eval2 = expressionEvaluator(e2, ctx) - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + eval1.code + eval2.code + + s""" + boolean $nullTerm = false; + ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} + $nullTerm = ${eval2.nullTerm}; + $primitiveTerm = ${eval2.primitiveTerm}; } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} + $nullTerm = ${eval1.nullTerm}; + $primitiveTerm = ${eval1.primitiveTerm}; } else { if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} + $primitiveTerm = ${eval1.primitiveTerm}; } else { - $primitiveTerm = ${eval2.primitiveTerm} + $primitiveTerm = ${eval2.primitiveTerm}; } } - """.children + """ case UnscaledValue(child) => - val childEval = expressionEvaluator(child) - - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: Long = if (!$nullTerm) { - ${childEval.primitiveTerm}.toUnscaledLong - } else { - ${defaultPrimitive(LongType)} - } - """.children + val childEval = expressionEvaluator(child, ctx) + + childEval.code + + s""" + boolean $nullTerm = ${childEval.nullTerm}; + long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); + """ case MakeDecimal(child, precision, scale) => - val childEval = expressionEvaluator(child) + val eval = expressionEvaluator(child, ctx) - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: org.apache.spark.sql.types.Decimal = - ${defaultPrimitive(DecimalType())} + eval.code + + s""" + boolean $nullTerm = ${eval.nullTerm}; + org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal() - $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) - $nullTerm = $primitiveTerm == null + $primitiveTerm = new org.apache.spark.sql.types.Decimal(); + $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); + $nullTerm = $primitiveTerm == null; } - """.children + """ } // If there was no match in the partial function above, we fall back on calling the interpreted // expression evaluator. - val code: Seq[Tree] = + val code: String = primitiveEvaluation.lift.apply(e).getOrElse { - log.debug(s"No rules to generate $e") - val tree = reify { e } - q""" - val $objectTerm = $tree.eval(i) - val $nullTerm = $objectTerm == null - val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] - """.children - } - - // Only inject debugging code if debugging is turned on. - val debugCode = - if (debugLogging) { - val localLogger = log - val localLoggerTree = reify { localLogger } - q""" - $localLoggerTree.debug( - ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) - """ :: Nil - } else { - Nil + logError(s"No rules to generate $e") + ctx.references += e + s""" + /* expression: ${e} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean $nullTerm = $objectTerm == null; + ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; + if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; + """ } - EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) } - protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { dataType match { - case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" - case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + case StringType => s"(${stringType})$inputRow.apply($ordinal)" + case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" } } protected def setColumn( - destinationRow: TermName, + destinationRow: String, dataType: DataType, ordinal: Int, - value: TermName) = { + value: String): String = { dataType match { - case StringType => q"$destinationRow.update($ordinal, $value)" + case StringType => s"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => - q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => q"$destinationRow.update($ordinal, $value)" + s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => s"$destinationRow.update($ordinal, $value)" } } - protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") - protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + protected def accessorForType(dt: DataType) = dt match { + case IntegerType => "getInt" + case other => s"get${termForType(dt)}" + } + + protected def mutatorForType(dt: DataType) = dt match { + case IntegerType => "setInt" + case other => s"set${termForType(dt)}" + } - protected def hashSetForType(dt: DataType) = dt match { - case IntegerType => typeOf[IntegerHashSet] - case LongType => typeOf[LongHashSet] + protected def hashSetForType(dt: DataType): String = dt match { + case IntegerType => classOf[IntegerHashSet].getName + case LongType => classOf[LongHashSet].getName case unsupportedType => sys.error(s"Code generation not support for hashset of type $unsupportedType") } - protected def primitiveForType(dt: DataType) = dt match { - case IntegerType => "Int" + protected def primitiveForType(dt: DataType): String = dt match { + case IntegerType => "int" + case LongType => "long" + case ShortType => "short" + case ByteType => "byte" + case DoubleType => "double" + case FloatType => "float" + case BooleanType => "boolean" + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "int" + case TimestampType => "java.sql.Timestamp" + case _ => "Object" + } + + protected def defaultPrimitive(dt: DataType): String = dt match { + case BooleanType => "false" + case FloatType => "-1.0f" + case ShortType => "-1" + case LongType => "-1" + case ByteType => "-1" + case DoubleType => "-1.0" + case IntegerType => "-1" + case DateType => "-1" + case dt: DecimalType => "null" + case StringType => "null" + case _ => "null" + } + + protected def termForType(dt: DataType): String = dt match { + case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" case ByteType => "Byte" case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "org.apache.spark.sql.types.UTF8String" - } - - protected def defaultPrimitive(dt: DataType) = dt match { - case BooleanType => ru.Literal(Constant(false)) - case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => q"""org.apache.spark.sql.types.UTF8String("<uninit>")""" - case ShortType => ru.Literal(Constant(-1.toShort)) - case LongType => ru.Literal(Constant(-1L)) - case ByteType => ru.Literal(Constant(-1.toByte)) - case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" - case IntegerType => ru.Literal(Constant(-1)) - case DateType => ru.Literal(Constant(-1)) - case _ => ru.Literal(Constant(null)) - } - - protected def termForType(dt: DataType) = dt match { - case n: AtomicType => n.tag - case _ => typeTag[Any] + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "Integer" + case TimestampType => "java.sql.Timestamp" + case _ => "Object" } /** * List of data types that have special accessors and setters in [[Row]]. */ protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) /** * Returns true if the data type has a special accessor and setter in [[Row]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 840260703ab74b69ac6f1e0af51870d544bee858..638b53fe0fe2f17ecca834da6ffa2a9d428b66bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +// MutableProjection is not accessible in Java +abstract class BaseMutableProjection extends MutableProjection {} + /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[Row]] for a fixed set of [[Expression Expressions]]. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -36,41 +35,61 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu in.map(BindReferences.bindReference(_, inputSchema)) protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { - val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => - val evaluationCode = expressionEvaluator(e) - - evaluationCode.code :+ - q""" - if(${evaluationCode.nullTerm}) - mutableRow.setNullAt($i) - else - ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} - """ - } + val ctx = newCodeGenContext() + val projectionCode = expressions.zipWithIndex.map { case (e, i) => + val evaluationCode = expressionEvaluator(e, ctx) + evaluationCode.code + + s""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i); + else + ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + """ + }.mkString("\n") + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); + } + + class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - val code = - q""" - () => { new $mutableProjectionType { + private $exprType[] expressions = null; + private $mutableRowType mutableRow = null; - private[this] var $mutableRowName: $mutableRowType = - new $genericMutableRowType(${expressions.size}) + public SpecificProjection($exprType[] expr) { + expressions = expr; + mutableRow = new $genericMutableRowType(${expressions.size}); + } - def target(row: $mutableRowType): $mutableProjectionType = { - $mutableRowName = row - this - } + public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + mutableRow = row; + return this; + } - /* Provide immutable access to the last projected row. */ - def currentValue: $rowType = mutableRow + /* Provide immutable access to the last projected row. */ + public Row currentValue() { + return mutableRow; + } - def apply(i: $rowType): $rowType = { - ..$projectionCode - mutableRow - } - } } - """ + public Object apply(Object _i) { + Row i = (Row) _i; + $projectionCode - log.debug(s"code for ${expressions.mkString(",")}:\n$code") - toolBox.eval(code).asInstanceOf[() => MutableProjection] + return mutableRow; + } + } + """ + + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + () => { + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection] + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index b129c0d898bb7d2928f01045dc76d2e7ccb26791..0ff840dab393c2a7ee68d0a405178d1cfbae146f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -18,18 +18,29 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging +import org.apache.spark.annotation.Private +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} +import org.apache.spark.sql.types.{BinaryType, NumericType} + +/** + * Inherits some default implementation for Java from `Ordering[Row]` + */ +@Private +class BaseOrdering extends Ordering[Row] { + def compare(a: Row, b: Row): Int = { + throw new UnsupportedOperationException + } +} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = @@ -38,73 +49,90 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { val a = newTermName("a") val b = newTermName("b") - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child) - val evalB = expressionEvaluator(order.child) + val ctx = newCodeGenContext() + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child, ctx) + val evalB = expressionEvaluator(order.child, ctx) + val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => - q""" - val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} - val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} - var i = 0 - while (i < x.length && i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - i = i+1 - } - return x.length - y.length - """ + s""" + { + byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; + byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; + int j = 0; + while (j < x.length && j < y.length) { + if (x[j] != y[j]) return x[j] - y[j]; + j = j + 1; + } + int d = x.length - y.length; + if (d != 0) { + return d; + } + }""" case _: NumericType => - q""" - val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} - if(comp != 0) { - return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} - } - """ - case StringType => - if (order.direction == Ascending) { - q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + s""" + if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { + if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { + return ${if (asc) "1" else "-1"}; + } else { + return ${if (asc) "-1" else "1"}; + } + }""" + case _ => + s""" + int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); + if (comp != 0) { + return ${if (asc) "comp" else "-comp"}; + }""" + } + + s""" + i = $a; + ${evalA.code} + i = $b; + ${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) "-1" else "1"}; + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) "1" else "-1"}; } else { - q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + $compare } + """ + }.mkString("\n") + + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificOrdering generate($exprType[] expr) { + return new SpecificOrdering(expr); } - q""" - i = $a - ..${evalA.code} - i = $b - ..${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { - // Nothing - } else if (${evalA.nullTerm}) { - return ${if (order.direction == Ascending) q"-1" else q"1"} - } else if (${evalB.nullTerm}) { - return ${if (order.direction == Ascending) q"1" else q"-1"} - } else { - $compare + class SpecificOrdering extends ${typeOf[BaseOrdering]} { + + private $exprType[] expressions = null; + + public SpecificOrdering($exprType[] expr) { + expressions = expr; } - """ - } - val q"class $orderingName extends $orderingType { ..$body }" = reify { - class SpecificOrdering extends Ordering[Row] { - val o = ordering - } - }.tree.children.head - - val code = q""" - class $orderingName extends $orderingType { - ..$body - def compare(a: $rowType, b: $rowType): Int = { - var i: $rowType = null // Holds current row being evaluated. - ..$comparisons - return 0 + @Override + public int compare(Row a, Row b) { + Row i = null; // Holds current row being evaluated. + $comparisons + return 0; } - } - new $orderingName() - """ + }""" + logDebug(s"Generated Ordering: $code") - toolBox.eval(code).asInstanceOf[Ordering[Row]] + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 40e163024360e0585d356ff6b536cb0077d247d9..fb18769f00da3dadf23fc8cb174fe8d16f56e844 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +/** + * Interface for generated predicate + */ +abstract class Predicate { + def eval(r: Row): Boolean +} + /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. */ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) @@ -32,17 +37,34 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { BindReferences.bindReference(in, inputSchema) protected def create(predicate: Expression): ((Row) => Boolean) = { - val cEval = expressionEvaluator(predicate) + val ctx = newCodeGenContext() + val eval = expressionEvaluator(predicate, ctx) + val code = s""" + import org.apache.spark.sql.Row; - val code = - q""" - (i: $rowType) => { - ..${cEval.code} - if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + public SpecificPredicate generate($exprType[] expr) { + return new SpecificPredicate(expr); + } + + class SpecificPredicate extends ${classOf[Predicate].getName} { + private final $exprType[] expressions; + public SpecificPredicate($exprType[] expr) { + expressions = expr; + } + + @Override + public boolean eval(Row i) { + ${eval.code} + return !${eval.nullTerm} && ${eval.primitiveTerm}; } - """ + }""" + + logDebug(s"Generated predicate '$predicate':\n$code") - log.debug(s"Generated predicate '$predicate':\n$code") - toolBox.eval(code).asInstanceOf[Row => Boolean] + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate] + (r: Row) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 31c63a79ebc8c7b6ab7da3d4e42357d8bcc3441e..d5be1fc12e0f0b3bf33d54dfe3024b3a7fd40476 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProject extends Projection {} /** * Generates bytecode that produces a new [[Row]] object based on a fixed set of input @@ -27,7 +32,6 @@ import org.apache.spark.sql.types._ * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -38,201 +42,183 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { // Make Mutablility optional... protected def create(expressions: Seq[Expression]): Projection = { - val tupleLength = ru.Literal(Constant(expressions.length)) - val lengthDef = q"final val length = $tupleLength" - - /* TODO: Configurable... - val nullFunctions = - q""" - private final val nullSet = new org.apache.spark.util.collection.BitSet(length) - final def setNullAt(i: Int) = nullSet.set(i) - final def isNullAt(i: Int) = nullSet.get(i) - """ - */ - - val nullFunctions = - q""" - private[this] var nullBits = new Array[Boolean](${expressions.size}) - override def setNullAt(i: Int) = { nullBits(i) = true } - override def isNullAt(i: Int) = nullBits(i) - """.children - - val tupleElements = expressions.zipWithIndex.flatMap { + val ctx = newCodeGenContext() + val columns = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") - val evaluatedExpression = expressionEvaluator(e) - val iLit = ru.Literal(Constant(i)) + s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" + }.mkString("\n ") - q""" - var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + val initColumns = expressions.zipWithIndex.map { + case (e, i) => + val eval = expressionEvaluator(e, ctx) + s""" { - ..${evaluatedExpression.code} - if(${evaluatedExpression.nullTerm}) - setNullAt($iLit) - else { - nullBits($iLit) = false - $elementName = ${evaluatedExpression.primitiveTerm} + // column$i + ${eval.code} + nullBits[$i] = ${eval.nullTerm}; + if(!${eval.nullTerm}) { + c$i = ${eval.primitiveTerm}; } } - """.children : Seq[Tree] - } + """ + }.mkString("\n") - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" - val applyFunction = { - val cases = (0 until expressions.size).map { i => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) + val getCases = (0 until expressions.size).map { i => + s"case $i: return c$i;" + }.mkString("\n ") - q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" - } - q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }" - } - - val updateFunction = { - val cases = expressions.zipWithIndex.map {case (e, i) => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) - - q""" - if(i == $ordinal) { - if(value == null) { - setNullAt(i) - } else { - nullBits(i) = false - $elementName = value.asInstanceOf[${termForType(e.dataType)}] - } - return - }""" - } - q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" - } + val updateCases = expressions.zipWithIndex.map { case (e, i) => + s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" + }.mkString("\n ") val specificAccessorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // getString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) return $elementName" :: Nil - case _ => Nil - } - dataType match { - // Row() need this interface to compile - case StringType => - q""" - override def getString(i: Int): String = { - $accessorFailure - }""" - case other => - q""" - override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: return c$i;" + case _ => "" + }.mkString("\n ") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { + if (isNullAt(i)) { + return ${defaultPrimitive(dataType)}; + } + switch (i) { + $cases + } + return ${defaultPrimitive(dataType)}; + }""" + } else { + "" } - } + }.mkString("\n") val specificMutatorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // setString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil - case _ => Nil - } - dataType match { - case StringType => - // MutableRow() need this interface to compile - q""" - override def setString(i: Int, value: String) { - $accessorFailure - }""" - case other => - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { - ..$ifStatements; - $accessorFailure - }""" + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: { c$i = value; return; }" + case _ => "" + }.mkString("\n") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { + nullBits[i] = false; + switch (i) { + $cases + } + }""" + } else { + "" } - } + }.mkString("\n") val hashValues = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") + val col = newTermName(s"c$i") val nonNull = e.dataType match { - case BooleanType => q"if ($elementName) 0 else 1" - case ByteType | ShortType | IntegerType => q"$elementName.toInt" - case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" - case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case BooleanType => s"$col ? 0 : 1" + case ByteType | ShortType | IntegerType | DateType => s"$col" + case LongType => s"$col ^ ($col >>> 32)" + case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" - case _ => q"$elementName.hashCode" + s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" + case _ => s"$col.hashCode()" } - q"if (isNullAt($i)) 0 else $nonNull" + s"isNullAt($i) ? 0 : ($nonNull)" } - val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + val hashUpdates: String = hashValues.map( v => + s""" + result *= 37; result += $v;""" + ).mkString("\n") - val hashCodeFunction = - q""" - override def hashCode(): Int = { - var result: Int = 37 - ..$hashUpdates - result - } + val columnChecks = expressions.zipWithIndex.map { case (e, i) => + s""" + if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) { + return false; + } """ + }.mkString("\n") - val columnChecks = (0 until expressions.size).map { i => - val elementName = newTermName(s"c$i") - q"if (this.$elementName != specificType.$elementName) return false" + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); } - val equalsFunction = - q""" - override def equals(other: Any): Boolean = other match { - case specificType: SpecificRow => - ..$columnChecks - return true - case other => super.equals(other) - } - """ + class SpecificProjection extends ${typeOf[BaseProject]} { + private $exprType[] expressions = null; + + public SpecificProjection($exprType[] expr) { + expressions = expr; + } - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + @Override + public Object apply(Object r) { + return new SpecificRow(expressions, (Row) r); + } } - val copyFunction = - q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" - - val toSeqFunction = - q"override def toSeq: Seq[Any] = Seq(..$allColumns)" - - val classBody = - nullFunctions ++ ( - lengthDef +: - applyFunction +: - updateFunction +: - equalsFunction +: - hashCodeFunction +: - copyFunction +: - toSeqFunction +: - (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) - - val code = q""" - final class SpecificRow(i: $rowType) extends $mutableRowType { - ..$classBody + final class SpecificRow extends ${typeOf[BaseMutableRow]} { + + $columns + + public SpecificRow($exprType[] expressions, Row i) { + $initColumns + } + + public int size() { return ${expressions.length};} + private boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } + + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; + } + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } + } + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + @Override + public boolean equals(Object other) { + if (other instanceof Row) { + Row row = (Row) other; + if (row.length() != size()) return false; + $columnChecks + return true; + } + return super.equals(other); + } + } """ - log.debug( - s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") - toolBox.eval(code).asInstanceOf[Projection] + logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 528e38a50a7405fbdc7bc0f943485014e1a85589..7f1b12cdd580035e5f2fa6f8df6bbb16fa31d59f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,12 +27,6 @@ import org.apache.spark.util.Utils */ package object codegen { - /** - * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala - * 2.10. - */ - protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock - /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index b6927485f42bf21763e8ed5ac609ae9dd89bb168..5df528770ca6e70b16a7020a8d92081e580162e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -344,7 +344,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation("abdef" cast TimestampType, null) checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) - checkEvaluation(Literal(1) cast LongType, 1) + checkEvaluation(Literal(1) cast LongType, 1.toLong) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) @@ -363,13 +363,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) + Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), + 5.toLong) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), + 0.toShort) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), + 0.toShort) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Literal(true) cast StringType, "true") @@ -509,9 +512,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) val tss = new Timestamp(seconds) - checkEvaluation(Cast(ts, ShortType), 15) + checkEvaluation(Cast(ts, ShortType), 15.toShort) checkEvaluation(Cast(ts, IntegerType), 15) - checkEvaluation(Cast(ts, LongType), 15) + checkEvaluation(Cast(ts, LongType), 15.toLong) checkEvaluation(Cast(ts, FloatType), 15.002f) checkEvaluation(Cast(ts, DoubleType), 15.002) checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index d7c437095e395c87b78320bd36b2e37b688fd8e5..8cfd853afa35f030425b012db614d0b69b9a4d51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -32,11 +32,12 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => - val evaluated = GenerateProjection.expressionEvaluator(expression) + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) fail( s""" |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} + |${evaluated.code} |$e """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index a40324b008e162d99ec25ec11f3ee443c8d2cd3f..9ab1f7d7ad0dbfffb6a5a4877ea4ffc24ad4bace 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -28,7 +28,8 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) val plan = try { GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) @@ -37,7 +38,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { fail( s""" |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} + |${evaluated.code} |$e """.stripMargin) } @@ -49,7 +50,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { s""" |Mismatched hashCodes for values: $actual, $expectedRow |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code.mkString("\n")} + |${evaluated.code} """.stripMargin) } if (actual != expectedRow) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9aaec2b064d769ceb5c732d81abd953239eba5bc..b41b1b77d049ebaaef207535e552be4d274f166c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -451,10 +451,13 @@ class DataFrameSuite extends QueryTest { test("SPARK-6899") { val originalValue = TestSQLContext.conf.codegenEnabled TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try{ + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + } finally { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("SPARK-7133: Implement struct, array, and map field accessor") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 63f7d314fb699c01454680f8437c08d84f5151ea..55b68d8e2283c6bc084e32b6f3941289137e33a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -184,77 +184,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df, expectedResults) } - // Just to group rows. - testCodeGen( - "SELECT key FROM testData3x GROUP BY key", - (1 to 100).map(Row(_))) - // COUNT - testCodeGen( - "SELECT key, count(value) FROM testData3x GROUP BY key", - (1 to 100).map(i => Row(i, 3))) - testCodeGen( - "SELECT count(key) FROM testData3x", - Row(300) :: Nil) - // COUNT DISTINCT ON int - testCodeGen( - "SELECT value, count(distinct key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 1))) - testCodeGen( - "SELECT count(distinct key) FROM testData3x", - Row(100) :: Nil) - // SUM - testCodeGen( - "SELECT value, sum(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 3 * i))) - testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", - Row(5050 * 3, 5050 * 3.0) :: Nil) - // AVERAGE - testCodeGen( - "SELECT value, avg(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT avg(key) FROM testData3x", - Row(50.5) :: Nil) - // MAX - testCodeGen( - "SELECT value, max(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT max(key) FROM testData3x", - Row(100) :: Nil) - // MIN - testCodeGen( - "SELECT value, min(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT min(key) FROM testData3x", - Row(1) :: Nil) - // Some combinations. - testCodeGen( - """ - |SELECT - | value, - | sum(key), - | max(key), - | min(key), - | avg(key), - | count(key), - | count(distinct key) - |FROM testData3x - |GROUP BY value - """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) - testCodeGen( - "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 1, 50.5, 300, 100) :: Nil) - // Aggregate with Code generation handling all null values - testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) - - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try { + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + } finally { + dropTempTable("testData3x") + setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("Add Parser of SQL COALESCE()") { @@ -463,9 +465,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val codegenbefore = conf.codegenEnabled setConf(SQLConf.EXTERNAL_SORT, "false") setConf(SQLConf.CODEGEN_ENABLED, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + try{ + sortTest() + } finally { + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("SPARK-6927 external sorting with codegen on") { @@ -473,9 +478,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val codegenbefore = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") setConf(SQLConf.EXTERNAL_SORT, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + try { + sortTest() + } finally { + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("limit") {