]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one or more | |
3 | * contributor license agreements. See the NOTICE file distributed with | |
4 | * this work for additional information regarding copyright ownership. | |
5 | * The ASF licenses this file to You under the Apache License, Version 2.0 | |
6 | * (the "License"); you may not use this file except in compliance with | |
7 | * the License. You may obtain a copy of the License at | |
8 | * | |
9 | * http://www.apache.org/licenses/LICENSE-2.0 | |
10 | * | |
11 | * Unless required by applicable law or agreed to in writing, software | |
12 | * distributed under the License is distributed on an "AS IS" BASIS, | |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 | * See the License for the specific language governing permissions and | |
15 | * limitations under the License. | |
16 | */ | |
17 | ||
18 | package org.apache.arrow.gandiva.evaluator; | |
19 | ||
20 | import static org.junit.Assert.assertEquals; | |
21 | import static org.junit.Assert.assertFalse; | |
22 | import static org.junit.Assert.assertTrue; | |
23 | ||
24 | import java.math.BigDecimal; | |
25 | import java.util.ArrayList; | |
26 | import java.util.Arrays; | |
27 | import java.util.List; | |
28 | ||
29 | import org.apache.arrow.gandiva.exceptions.GandivaException; | |
30 | import org.apache.arrow.gandiva.expression.ExpressionTree; | |
31 | import org.apache.arrow.gandiva.expression.TreeBuilder; | |
32 | import org.apache.arrow.gandiva.expression.TreeNode; | |
33 | import org.apache.arrow.vector.BigIntVector; | |
34 | import org.apache.arrow.vector.BitVector; | |
35 | import org.apache.arrow.vector.DecimalVector; | |
36 | import org.apache.arrow.vector.Float8Vector; | |
37 | import org.apache.arrow.vector.ValueVector; | |
38 | import org.apache.arrow.vector.VarCharVector; | |
39 | import org.apache.arrow.vector.ipc.message.ArrowFieldNode; | |
40 | import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; | |
41 | import org.apache.arrow.vector.types.pojo.ArrowType; | |
42 | import org.apache.arrow.vector.types.pojo.ArrowType.Decimal; | |
43 | import org.apache.arrow.vector.types.pojo.Field; | |
44 | import org.apache.arrow.vector.types.pojo.Schema; | |
45 | import org.junit.Rule; | |
46 | import org.junit.Test; | |
47 | import org.junit.rules.ExpectedException; | |
48 | ||
49 | import com.google.common.collect.Lists; | |
50 | ||
51 | public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.BaseEvaluatorTest { | |
52 | @Rule | |
53 | public ExpectedException exception = ExpectedException.none(); | |
54 | ||
55 | @Test | |
56 | public void test_add() throws GandivaException { | |
57 | int precision = 38; | |
58 | int scale = 8; | |
59 | ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); | |
60 | Field a = Field.nullable("a", decimal); | |
61 | Field b = Field.nullable("b", decimal); | |
62 | List<Field> args = Lists.newArrayList(a, b); | |
63 | ||
64 | ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil | |
65 | .OperationType.ADD, decimal, decimal); | |
66 | Field retType = Field.nullable("c", outputType); | |
67 | ExpressionTree root = TreeBuilder.makeExpression("add", args, retType); | |
68 | ||
69 | List<ExpressionTree> exprs = Lists.newArrayList(root); | |
70 | ||
71 | Schema schema = new Schema(args); | |
72 | Projector eval = Projector.make(schema, exprs); | |
73 | ||
74 | int numRows = 4; | |
75 | byte[] validity = new byte[]{(byte) 255}; | |
76 | String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "4.12345678"}; | |
77 | String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "5.12345678"}; | |
78 | ||
79 | DecimalVector valuesa = decimalVector(aValues, precision, scale); | |
80 | DecimalVector valuesb = decimalVector(bValues, precision, scale); | |
81 | ArrowRecordBatch batch = | |
82 | new ArrowRecordBatch( | |
83 | numRows, | |
84 | Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), | |
85 | Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(), | |
86 | valuesb.getValidityBuffer(), valuesb.getDataBuffer())); | |
87 | ||
88 | DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(), | |
89 | outputType.getScale()); | |
90 | outVector.allocateNew(numRows); | |
91 | ||
92 | List<ValueVector> output = new ArrayList<ValueVector>(); | |
93 | output.add(outVector); | |
94 | eval.evaluate(batch, output); | |
95 | ||
96 | // should have scaled down. | |
97 | BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(3.2469136), | |
98 | BigDecimal.valueOf(5.2469136), | |
99 | BigDecimal.valueOf(7.2469136), | |
100 | BigDecimal.valueOf(9.2469136)}; | |
101 | ||
102 | for (int i = 0; i < 4; i++) { | |
103 | assertFalse(outVector.isNull(i)); | |
104 | assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i) | |
105 | ) == 0); | |
106 | } | |
107 | ||
108 | // free buffers | |
109 | releaseRecordBatch(batch); | |
110 | releaseValueVectors(output); | |
111 | eval.close(); | |
112 | } | |
113 | ||
114 | @Test | |
115 | public void test_add_literal() throws GandivaException { | |
116 | int precision = 2; | |
117 | int scale = 0; | |
118 | ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); | |
119 | ArrowType.Decimal literalType = new ArrowType.Decimal(2, 1, 128); | |
120 | Field a = Field.nullable("a", decimal); | |
121 | ||
122 | ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil | |
123 | .OperationType.ADD, decimal, literalType); | |
124 | Field retType = Field.nullable("c", outputType); | |
125 | TreeNode field = TreeBuilder.makeField(a); | |
126 | TreeNode literal = TreeBuilder.makeDecimalLiteral("6", 2, 1); | |
127 | List<TreeNode> args = Lists.newArrayList(field, literal); | |
128 | TreeNode root = TreeBuilder.makeFunction("add", args, outputType); | |
129 | ExpressionTree tree = TreeBuilder.makeExpression(root, retType); | |
130 | ||
131 | List<ExpressionTree> exprs = Lists.newArrayList(tree); | |
132 | ||
133 | Schema schema = new Schema(Lists.newArrayList(a)); | |
134 | Projector eval = Projector.make(schema, exprs); | |
135 | ||
136 | int numRows = 4; | |
137 | String[] aValues = new String[]{"1", "2", "3", "4"}; | |
138 | ||
139 | DecimalVector valuesa = decimalVector(aValues, precision, scale); | |
140 | ArrowRecordBatch batch = | |
141 | new ArrowRecordBatch( | |
142 | numRows, | |
143 | Lists.newArrayList(new ArrowFieldNode(numRows, 0)), | |
144 | Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer())); | |
145 | ||
146 | DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(), | |
147 | outputType.getScale()); | |
148 | outVector.allocateNew(numRows); | |
149 | ||
150 | List<ValueVector> output = new ArrayList<ValueVector>(); | |
151 | output.add(outVector); | |
152 | eval.evaluate(batch, output); | |
153 | ||
154 | BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(1.6), BigDecimal.valueOf(2.6), | |
155 | BigDecimal.valueOf(3.6), BigDecimal.valueOf(4.6)}; | |
156 | ||
157 | for (int i = 0; i < 4; i++) { | |
158 | assertFalse(outVector.isNull(i)); | |
159 | assertTrue(expOutput[i].compareTo(outVector.getObject(i)) == 0); | |
160 | } | |
161 | ||
162 | // free buffers | |
163 | releaseRecordBatch(batch); | |
164 | releaseValueVectors(output); | |
165 | eval.close(); | |
166 | } | |
167 | ||
168 | @Test | |
169 | public void test_multiply() throws GandivaException { | |
170 | int precision = 38; | |
171 | int scale = 8; | |
172 | ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); | |
173 | Field a = Field.nullable("a", decimal); | |
174 | Field b = Field.nullable("b", decimal); | |
175 | List<Field> args = Lists.newArrayList(a, b); | |
176 | ||
177 | ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil | |
178 | .OperationType.MULTIPLY, decimal, decimal); | |
179 | Field retType = Field.nullable("c", outputType); | |
180 | ExpressionTree root = TreeBuilder.makeExpression("multiply", args, retType); | |
181 | ||
182 | List<ExpressionTree> exprs = Lists.newArrayList(root); | |
183 | ||
184 | Schema schema = new Schema(args); | |
185 | Projector eval = Projector.make(schema, exprs); | |
186 | ||
187 | int numRows = 4; | |
188 | byte[] validity = new byte[]{(byte) 255}; | |
189 | String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "999999999999.99999999"}; | |
190 | String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "999999999999.99999999"}; | |
191 | ||
192 | DecimalVector valuesa = decimalVector(aValues, precision, scale); | |
193 | DecimalVector valuesb = decimalVector(bValues, precision, scale); | |
194 | ArrowRecordBatch batch = | |
195 | new ArrowRecordBatch( | |
196 | numRows, | |
197 | Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), | |
198 | Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(), | |
199 | valuesb.getValidityBuffer(), valuesb.getDataBuffer())); | |
200 | ||
201 | DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(), | |
202 | outputType.getScale()); | |
203 | outVector.allocateNew(numRows); | |
204 | ||
205 | List<ValueVector> output = new ArrayList<ValueVector>(); | |
206 | output.add(outVector); | |
207 | eval.evaluate(batch, output); | |
208 | ||
209 | // should have scaled down. | |
210 | BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(2.385612), | |
211 | BigDecimal.valueOf(6.632525), | |
212 | BigDecimal.valueOf(12.879439), | |
213 | new BigDecimal("999999999999999999980000.000000")}; | |
214 | ||
215 | for (int i = 0; i < 4; i++) { | |
216 | assertFalse(outVector.isNull(i)); | |
217 | assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i) | |
218 | ) == 0); | |
219 | } | |
220 | ||
221 | // free buffers | |
222 | releaseRecordBatch(batch); | |
223 | releaseValueVectors(output); | |
224 | eval.close(); | |
225 | } | |
226 | ||
227 | @Test | |
228 | public void testCompare() throws GandivaException { | |
229 | Decimal aType = new Decimal(38, 3, 128); | |
230 | Decimal bType = new Decimal(38, 2, 128); | |
231 | Field a = Field.nullable("a", aType); | |
232 | Field b = Field.nullable("b", bType); | |
233 | List<Field> args = Lists.newArrayList(a, b); | |
234 | ||
235 | List<ExpressionTree> exprs = new ArrayList<>( | |
236 | Arrays.asList( | |
237 | TreeBuilder.makeExpression("equal", args, Field.nullable("eq", boolType)), | |
238 | TreeBuilder.makeExpression("not_equal", args, Field.nullable("ne", boolType)), | |
239 | TreeBuilder.makeExpression("less_than", args, Field.nullable("lt", boolType)), | |
240 | TreeBuilder.makeExpression("less_than_or_equal_to", args, Field.nullable("le", boolType)), | |
241 | TreeBuilder.makeExpression("greater_than", args, Field.nullable("gt", boolType)), | |
242 | TreeBuilder.makeExpression("greater_than_or_equal_to", args, Field.nullable("ge", boolType)) | |
243 | ) | |
244 | ); | |
245 | ||
246 | Schema schema = new Schema(args); | |
247 | Projector eval = Projector.make(schema, exprs); | |
248 | ||
249 | List<ValueVector> output = null; | |
250 | ArrowRecordBatch batch = null; | |
251 | try { | |
252 | int numRows = 4; | |
253 | String[] aValues = new String[]{"7.620", "2.380", "3.860", "-18.160"}; | |
254 | String[] bValues = new String[]{"7.62", "3.50", "1.90", "-1.45"}; | |
255 | ||
256 | DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale()); | |
257 | DecimalVector valuesb = decimalVector(bValues, bType.getPrecision(), bType.getScale()); | |
258 | batch = | |
259 | new ArrowRecordBatch( | |
260 | numRows, | |
261 | Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), | |
262 | Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(), | |
263 | valuesb.getValidityBuffer(), valuesb.getDataBuffer())); | |
264 | ||
265 | // expected results. | |
266 | boolean[][] expected = { | |
267 | {true, false, false, false}, // eq | |
268 | {false, true, true, true}, // ne | |
269 | {false, true, false, true}, // lt | |
270 | {true, true, false, true}, // le | |
271 | {false, false, true, false}, // gt | |
272 | {true, false, true, false}, // ge | |
273 | }; | |
274 | ||
275 | // Allocate output vectors. | |
276 | output = new ArrayList<>( | |
277 | Arrays.asList( | |
278 | new BitVector("eq", allocator), | |
279 | new BitVector("ne", allocator), | |
280 | new BitVector("lt", allocator), | |
281 | new BitVector("le", allocator), | |
282 | new BitVector("gt", allocator), | |
283 | new BitVector("ge", allocator) | |
284 | ) | |
285 | ); | |
286 | for (ValueVector v : output) { | |
287 | v.allocateNew(); | |
288 | } | |
289 | ||
290 | // evaluate expressions. | |
291 | eval.evaluate(batch, output); | |
292 | ||
293 | // compare the outputs. | |
294 | for (int idx = 0; idx < output.size(); ++idx) { | |
295 | boolean[] expectedArray = expected[idx]; | |
296 | BitVector resultVector = (BitVector) output.get(idx); | |
297 | ||
298 | for (int i = 0; i < numRows; i++) { | |
299 | assertFalse(resultVector.isNull(i)); | |
300 | assertEquals("mismatch in result for expr at idx " + idx + " for row " + i, | |
301 | expectedArray[i], resultVector.getObject(i).booleanValue()); | |
302 | } | |
303 | } | |
304 | } finally { | |
305 | // free buffers | |
306 | if (batch != null) { | |
307 | releaseRecordBatch(batch); | |
308 | } | |
309 | if (output != null) { | |
310 | releaseValueVectors(output); | |
311 | } | |
312 | eval.close(); | |
313 | } | |
314 | } | |
315 | ||
316 | @Test | |
317 | public void testRound() throws GandivaException { | |
318 | Decimal aType = new Decimal(38, 2, 128); | |
319 | Decimal aWithScaleZero = new Decimal(38, 0, 128); | |
320 | Decimal aWithScaleOne = new Decimal(38, 1, 128); | |
321 | Field a = Field.nullable("a", aType); | |
322 | List<Field> args = Lists.newArrayList(a); | |
323 | ||
324 | List<ExpressionTree> exprs = new ArrayList<>( | |
325 | Arrays.asList( | |
326 | TreeBuilder.makeExpression("abs", args, Field.nullable("abs", aType)), | |
327 | TreeBuilder.makeExpression("ceil", args, Field.nullable("ceil", aWithScaleZero)), | |
328 | TreeBuilder.makeExpression("floor", args, Field.nullable("floor", aWithScaleZero)), | |
329 | TreeBuilder.makeExpression("round", args, Field.nullable("round", aWithScaleZero)), | |
330 | TreeBuilder.makeExpression("truncate", args, Field.nullable("truncate", aWithScaleZero)), | |
331 | TreeBuilder.makeExpression( | |
332 | TreeBuilder.makeFunction("round", | |
333 | Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)), | |
334 | aWithScaleOne), | |
335 | Field.nullable("round_scale_1", aWithScaleOne)), | |
336 | TreeBuilder.makeExpression( | |
337 | TreeBuilder.makeFunction("truncate", | |
338 | Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)), | |
339 | aWithScaleOne), | |
340 | Field.nullable("truncate_scale_1", aWithScaleOne)) | |
341 | ) | |
342 | ); | |
343 | ||
344 | Schema schema = new Schema(args); | |
345 | Projector eval = Projector.make(schema, exprs); | |
346 | ||
347 | List<ValueVector> output = null; | |
348 | ArrowRecordBatch batch = null; | |
349 | try { | |
350 | int numRows = 4; | |
351 | String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"}; | |
352 | ||
353 | DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale()); | |
354 | batch = | |
355 | new ArrowRecordBatch( | |
356 | numRows, | |
357 | Lists.newArrayList(new ArrowFieldNode(numRows, 0)), | |
358 | Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer())); | |
359 | ||
360 | // expected results. | |
361 | BigDecimal[][] expected = { | |
362 | {BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58), | |
363 | BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58)}, // abs | |
364 | {BigDecimal.valueOf(2), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // ceil | |
365 | {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-2), BigDecimal.valueOf(-2)}, // floor | |
366 | {BigDecimal.valueOf(1), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-2)}, // round | |
367 | {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // truncate | |
368 | {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6), | |
369 | BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)}, // round-to-scale-1 | |
370 | {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.5), | |
371 | BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.5)}, // truncate-to-scale-1 | |
372 | }; | |
373 | ||
374 | // Allocate output vectors. | |
375 | output = new ArrayList<>( | |
376 | Arrays.asList( | |
377 | new DecimalVector("abs", allocator, aType.getPrecision(), aType.getScale()), | |
378 | new DecimalVector("ceil", allocator, aType.getPrecision(), 0), | |
379 | new DecimalVector("floor", allocator, aType.getPrecision(), 0), | |
380 | new DecimalVector("round", allocator, aType.getPrecision(), 0), | |
381 | new DecimalVector("truncate", allocator, aType.getPrecision(), 0), | |
382 | new DecimalVector("round_to_scale_1", allocator, aType.getPrecision(), 1), | |
383 | new DecimalVector("truncate_to_scale_1", allocator, aType.getPrecision(), 1) | |
384 | ) | |
385 | ); | |
386 | for (ValueVector v : output) { | |
387 | v.allocateNew(); | |
388 | } | |
389 | ||
390 | // evaluate expressions. | |
391 | eval.evaluate(batch, output); | |
392 | ||
393 | // compare the outputs. | |
394 | for (int idx = 0; idx < output.size(); ++idx) { | |
395 | BigDecimal[] expectedArray = expected[idx]; | |
396 | DecimalVector resultVector = (DecimalVector) output.get(idx); | |
397 | ||
398 | for (int i = 0; i < numRows; i++) { | |
399 | assertFalse(resultVector.isNull(i)); | |
400 | assertTrue("mismatch in result for " + | |
401 | "field " + resultVector.getField().getName() + | |
402 | " for row " + i + | |
403 | " expected " + expectedArray[i] + | |
404 | ", got " + resultVector.getObject(i), | |
405 | expectedArray[i].compareTo(resultVector.getObject(i)) == 0); | |
406 | } | |
407 | } | |
408 | } finally { | |
409 | // free buffers | |
410 | if (batch != null) { | |
411 | releaseRecordBatch(batch); | |
412 | } | |
413 | if (output != null) { | |
414 | releaseValueVectors(output); | |
415 | } | |
416 | eval.close(); | |
417 | } | |
418 | } | |
419 | ||
420 | @Test | |
421 | public void testCastToDecimal() throws GandivaException { | |
422 | Decimal decimalType = new Decimal(38, 2, 128); | |
423 | Decimal decimalWithScaleOne = new Decimal(38, 1, 128); | |
424 | Field dec = Field.nullable("dec", decimalType); | |
425 | Field int64f = Field.nullable("int64", int64); | |
426 | Field doublef = Field.nullable("float64", float64); | |
427 | ||
428 | List<ExpressionTree> exprs = new ArrayList<>( | |
429 | Arrays.asList( | |
430 | TreeBuilder.makeExpression("castDECIMAL", | |
431 | Lists.newArrayList(int64f), | |
432 | Field.nullable("int64_to_dec", decimalType)), | |
433 | ||
434 | TreeBuilder.makeExpression("castDECIMAL", | |
435 | Lists.newArrayList(doublef), | |
436 | Field.nullable("float64_to_dec", decimalType)), | |
437 | ||
438 | TreeBuilder.makeExpression("castDECIMAL", | |
439 | Lists.newArrayList(dec), | |
440 | Field.nullable("dec_to_dec", decimalWithScaleOne)) | |
441 | ) | |
442 | ); | |
443 | ||
444 | Schema schema = new Schema(Lists.newArrayList(int64f, doublef, dec)); | |
445 | Projector eval = Projector.make(schema, exprs); | |
446 | ||
447 | List<ValueVector> output = null; | |
448 | ArrowRecordBatch batch = null; | |
449 | try { | |
450 | int numRows = 4; | |
451 | String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"}; | |
452 | DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); | |
453 | batch = new ArrowRecordBatch( | |
454 | numRows, | |
455 | Lists.newArrayList( | |
456 | new ArrowFieldNode(numRows, 0), | |
457 | new ArrowFieldNode(numRows, 0), | |
458 | new ArrowFieldNode(numRows, 0)), | |
459 | Lists.newArrayList( | |
460 | arrowBufWithAllValid(4), | |
461 | longBuf(new long[]{123, 158, -123, -158}), | |
462 | arrowBufWithAllValid(4), | |
463 | doubleBuf(new double[]{1.23, 1.58, -1.23, -1.58}), | |
464 | valuesa.getValidityBuffer(), | |
465 | valuesa.getDataBuffer()) | |
466 | ); | |
467 | ||
468 | // Allocate output vectors. | |
469 | output = new ArrayList<>( | |
470 | Arrays.asList( | |
471 | new DecimalVector("int64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()), | |
472 | new DecimalVector("float64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()), | |
473 | new DecimalVector("dec_to_dec", allocator, | |
474 | decimalWithScaleOne.getPrecision(), decimalWithScaleOne.getScale()) | |
475 | ) | |
476 | ); | |
477 | for (ValueVector v : output) { | |
478 | v.allocateNew(); | |
479 | } | |
480 | ||
481 | // evaluate expressions. | |
482 | eval.evaluate(batch, output); | |
483 | ||
484 | // compare the outputs. | |
485 | BigDecimal[][] expected = { | |
486 | { BigDecimal.valueOf(123), BigDecimal.valueOf(158), | |
487 | BigDecimal.valueOf(-123), BigDecimal.valueOf(-158)}, | |
488 | { BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58), | |
489 | BigDecimal.valueOf(-1.23), BigDecimal.valueOf(-1.58)}, | |
490 | { BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6), | |
491 | BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)} | |
492 | }; | |
493 | for (int idx = 0; idx < output.size(); ++idx) { | |
494 | BigDecimal[] expectedArray = expected[idx]; | |
495 | DecimalVector resultVector = (DecimalVector) output.get(idx); | |
496 | for (int i = 0; i < numRows; i++) { | |
497 | assertFalse(resultVector.isNull(i)); | |
498 | assertTrue("mismatch in result for " + | |
499 | "field " + resultVector.getField().getName() + | |
500 | " for row " + i + | |
501 | " expected " + expectedArray[i] + | |
502 | ", got " + resultVector.getObject(i), | |
503 | expectedArray[i].compareTo(resultVector.getObject(i)) == 0); | |
504 | } | |
505 | } | |
506 | } finally { | |
507 | // free buffers | |
508 | if (batch != null) { | |
509 | releaseRecordBatch(batch); | |
510 | } | |
511 | if (output != null) { | |
512 | releaseValueVectors(output); | |
513 | } | |
514 | eval.close(); | |
515 | } | |
516 | } | |
517 | ||
518 | @Test | |
519 | public void testCastToLong() throws GandivaException { | |
520 | Decimal decimalType = new Decimal(38, 2, 128); | |
521 | Field dec = Field.nullable("dec", decimalType); | |
522 | ||
523 | Schema schema = new Schema(Lists.newArrayList(dec)); | |
524 | Projector eval = Projector.make(schema, | |
525 | Lists.newArrayList( | |
526 | TreeBuilder.makeExpression("castBIGINT", | |
527 | Lists.newArrayList(dec), | |
528 | Field.nullable("dec_to_int64", int64) | |
529 | ) | |
530 | ) | |
531 | ); | |
532 | ||
533 | List<ValueVector> output = null; | |
534 | ArrowRecordBatch batch = null; | |
535 | try { | |
536 | int numRows = 5; | |
537 | String[] aValues = new String[]{"1.23", "1.50", "98765.78", "-1.23", "-1.58"}; | |
538 | DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); | |
539 | batch = new ArrowRecordBatch( | |
540 | numRows, | |
541 | Lists.newArrayList( | |
542 | new ArrowFieldNode(numRows, 0) | |
543 | ), | |
544 | Lists.newArrayList( | |
545 | valuesa.getValidityBuffer(), | |
546 | valuesa.getDataBuffer() | |
547 | ) | |
548 | ); | |
549 | ||
550 | // Allocate output vectors. | |
551 | BigIntVector resultVector = new BigIntVector("dec_to_int64", allocator); | |
552 | resultVector.allocateNew(); | |
553 | output = new ArrayList<>(Arrays.asList(resultVector)); | |
554 | ||
555 | // evaluate expressions. | |
556 | eval.evaluate(batch, output); | |
557 | ||
558 | // compare the outputs. | |
559 | long[] expected = {1, 2, 98766, -1, -2}; | |
560 | for (int i = 0; i < numRows; i++) { | |
561 | assertFalse(resultVector.isNull(i)); | |
562 | assertEquals(expected[i], resultVector.get(i)); | |
563 | } | |
564 | } finally { | |
565 | // free buffers | |
566 | if (batch != null) { | |
567 | releaseRecordBatch(batch); | |
568 | } | |
569 | if (output != null) { | |
570 | releaseValueVectors(output); | |
571 | } | |
572 | eval.close(); | |
573 | } | |
574 | } | |
575 | ||
576 | @Test | |
577 | public void testCastToDouble() throws GandivaException { | |
578 | Decimal decimalType = new Decimal(38, 2, 128); | |
579 | Field dec = Field.nullable("dec", decimalType); | |
580 | ||
581 | Schema schema = new Schema(Lists.newArrayList(dec)); | |
582 | Projector eval = Projector.make(schema, | |
583 | Lists.newArrayList( | |
584 | TreeBuilder.makeExpression("castFLOAT8", | |
585 | Lists.newArrayList(dec), | |
586 | Field.nullable("dec_to_float64", float64) | |
587 | ) | |
588 | ) | |
589 | ); | |
590 | ||
591 | List<ValueVector> output = null; | |
592 | ArrowRecordBatch batch = null; | |
593 | try { | |
594 | int numRows = 4; | |
595 | String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"}; | |
596 | DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); | |
597 | batch = new ArrowRecordBatch( | |
598 | numRows, | |
599 | Lists.newArrayList( | |
600 | new ArrowFieldNode(numRows, 0) | |
601 | ), | |
602 | Lists.newArrayList( | |
603 | valuesa.getValidityBuffer(), | |
604 | valuesa.getDataBuffer() | |
605 | ) | |
606 | ); | |
607 | ||
608 | // Allocate output vectors. | |
609 | Float8Vector resultVector = new Float8Vector("dec_to_float64", allocator); | |
610 | resultVector.allocateNew(); | |
611 | output = new ArrayList<>(Arrays.asList(resultVector)); | |
612 | ||
613 | // evaluate expressions. | |
614 | eval.evaluate(batch, output); | |
615 | ||
616 | // compare the outputs. | |
617 | double[] expected = {1.23, 1.58, -1.23, -1.58}; | |
618 | for (int i = 0; i < numRows; i++) { | |
619 | assertFalse(resultVector.isNull(i)); | |
620 | assertEquals(expected[i], resultVector.get(i), 0); | |
621 | } | |
622 | } finally { | |
623 | // free buffers | |
624 | if (batch != null) { | |
625 | releaseRecordBatch(batch); | |
626 | } | |
627 | if (output != null) { | |
628 | releaseValueVectors(output); | |
629 | } | |
630 | eval.close(); | |
631 | } | |
632 | } | |
633 | ||
634 | @Test | |
635 | public void testCastToString() throws GandivaException { | |
636 | Decimal decimalType = new Decimal(38, 2, 128); | |
637 | Field dec = Field.nullable("dec", decimalType); | |
638 | Field str = Field.nullable("str", new ArrowType.Utf8()); | |
639 | TreeNode field = TreeBuilder.makeField(dec); | |
640 | TreeNode literal = TreeBuilder.makeLiteral(5L); | |
641 | List<TreeNode> args = Lists.newArrayList(field, literal); | |
642 | TreeNode cast = TreeBuilder.makeFunction("castVARCHAR", args, new ArrowType.Utf8()); | |
643 | TreeNode root = TreeBuilder.makeFunction("equal", | |
644 | Lists.newArrayList(cast, TreeBuilder.makeField(str)), new ArrowType.Bool()); | |
645 | ExpressionTree tree = TreeBuilder.makeExpression(root, Field.nullable("are_equal", new ArrowType.Bool())); | |
646 | ||
647 | Schema schema = new Schema(Lists.newArrayList(dec, str)); | |
648 | Projector eval = Projector.make(schema, Lists.newArrayList(tree) | |
649 | ); | |
650 | ||
651 | List<ValueVector> output = null; | |
652 | ArrowRecordBatch batch = null; | |
653 | try { | |
654 | int numRows = 4; | |
655 | String[] aValues = new String[]{"10.51", "100.23", "-1000.23", "-0000.10"}; | |
656 | String[] expected = {"10.51", "100.2", "-1000", "-0.10"}; | |
657 | DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); | |
658 | VarCharVector result = varcharVector(expected); | |
659 | batch = new ArrowRecordBatch( | |
660 | numRows, | |
661 | Lists.newArrayList( | |
662 | new ArrowFieldNode(numRows, 0) | |
663 | ), | |
664 | Lists.newArrayList( | |
665 | valuesa.getValidityBuffer(), | |
666 | valuesa.getDataBuffer(), | |
667 | result.getValidityBuffer(), | |
668 | result.getOffsetBuffer(), | |
669 | result.getDataBuffer() | |
670 | ) | |
671 | ); | |
672 | ||
673 | BitVector resultVector = new BitVector("res", allocator); | |
674 | resultVector.allocateNew(); | |
675 | output = new ArrayList<>(Arrays.asList(resultVector)); | |
676 | ||
677 | // evaluate expressions. | |
678 | eval.evaluate(batch, output); | |
679 | ||
680 | // compare the outputs. | |
681 | for (int i = 0; i < numRows; i++) { | |
682 | assertTrue(resultVector.getObject(i).booleanValue()); | |
683 | } | |
684 | } finally { | |
685 | // free buffers | |
686 | if (batch != null) { | |
687 | releaseRecordBatch(batch); | |
688 | } | |
689 | if (output != null) { | |
690 | releaseValueVectors(output); | |
691 | } | |
692 | eval.close(); | |
693 | } | |
694 | } | |
695 | ||
696 | @Test | |
697 | public void testCastStringToDecimal() throws GandivaException { | |
698 | Decimal decimalType = new Decimal(4, 2, 128); | |
699 | Field dec = Field.nullable("dec", decimalType); | |
700 | ||
701 | Field str = Field.nullable("str", new ArrowType.Utf8()); | |
702 | TreeNode field = TreeBuilder.makeField(str); | |
703 | List<TreeNode> args = Lists.newArrayList(field); | |
704 | TreeNode cast = TreeBuilder.makeFunction("castDECIMAL", args, decimalType); | |
705 | ExpressionTree tree = TreeBuilder.makeExpression(cast, Field.nullable("dec_str", decimalType)); | |
706 | ||
707 | Schema schema = new Schema(Lists.newArrayList(str)); | |
708 | Projector eval = Projector.make(schema, Lists.newArrayList(tree) | |
709 | ); | |
710 | ||
711 | List<ValueVector> output = null; | |
712 | ArrowRecordBatch batch = null; | |
713 | try { | |
714 | int numRows = 4; | |
715 | String[] aValues = new String[]{"10.5134", "-0.1", "10.516", "-1000"}; | |
716 | VarCharVector valuesa = varcharVector(aValues); | |
717 | batch = new ArrowRecordBatch( | |
718 | numRows, | |
719 | Lists.newArrayList( | |
720 | new ArrowFieldNode(numRows, 0) | |
721 | ), | |
722 | Lists.newArrayList( | |
723 | valuesa.getValidityBuffer(), | |
724 | valuesa.getOffsetBuffer(), | |
725 | valuesa.getDataBuffer() | |
726 | ) | |
727 | ); | |
728 | ||
729 | DecimalVector resultVector = new DecimalVector("res", allocator, | |
730 | decimalType.getPrecision(), decimalType.getScale()); | |
731 | resultVector.allocateNew(); | |
732 | output = new ArrayList<>(Arrays.asList(resultVector)); | |
733 | ||
734 | BigDecimal[] expected = {BigDecimal.valueOf(10.51), BigDecimal.valueOf(-0.10), | |
735 | BigDecimal.valueOf(10.52), BigDecimal.valueOf(0.00)}; | |
736 | // evaluate expressions. | |
737 | eval.evaluate(batch, output); | |
738 | ||
739 | // compare the outputs. | |
740 | for (int i = 0; i < numRows; i++) { | |
741 | assertTrue("mismatch in result for " + | |
742 | "field " + resultVector.getField().getName() + | |
743 | " for row " + i + | |
744 | " expected " + expected[i] + | |
745 | ", got " + resultVector.getObject(i), expected[i].compareTo(resultVector.getObject(i)) == 0); | |
746 | } | |
747 | } finally { | |
748 | // free buffers | |
749 | if (batch != null) { | |
750 | releaseRecordBatch(batch); | |
751 | } | |
752 | if (output != null) { | |
753 | releaseValueVectors(output); | |
754 | } | |
755 | eval.close(); | |
756 | } | |
757 | } | |
758 | ||
759 | @Test | |
760 | public void testInvalidDecimal() throws GandivaException { | |
761 | exception.expect(IllegalArgumentException.class); | |
762 | exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" + | |
763 | " : 0"); | |
764 | Decimal decimalType = new Decimal(0, 0, 128); | |
765 | Field int64f = Field.nullable("int64", int64); | |
766 | ||
767 | Schema schema = new Schema(Lists.newArrayList(int64f)); | |
768 | Projector eval = Projector.make(schema, | |
769 | Lists.newArrayList( | |
770 | TreeBuilder.makeExpression("castDECIMAL", | |
771 | Lists.newArrayList(int64f), | |
772 | Field.nullable("invalid_dec", decimalType) | |
773 | ) | |
774 | ) | |
775 | ); | |
776 | } | |
777 | ||
778 | @Test | |
779 | public void testInvalidDecimalGt38() throws GandivaException { | |
780 | exception.expect(IllegalArgumentException.class); | |
781 | exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" + | |
782 | " : 42"); | |
783 | Decimal decimalType = new Decimal(42, 0, 128); | |
784 | Field int64f = Field.nullable("int64", int64); | |
785 | ||
786 | Schema schema = new Schema(Lists.newArrayList(int64f)); | |
787 | Projector eval = Projector.make(schema, | |
788 | Lists.newArrayList( | |
789 | TreeBuilder.makeExpression("castDECIMAL", | |
790 | Lists.newArrayList(int64f), | |
791 | Field.nullable("invalid_dec", decimalType) | |
792 | ) | |
793 | ) | |
794 | ); | |
795 | } | |
796 | } | |
797 |