diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index dbc0c2965a805cb69de0baefc4104527ed315dbb..15560a2a933ad32c7a0db6f2ce3e9b731dc3fba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -105,17 +105,18 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - var currentMin: Any = _ + val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) + val cmp = GreaterThan(currentMin, expr) override def update(input: Row): Unit = { - if (currentMin == null) { - currentMin = expr.eval(input) - } else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) { - currentMin = expr.eval(input) + if (currentMin.value == null) { + currentMin.value = expr.eval(input) + } else if(cmp.eval(input) == true) { + currentMin.value = expr.eval(input) } } - override def eval(input: Row): Any = currentMin + override def eval(input: Row): Any = currentMin.value } case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -135,17 +136,18 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - var currentMax: Any = _ + val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) + val cmp = LessThan(currentMax, expr) override def update(input: Row): Unit = { - if (currentMax == null) { - currentMax = expr.eval(input) - } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) { - currentMax = expr.eval(input) + if (currentMax.value == null) { + currentMax.value = expr.eval(input) + } else if(cmp.eval(input) == true) { + currentMax.value = expr.eval(input) } } - override def eval(input: Row): Any = currentMax + override def eval(input: Row): Any = currentMax.value } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -350,7 +352,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private val zero = Cast(Literal(0), expr.dataType) private var count: Long = _ - private val sum = MutableLiteral(zero.eval(EmptyRow)) + private val sum = MutableLiteral(zero.eval(null), expr.dataType) private val sumAsDouble = Cast(sum, DoubleType) private def addFunction(value: Any) = Add(sum, Literal(value)) @@ -423,7 +425,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val zero = Cast(Literal(0), expr.dataType) - private val sum = MutableLiteral(zero.eval(null)) + private val sum = MutableLiteral(zero.eval(null), expr.dataType) private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a8c2396d62632cfe802858c674827fa49ba1864b..78a0c55e4bbe5c4d5e90d458c2a69692aff4adf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -61,11 +61,10 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { } // TODO: Specialize -case class MutableLiteral(var value: Any, nullable: Boolean = true) extends LeafExpression { +case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) + extends LeafExpression { type EvaluatedType = Any - val dataType = Literal(value).dataType - def update(expression: Expression, input: Row) = { value = expression.eval(input) }