]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / java / gandiva / src / test / java / org / apache / arrow / gandiva / evaluator / ProjectorDecimalTest.java
CommitLineData
1d09f67e
TL
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.arrow.gandiva.evaluator;
19
20import static org.junit.Assert.assertEquals;
21import static org.junit.Assert.assertFalse;
22import static org.junit.Assert.assertTrue;
23
24import java.math.BigDecimal;
25import java.util.ArrayList;
26import java.util.Arrays;
27import java.util.List;
28
29import org.apache.arrow.gandiva.exceptions.GandivaException;
30import org.apache.arrow.gandiva.expression.ExpressionTree;
31import org.apache.arrow.gandiva.expression.TreeBuilder;
32import org.apache.arrow.gandiva.expression.TreeNode;
33import org.apache.arrow.vector.BigIntVector;
34import org.apache.arrow.vector.BitVector;
35import org.apache.arrow.vector.DecimalVector;
36import org.apache.arrow.vector.Float8Vector;
37import org.apache.arrow.vector.ValueVector;
38import org.apache.arrow.vector.VarCharVector;
39import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
40import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
41import org.apache.arrow.vector.types.pojo.ArrowType;
42import org.apache.arrow.vector.types.pojo.ArrowType.Decimal;
43import org.apache.arrow.vector.types.pojo.Field;
44import org.apache.arrow.vector.types.pojo.Schema;
45import org.junit.Rule;
46import org.junit.Test;
47import org.junit.rules.ExpectedException;
48
49import com.google.common.collect.Lists;
50
51public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.BaseEvaluatorTest {
52 @Rule
53 public ExpectedException exception = ExpectedException.none();
54
55 @Test
56 public void test_add() throws GandivaException {
57 int precision = 38;
58 int scale = 8;
59 ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128);
60 Field a = Field.nullable("a", decimal);
61 Field b = Field.nullable("b", decimal);
62 List<Field> args = Lists.newArrayList(a, b);
63
64 ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
65 .OperationType.ADD, decimal, decimal);
66 Field retType = Field.nullable("c", outputType);
67 ExpressionTree root = TreeBuilder.makeExpression("add", args, retType);
68
69 List<ExpressionTree> exprs = Lists.newArrayList(root);
70
71 Schema schema = new Schema(args);
72 Projector eval = Projector.make(schema, exprs);
73
74 int numRows = 4;
75 byte[] validity = new byte[]{(byte) 255};
76 String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "4.12345678"};
77 String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "5.12345678"};
78
79 DecimalVector valuesa = decimalVector(aValues, precision, scale);
80 DecimalVector valuesb = decimalVector(bValues, precision, scale);
81 ArrowRecordBatch batch =
82 new ArrowRecordBatch(
83 numRows,
84 Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
85 Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
86 valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
87
88 DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
89 outputType.getScale());
90 outVector.allocateNew(numRows);
91
92 List<ValueVector> output = new ArrayList<ValueVector>();
93 output.add(outVector);
94 eval.evaluate(batch, output);
95
96 // should have scaled down.
97 BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(3.2469136),
98 BigDecimal.valueOf(5.2469136),
99 BigDecimal.valueOf(7.2469136),
100 BigDecimal.valueOf(9.2469136)};
101
102 for (int i = 0; i < 4; i++) {
103 assertFalse(outVector.isNull(i));
104 assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i)
105 ) == 0);
106 }
107
108 // free buffers
109 releaseRecordBatch(batch);
110 releaseValueVectors(output);
111 eval.close();
112 }
113
114 @Test
115 public void test_add_literal() throws GandivaException {
116 int precision = 2;
117 int scale = 0;
118 ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128);
119 ArrowType.Decimal literalType = new ArrowType.Decimal(2, 1, 128);
120 Field a = Field.nullable("a", decimal);
121
122 ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
123 .OperationType.ADD, decimal, literalType);
124 Field retType = Field.nullable("c", outputType);
125 TreeNode field = TreeBuilder.makeField(a);
126 TreeNode literal = TreeBuilder.makeDecimalLiteral("6", 2, 1);
127 List<TreeNode> args = Lists.newArrayList(field, literal);
128 TreeNode root = TreeBuilder.makeFunction("add", args, outputType);
129 ExpressionTree tree = TreeBuilder.makeExpression(root, retType);
130
131 List<ExpressionTree> exprs = Lists.newArrayList(tree);
132
133 Schema schema = new Schema(Lists.newArrayList(a));
134 Projector eval = Projector.make(schema, exprs);
135
136 int numRows = 4;
137 String[] aValues = new String[]{"1", "2", "3", "4"};
138
139 DecimalVector valuesa = decimalVector(aValues, precision, scale);
140 ArrowRecordBatch batch =
141 new ArrowRecordBatch(
142 numRows,
143 Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
144 Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer()));
145
146 DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
147 outputType.getScale());
148 outVector.allocateNew(numRows);
149
150 List<ValueVector> output = new ArrayList<ValueVector>();
151 output.add(outVector);
152 eval.evaluate(batch, output);
153
154 BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(1.6), BigDecimal.valueOf(2.6),
155 BigDecimal.valueOf(3.6), BigDecimal.valueOf(4.6)};
156
157 for (int i = 0; i < 4; i++) {
158 assertFalse(outVector.isNull(i));
159 assertTrue(expOutput[i].compareTo(outVector.getObject(i)) == 0);
160 }
161
162 // free buffers
163 releaseRecordBatch(batch);
164 releaseValueVectors(output);
165 eval.close();
166 }
167
168 @Test
169 public void test_multiply() throws GandivaException {
170 int precision = 38;
171 int scale = 8;
172 ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128);
173 Field a = Field.nullable("a", decimal);
174 Field b = Field.nullable("b", decimal);
175 List<Field> args = Lists.newArrayList(a, b);
176
177 ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
178 .OperationType.MULTIPLY, decimal, decimal);
179 Field retType = Field.nullable("c", outputType);
180 ExpressionTree root = TreeBuilder.makeExpression("multiply", args, retType);
181
182 List<ExpressionTree> exprs = Lists.newArrayList(root);
183
184 Schema schema = new Schema(args);
185 Projector eval = Projector.make(schema, exprs);
186
187 int numRows = 4;
188 byte[] validity = new byte[]{(byte) 255};
189 String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "999999999999.99999999"};
190 String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "999999999999.99999999"};
191
192 DecimalVector valuesa = decimalVector(aValues, precision, scale);
193 DecimalVector valuesb = decimalVector(bValues, precision, scale);
194 ArrowRecordBatch batch =
195 new ArrowRecordBatch(
196 numRows,
197 Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
198 Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
199 valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
200
201 DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
202 outputType.getScale());
203 outVector.allocateNew(numRows);
204
205 List<ValueVector> output = new ArrayList<ValueVector>();
206 output.add(outVector);
207 eval.evaluate(batch, output);
208
209 // should have scaled down.
210 BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(2.385612),
211 BigDecimal.valueOf(6.632525),
212 BigDecimal.valueOf(12.879439),
213 new BigDecimal("999999999999999999980000.000000")};
214
215 for (int i = 0; i < 4; i++) {
216 assertFalse(outVector.isNull(i));
217 assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i)
218 ) == 0);
219 }
220
221 // free buffers
222 releaseRecordBatch(batch);
223 releaseValueVectors(output);
224 eval.close();
225 }
226
227 @Test
228 public void testCompare() throws GandivaException {
229 Decimal aType = new Decimal(38, 3, 128);
230 Decimal bType = new Decimal(38, 2, 128);
231 Field a = Field.nullable("a", aType);
232 Field b = Field.nullable("b", bType);
233 List<Field> args = Lists.newArrayList(a, b);
234
235 List<ExpressionTree> exprs = new ArrayList<>(
236 Arrays.asList(
237 TreeBuilder.makeExpression("equal", args, Field.nullable("eq", boolType)),
238 TreeBuilder.makeExpression("not_equal", args, Field.nullable("ne", boolType)),
239 TreeBuilder.makeExpression("less_than", args, Field.nullable("lt", boolType)),
240 TreeBuilder.makeExpression("less_than_or_equal_to", args, Field.nullable("le", boolType)),
241 TreeBuilder.makeExpression("greater_than", args, Field.nullable("gt", boolType)),
242 TreeBuilder.makeExpression("greater_than_or_equal_to", args, Field.nullable("ge", boolType))
243 )
244 );
245
246 Schema schema = new Schema(args);
247 Projector eval = Projector.make(schema, exprs);
248
249 List<ValueVector> output = null;
250 ArrowRecordBatch batch = null;
251 try {
252 int numRows = 4;
253 String[] aValues = new String[]{"7.620", "2.380", "3.860", "-18.160"};
254 String[] bValues = new String[]{"7.62", "3.50", "1.90", "-1.45"};
255
256 DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale());
257 DecimalVector valuesb = decimalVector(bValues, bType.getPrecision(), bType.getScale());
258 batch =
259 new ArrowRecordBatch(
260 numRows,
261 Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
262 Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
263 valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
264
265 // expected results.
266 boolean[][] expected = {
267 {true, false, false, false}, // eq
268 {false, true, true, true}, // ne
269 {false, true, false, true}, // lt
270 {true, true, false, true}, // le
271 {false, false, true, false}, // gt
272 {true, false, true, false}, // ge
273 };
274
275 // Allocate output vectors.
276 output = new ArrayList<>(
277 Arrays.asList(
278 new BitVector("eq", allocator),
279 new BitVector("ne", allocator),
280 new BitVector("lt", allocator),
281 new BitVector("le", allocator),
282 new BitVector("gt", allocator),
283 new BitVector("ge", allocator)
284 )
285 );
286 for (ValueVector v : output) {
287 v.allocateNew();
288 }
289
290 // evaluate expressions.
291 eval.evaluate(batch, output);
292
293 // compare the outputs.
294 for (int idx = 0; idx < output.size(); ++idx) {
295 boolean[] expectedArray = expected[idx];
296 BitVector resultVector = (BitVector) output.get(idx);
297
298 for (int i = 0; i < numRows; i++) {
299 assertFalse(resultVector.isNull(i));
300 assertEquals("mismatch in result for expr at idx " + idx + " for row " + i,
301 expectedArray[i], resultVector.getObject(i).booleanValue());
302 }
303 }
304 } finally {
305 // free buffers
306 if (batch != null) {
307 releaseRecordBatch(batch);
308 }
309 if (output != null) {
310 releaseValueVectors(output);
311 }
312 eval.close();
313 }
314 }
315
316 @Test
317 public void testRound() throws GandivaException {
318 Decimal aType = new Decimal(38, 2, 128);
319 Decimal aWithScaleZero = new Decimal(38, 0, 128);
320 Decimal aWithScaleOne = new Decimal(38, 1, 128);
321 Field a = Field.nullable("a", aType);
322 List<Field> args = Lists.newArrayList(a);
323
324 List<ExpressionTree> exprs = new ArrayList<>(
325 Arrays.asList(
326 TreeBuilder.makeExpression("abs", args, Field.nullable("abs", aType)),
327 TreeBuilder.makeExpression("ceil", args, Field.nullable("ceil", aWithScaleZero)),
328 TreeBuilder.makeExpression("floor", args, Field.nullable("floor", aWithScaleZero)),
329 TreeBuilder.makeExpression("round", args, Field.nullable("round", aWithScaleZero)),
330 TreeBuilder.makeExpression("truncate", args, Field.nullable("truncate", aWithScaleZero)),
331 TreeBuilder.makeExpression(
332 TreeBuilder.makeFunction("round",
333 Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)),
334 aWithScaleOne),
335 Field.nullable("round_scale_1", aWithScaleOne)),
336 TreeBuilder.makeExpression(
337 TreeBuilder.makeFunction("truncate",
338 Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)),
339 aWithScaleOne),
340 Field.nullable("truncate_scale_1", aWithScaleOne))
341 )
342 );
343
344 Schema schema = new Schema(args);
345 Projector eval = Projector.make(schema, exprs);
346
347 List<ValueVector> output = null;
348 ArrowRecordBatch batch = null;
349 try {
350 int numRows = 4;
351 String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"};
352
353 DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale());
354 batch =
355 new ArrowRecordBatch(
356 numRows,
357 Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
358 Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer()));
359
360 // expected results.
361 BigDecimal[][] expected = {
362 {BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58),
363 BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58)}, // abs
364 {BigDecimal.valueOf(2), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // ceil
365 {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-2), BigDecimal.valueOf(-2)}, // floor
366 {BigDecimal.valueOf(1), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-2)}, // round
367 {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // truncate
368 {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6),
369 BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)}, // round-to-scale-1
370 {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.5),
371 BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.5)}, // truncate-to-scale-1
372 };
373
374 // Allocate output vectors.
375 output = new ArrayList<>(
376 Arrays.asList(
377 new DecimalVector("abs", allocator, aType.getPrecision(), aType.getScale()),
378 new DecimalVector("ceil", allocator, aType.getPrecision(), 0),
379 new DecimalVector("floor", allocator, aType.getPrecision(), 0),
380 new DecimalVector("round", allocator, aType.getPrecision(), 0),
381 new DecimalVector("truncate", allocator, aType.getPrecision(), 0),
382 new DecimalVector("round_to_scale_1", allocator, aType.getPrecision(), 1),
383 new DecimalVector("truncate_to_scale_1", allocator, aType.getPrecision(), 1)
384 )
385 );
386 for (ValueVector v : output) {
387 v.allocateNew();
388 }
389
390 // evaluate expressions.
391 eval.evaluate(batch, output);
392
393 // compare the outputs.
394 for (int idx = 0; idx < output.size(); ++idx) {
395 BigDecimal[] expectedArray = expected[idx];
396 DecimalVector resultVector = (DecimalVector) output.get(idx);
397
398 for (int i = 0; i < numRows; i++) {
399 assertFalse(resultVector.isNull(i));
400 assertTrue("mismatch in result for " +
401 "field " + resultVector.getField().getName() +
402 " for row " + i +
403 " expected " + expectedArray[i] +
404 ", got " + resultVector.getObject(i),
405 expectedArray[i].compareTo(resultVector.getObject(i)) == 0);
406 }
407 }
408 } finally {
409 // free buffers
410 if (batch != null) {
411 releaseRecordBatch(batch);
412 }
413 if (output != null) {
414 releaseValueVectors(output);
415 }
416 eval.close();
417 }
418 }
419
420 @Test
421 public void testCastToDecimal() throws GandivaException {
422 Decimal decimalType = new Decimal(38, 2, 128);
423 Decimal decimalWithScaleOne = new Decimal(38, 1, 128);
424 Field dec = Field.nullable("dec", decimalType);
425 Field int64f = Field.nullable("int64", int64);
426 Field doublef = Field.nullable("float64", float64);
427
428 List<ExpressionTree> exprs = new ArrayList<>(
429 Arrays.asList(
430 TreeBuilder.makeExpression("castDECIMAL",
431 Lists.newArrayList(int64f),
432 Field.nullable("int64_to_dec", decimalType)),
433
434 TreeBuilder.makeExpression("castDECIMAL",
435 Lists.newArrayList(doublef),
436 Field.nullable("float64_to_dec", decimalType)),
437
438 TreeBuilder.makeExpression("castDECIMAL",
439 Lists.newArrayList(dec),
440 Field.nullable("dec_to_dec", decimalWithScaleOne))
441 )
442 );
443
444 Schema schema = new Schema(Lists.newArrayList(int64f, doublef, dec));
445 Projector eval = Projector.make(schema, exprs);
446
447 List<ValueVector> output = null;
448 ArrowRecordBatch batch = null;
449 try {
450 int numRows = 4;
451 String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"};
452 DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
453 batch = new ArrowRecordBatch(
454 numRows,
455 Lists.newArrayList(
456 new ArrowFieldNode(numRows, 0),
457 new ArrowFieldNode(numRows, 0),
458 new ArrowFieldNode(numRows, 0)),
459 Lists.newArrayList(
460 arrowBufWithAllValid(4),
461 longBuf(new long[]{123, 158, -123, -158}),
462 arrowBufWithAllValid(4),
463 doubleBuf(new double[]{1.23, 1.58, -1.23, -1.58}),
464 valuesa.getValidityBuffer(),
465 valuesa.getDataBuffer())
466 );
467
468 // Allocate output vectors.
469 output = new ArrayList<>(
470 Arrays.asList(
471 new DecimalVector("int64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()),
472 new DecimalVector("float64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()),
473 new DecimalVector("dec_to_dec", allocator,
474 decimalWithScaleOne.getPrecision(), decimalWithScaleOne.getScale())
475 )
476 );
477 for (ValueVector v : output) {
478 v.allocateNew();
479 }
480
481 // evaluate expressions.
482 eval.evaluate(batch, output);
483
484 // compare the outputs.
485 BigDecimal[][] expected = {
486 { BigDecimal.valueOf(123), BigDecimal.valueOf(158),
487 BigDecimal.valueOf(-123), BigDecimal.valueOf(-158)},
488 { BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58),
489 BigDecimal.valueOf(-1.23), BigDecimal.valueOf(-1.58)},
490 { BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6),
491 BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)}
492 };
493 for (int idx = 0; idx < output.size(); ++idx) {
494 BigDecimal[] expectedArray = expected[idx];
495 DecimalVector resultVector = (DecimalVector) output.get(idx);
496 for (int i = 0; i < numRows; i++) {
497 assertFalse(resultVector.isNull(i));
498 assertTrue("mismatch in result for " +
499 "field " + resultVector.getField().getName() +
500 " for row " + i +
501 " expected " + expectedArray[i] +
502 ", got " + resultVector.getObject(i),
503 expectedArray[i].compareTo(resultVector.getObject(i)) == 0);
504 }
505 }
506 } finally {
507 // free buffers
508 if (batch != null) {
509 releaseRecordBatch(batch);
510 }
511 if (output != null) {
512 releaseValueVectors(output);
513 }
514 eval.close();
515 }
516 }
517
518 @Test
519 public void testCastToLong() throws GandivaException {
520 Decimal decimalType = new Decimal(38, 2, 128);
521 Field dec = Field.nullable("dec", decimalType);
522
523 Schema schema = new Schema(Lists.newArrayList(dec));
524 Projector eval = Projector.make(schema,
525 Lists.newArrayList(
526 TreeBuilder.makeExpression("castBIGINT",
527 Lists.newArrayList(dec),
528 Field.nullable("dec_to_int64", int64)
529 )
530 )
531 );
532
533 List<ValueVector> output = null;
534 ArrowRecordBatch batch = null;
535 try {
536 int numRows = 5;
537 String[] aValues = new String[]{"1.23", "1.50", "98765.78", "-1.23", "-1.58"};
538 DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
539 batch = new ArrowRecordBatch(
540 numRows,
541 Lists.newArrayList(
542 new ArrowFieldNode(numRows, 0)
543 ),
544 Lists.newArrayList(
545 valuesa.getValidityBuffer(),
546 valuesa.getDataBuffer()
547 )
548 );
549
550 // Allocate output vectors.
551 BigIntVector resultVector = new BigIntVector("dec_to_int64", allocator);
552 resultVector.allocateNew();
553 output = new ArrayList<>(Arrays.asList(resultVector));
554
555 // evaluate expressions.
556 eval.evaluate(batch, output);
557
558 // compare the outputs.
559 long[] expected = {1, 2, 98766, -1, -2};
560 for (int i = 0; i < numRows; i++) {
561 assertFalse(resultVector.isNull(i));
562 assertEquals(expected[i], resultVector.get(i));
563 }
564 } finally {
565 // free buffers
566 if (batch != null) {
567 releaseRecordBatch(batch);
568 }
569 if (output != null) {
570 releaseValueVectors(output);
571 }
572 eval.close();
573 }
574 }
575
576 @Test
577 public void testCastToDouble() throws GandivaException {
578 Decimal decimalType = new Decimal(38, 2, 128);
579 Field dec = Field.nullable("dec", decimalType);
580
581 Schema schema = new Schema(Lists.newArrayList(dec));
582 Projector eval = Projector.make(schema,
583 Lists.newArrayList(
584 TreeBuilder.makeExpression("castFLOAT8",
585 Lists.newArrayList(dec),
586 Field.nullable("dec_to_float64", float64)
587 )
588 )
589 );
590
591 List<ValueVector> output = null;
592 ArrowRecordBatch batch = null;
593 try {
594 int numRows = 4;
595 String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"};
596 DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
597 batch = new ArrowRecordBatch(
598 numRows,
599 Lists.newArrayList(
600 new ArrowFieldNode(numRows, 0)
601 ),
602 Lists.newArrayList(
603 valuesa.getValidityBuffer(),
604 valuesa.getDataBuffer()
605 )
606 );
607
608 // Allocate output vectors.
609 Float8Vector resultVector = new Float8Vector("dec_to_float64", allocator);
610 resultVector.allocateNew();
611 output = new ArrayList<>(Arrays.asList(resultVector));
612
613 // evaluate expressions.
614 eval.evaluate(batch, output);
615
616 // compare the outputs.
617 double[] expected = {1.23, 1.58, -1.23, -1.58};
618 for (int i = 0; i < numRows; i++) {
619 assertFalse(resultVector.isNull(i));
620 assertEquals(expected[i], resultVector.get(i), 0);
621 }
622 } finally {
623 // free buffers
624 if (batch != null) {
625 releaseRecordBatch(batch);
626 }
627 if (output != null) {
628 releaseValueVectors(output);
629 }
630 eval.close();
631 }
632 }
633
634 @Test
635 public void testCastToString() throws GandivaException {
636 Decimal decimalType = new Decimal(38, 2, 128);
637 Field dec = Field.nullable("dec", decimalType);
638 Field str = Field.nullable("str", new ArrowType.Utf8());
639 TreeNode field = TreeBuilder.makeField(dec);
640 TreeNode literal = TreeBuilder.makeLiteral(5L);
641 List<TreeNode> args = Lists.newArrayList(field, literal);
642 TreeNode cast = TreeBuilder.makeFunction("castVARCHAR", args, new ArrowType.Utf8());
643 TreeNode root = TreeBuilder.makeFunction("equal",
644 Lists.newArrayList(cast, TreeBuilder.makeField(str)), new ArrowType.Bool());
645 ExpressionTree tree = TreeBuilder.makeExpression(root, Field.nullable("are_equal", new ArrowType.Bool()));
646
647 Schema schema = new Schema(Lists.newArrayList(dec, str));
648 Projector eval = Projector.make(schema, Lists.newArrayList(tree)
649 );
650
651 List<ValueVector> output = null;
652 ArrowRecordBatch batch = null;
653 try {
654 int numRows = 4;
655 String[] aValues = new String[]{"10.51", "100.23", "-1000.23", "-0000.10"};
656 String[] expected = {"10.51", "100.2", "-1000", "-0.10"};
657 DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
658 VarCharVector result = varcharVector(expected);
659 batch = new ArrowRecordBatch(
660 numRows,
661 Lists.newArrayList(
662 new ArrowFieldNode(numRows, 0)
663 ),
664 Lists.newArrayList(
665 valuesa.getValidityBuffer(),
666 valuesa.getDataBuffer(),
667 result.getValidityBuffer(),
668 result.getOffsetBuffer(),
669 result.getDataBuffer()
670 )
671 );
672
673 BitVector resultVector = new BitVector("res", allocator);
674 resultVector.allocateNew();
675 output = new ArrayList<>(Arrays.asList(resultVector));
676
677 // evaluate expressions.
678 eval.evaluate(batch, output);
679
680 // compare the outputs.
681 for (int i = 0; i < numRows; i++) {
682 assertTrue(resultVector.getObject(i).booleanValue());
683 }
684 } finally {
685 // free buffers
686 if (batch != null) {
687 releaseRecordBatch(batch);
688 }
689 if (output != null) {
690 releaseValueVectors(output);
691 }
692 eval.close();
693 }
694 }
695
696 @Test
697 public void testCastStringToDecimal() throws GandivaException {
698 Decimal decimalType = new Decimal(4, 2, 128);
699 Field dec = Field.nullable("dec", decimalType);
700
701 Field str = Field.nullable("str", new ArrowType.Utf8());
702 TreeNode field = TreeBuilder.makeField(str);
703 List<TreeNode> args = Lists.newArrayList(field);
704 TreeNode cast = TreeBuilder.makeFunction("castDECIMAL", args, decimalType);
705 ExpressionTree tree = TreeBuilder.makeExpression(cast, Field.nullable("dec_str", decimalType));
706
707 Schema schema = new Schema(Lists.newArrayList(str));
708 Projector eval = Projector.make(schema, Lists.newArrayList(tree)
709 );
710
711 List<ValueVector> output = null;
712 ArrowRecordBatch batch = null;
713 try {
714 int numRows = 4;
715 String[] aValues = new String[]{"10.5134", "-0.1", "10.516", "-1000"};
716 VarCharVector valuesa = varcharVector(aValues);
717 batch = new ArrowRecordBatch(
718 numRows,
719 Lists.newArrayList(
720 new ArrowFieldNode(numRows, 0)
721 ),
722 Lists.newArrayList(
723 valuesa.getValidityBuffer(),
724 valuesa.getOffsetBuffer(),
725 valuesa.getDataBuffer()
726 )
727 );
728
729 DecimalVector resultVector = new DecimalVector("res", allocator,
730 decimalType.getPrecision(), decimalType.getScale());
731 resultVector.allocateNew();
732 output = new ArrayList<>(Arrays.asList(resultVector));
733
734 BigDecimal[] expected = {BigDecimal.valueOf(10.51), BigDecimal.valueOf(-0.10),
735 BigDecimal.valueOf(10.52), BigDecimal.valueOf(0.00)};
736 // evaluate expressions.
737 eval.evaluate(batch, output);
738
739 // compare the outputs.
740 for (int i = 0; i < numRows; i++) {
741 assertTrue("mismatch in result for " +
742 "field " + resultVector.getField().getName() +
743 " for row " + i +
744 " expected " + expected[i] +
745 ", got " + resultVector.getObject(i), expected[i].compareTo(resultVector.getObject(i)) == 0);
746 }
747 } finally {
748 // free buffers
749 if (batch != null) {
750 releaseRecordBatch(batch);
751 }
752 if (output != null) {
753 releaseValueVectors(output);
754 }
755 eval.close();
756 }
757 }
758
759 @Test
760 public void testInvalidDecimal() throws GandivaException {
761 exception.expect(IllegalArgumentException.class);
762 exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" +
763 " : 0");
764 Decimal decimalType = new Decimal(0, 0, 128);
765 Field int64f = Field.nullable("int64", int64);
766
767 Schema schema = new Schema(Lists.newArrayList(int64f));
768 Projector eval = Projector.make(schema,
769 Lists.newArrayList(
770 TreeBuilder.makeExpression("castDECIMAL",
771 Lists.newArrayList(int64f),
772 Field.nullable("invalid_dec", decimalType)
773 )
774 )
775 );
776 }
777
778 @Test
779 public void testInvalidDecimalGt38() throws GandivaException {
780 exception.expect(IllegalArgumentException.class);
781 exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" +
782 " : 42");
783 Decimal decimalType = new Decimal(42, 0, 128);
784 Field int64f = Field.nullable("int64", int64);
785
786 Schema schema = new Schema(Lists.newArrayList(int64f));
787 Projector eval = Projector.make(schema,
788 Lists.newArrayList(
789 TreeBuilder.makeExpression("castDECIMAL",
790 Lists.newArrayList(int64f),
791 Field.nullable("invalid_dec", decimalType)
792 )
793 )
794 );
795 }
796}
797