Skip to content
Snippets Groups Projects
Commit f82aa824 authored by Sameer Agarwal's avatar Sameer Agarwal Committed by Reynold Xin
Browse files

[SPARK-14774][SQL] Write unscaled values in ColumnVector.putDecimal

## What changes were proposed in this pull request?

We recently made `ColumnarBatch.row` mutable and added a new `ColumnVector.putDecimal` method to support putting `Decimal` values in the `ColumnarBatch`. This unfortunately introduced a bug wherein we were not updating the vector with the proper unscaled values.

## How was this patch tested?

This codepath is hit only when the vectorized aggregate hashmap is enabled. https://github.com/apache/spark/pull/12440 makes sure that a number of regression tests/benchmarks test this bugfix.

Author: Sameer Agarwal <sameer@databricks.com>

Closes #12541 from sameeragarwal/fix-bigdecimal.
parent 1a95397b
No related branches found
No related tags found
No related merge requests found
......@@ -569,9 +569,9 @@ public abstract class ColumnVector implements AutoCloseable {
public final void putDecimal(int rowId, Decimal value, int precision) {
if (precision <= Decimal.MAX_INT_DIGITS()) {
putInt(rowId, value.toInt());
putInt(rowId, (int) value.toUnscaledLong());
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
putLong(rowId, value.toLong());
putLong(rowId, value.toUnscaledLong());
} else {
BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue();
putByteArray(rowId, bigInteger.toByteArray());
......
......@@ -142,9 +142,11 @@ public class ColumnVectorUtils {
byte[] b =((String)o).getBytes(StandardCharsets.UTF_8);
dst.appendByteArray(b, 0, b.length);
} else if (t instanceof DecimalType) {
DecimalType dt = (DecimalType)t;
Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale());
if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
DecimalType dt = (DecimalType) t;
Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale());
if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
dst.appendInt((int) d.toUnscaledLong());
} else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
dst.appendLong(d.toUnscaledLong());
} else {
final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
......
......@@ -586,30 +586,31 @@ class ColumnarBatchSuite extends SparkFunSuite {
}
private def compareStruct(fields: Seq[StructField], r1: InternalRow, r2: Row, seed: Long) {
fields.zipWithIndex.foreach { v => {
assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed)
if (!r1.isNullAt(v._2)) {
v._1.dataType match {
case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed)
case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed)
case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed)
case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed)
case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed)
case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)),
fields.zipWithIndex.foreach { case (field: StructField, ordinal: Int) =>
assert(r1.isNullAt(ordinal) == r2.isNullAt(ordinal), "Seed = " + seed)
if (!r1.isNullAt(ordinal)) {
field.dataType match {
case BooleanType => assert(r1.getBoolean(ordinal) == r2.getBoolean(ordinal),
"Seed = " + seed)
case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)),
case ByteType => assert(r1.getByte(ordinal) == r2.getByte(ordinal), "Seed = " + seed)
case ShortType => assert(r1.getShort(ordinal) == r2.getShort(ordinal), "Seed = " + seed)
case IntegerType => assert(r1.getInt(ordinal) == r2.getInt(ordinal), "Seed = " + seed)
case LongType => assert(r1.getLong(ordinal) == r2.getLong(ordinal), "Seed = " + seed)
case FloatType => assert(doubleEquals(r1.getFloat(ordinal), r2.getFloat(ordinal)),
"Seed = " + seed)
case DoubleType => assert(doubleEquals(r1.getDouble(ordinal), r2.getDouble(ordinal)),
"Seed = " + seed)
case t: DecimalType =>
val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal
val d2 = r2.getDecimal(v._2)
val d1 = r1.getDecimal(ordinal, t.precision, t.scale).toBigDecimal
val d2 = r2.getDecimal(ordinal)
assert(d1.compare(d2) == 0, "Seed = " + seed)
case StringType =>
assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed)
assert(r1.getString(ordinal) == r2.getString(ordinal), "Seed = " + seed)
case CalendarIntervalType =>
assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval])
assert(r1.getInterval(ordinal) === r2.get(ordinal).asInstanceOf[CalendarInterval])
case ArrayType(childType, n) =>
val a1 = r1.getArray(v._2).array
val a2 = r2.getList(v._2).toArray
val a1 = r1.getArray(ordinal).array
val a2 = r2.getList(ordinal).toArray
assert(a1.length == a2.length, "Seed = " + seed)
childType match {
case DoubleType =>
......@@ -640,12 +641,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
case _ => assert(a1 === a2, "Seed = " + seed)
}
case StructType(childFields) =>
compareStruct(childFields, r1.getStruct(v._2, fields.length), r2.getStruct(v._2), seed)
compareStruct(childFields, r1.getStruct(ordinal, fields.length),
r2.getStruct(ordinal), seed)
case _ =>
throw new NotImplementedError("Not implemented " + v._1.dataType)
throw new NotImplementedError("Not implemented " + field.dataType)
}
}
}}
}
}
test("Convert rows") {
......@@ -678,9 +680,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
def testRandomRows(flatSchema: Boolean, numFields: Int) {
// TODO: Figure out why StringType doesn't work on jenkins.
val types = Array(
BooleanType, ByteType, FloatType, DoubleType,
IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10),
CalendarIntervalType)
BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType,
DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType)
val seed = System.nanoTime()
val NUM_ROWS = 200
val NUM_ITERS = 1000
......@@ -756,8 +759,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
test("mutable ColumnarBatch rows") {
val NUM_ITERS = 10
val types = Array(
BooleanType, FloatType, DoubleType,
IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10))
BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType,
DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
new DecimalType(12, 2), new DecimalType(30, 10))
for (i <- 0 to NUM_ITERS) {
val random = new Random(System.nanoTime())
val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment