2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.arrow
.gandiva
.evaluator
;
20 import java
.nio
.charset
.Charset
;
21 import java
.util
.Arrays
;
22 import java
.util
.List
;
23 import java
.util
.stream
.IntStream
;
25 import org
.apache
.arrow
.gandiva
.exceptions
.GandivaException
;
26 import org
.apache
.arrow
.gandiva
.expression
.Condition
;
27 import org
.apache
.arrow
.gandiva
.expression
.TreeBuilder
;
28 import org
.apache
.arrow
.gandiva
.expression
.TreeNode
;
29 import org
.apache
.arrow
.memory
.ArrowBuf
;
30 import org
.apache
.arrow
.vector
.ipc
.message
.ArrowFieldNode
;
31 import org
.apache
.arrow
.vector
.ipc
.message
.ArrowRecordBatch
;
32 import org
.apache
.arrow
.vector
.types
.pojo
.ArrowType
;
33 import org
.apache
.arrow
.vector
.types
.pojo
.Field
;
34 import org
.apache
.arrow
.vector
.types
.pojo
.Schema
;
35 import org
.junit
.Assert
;
36 import org
.junit
.Test
;
38 import com
.google
.common
.collect
.Lists
;
39 import com
.google
.common
.collect
.Sets
;
41 public class FilterTest
extends BaseEvaluatorTest
{
43 private int[] selectionVectorToArray(SelectionVector vector
) {
44 int[] actual
= new int[vector
.getRecordCount()];
45 for (int i
= 0; i
< vector
.getRecordCount(); ++i
) {
46 actual
[i
] = vector
.getIndex(i
);
51 private Charset utf8Charset
= Charset
.forName("UTF-8");
52 private Charset utf16Charset
= Charset
.forName("UTF-16");
54 List
<ArrowBuf
> varBufs(String
[] strings
, Charset charset
) {
55 ArrowBuf offsetsBuffer
= allocator
.buffer((strings
.length
+ 1) * 4);
56 ArrowBuf dataBuffer
= allocator
.buffer(strings
.length
* 8);
59 for (int i
= 0; i
< strings
.length
; i
++) {
60 offsetsBuffer
.writeInt(startOffset
);
62 final byte[] bytes
= strings
[i
].getBytes(charset
);
63 dataBuffer
= dataBuffer
.reallocIfNeeded(dataBuffer
.writerIndex() + bytes
.length
);
64 dataBuffer
.setBytes(startOffset
, bytes
, 0, bytes
.length
);
65 startOffset
+= bytes
.length
;
67 offsetsBuffer
.writeInt(startOffset
); // offset for the last element
69 return Arrays
.asList(offsetsBuffer
, dataBuffer
);
72 List
<ArrowBuf
> stringBufs(String
[] strings
) {
73 return varBufs(strings
, utf8Charset
);
77 public void testSimpleInString() throws GandivaException
, Exception
{
78 Field c1
= Field
.nullable("c1", new ArrowType
.Utf8());
79 TreeNode l1
= TreeBuilder
.makeLiteral(1L);
80 TreeNode l2
= TreeBuilder
.makeLiteral(3L);
82 List
<Field
> argsSchema
= Lists
.newArrayList(c1
);
83 List
<TreeNode
> args
= Lists
.newArrayList(TreeBuilder
.makeField(c1
), l1
, l2
);
84 TreeNode substr
= TreeBuilder
.makeFunction("substr", args
, new ArrowType
.Utf8());
86 TreeBuilder
.makeInExpressionString(substr
, Sets
.newHashSet("one", "two", "thr", "fou"));
88 Condition condition
= TreeBuilder
.makeCondition(inExpr
);
90 Schema schema
= new Schema(argsSchema
);
91 Filter filter
= Filter
.make(schema
, condition
);
94 byte[] validity
= new byte[] {(byte) 255, 0};
95 // second half is "undefined"
96 String
[] c1Values
= new String
[]{"one", "two", "three", "four", "five", "six", "seven",
97 "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
99 int[] expected
= {0, 1, 2, 3};
100 ArrowBuf c1Validity
= buf(validity
);
101 ArrowBuf c2Validity
= buf(validity
);
102 List
<ArrowBuf
> dataBufsX
= stringBufs(c1Values
);
104 ArrowFieldNode fieldNode
= new ArrowFieldNode(numRows
, 0);
105 ArrowRecordBatch batch
=
106 new ArrowRecordBatch(
108 Lists
.newArrayList(fieldNode
),
109 Lists
.newArrayList(c1Validity
, dataBufsX
.get(0), dataBufsX
.get(1), c2Validity
));
111 ArrowBuf selectionBuffer
= buf(numRows
* 2);
112 SelectionVectorInt16 selectionVector
= new SelectionVectorInt16(selectionBuffer
);
114 filter
.evaluate(batch
, selectionVector
);
116 int[] actual
= selectionVectorToArray(selectionVector
);
117 releaseRecordBatch(batch
);
118 selectionBuffer
.close();
120 Assert
.assertArrayEquals(expected
, actual
);
124 public void testSimpleInInt() throws GandivaException
, Exception
{
125 Field c1
= Field
.nullable("c1", int32
);
127 List
<Field
> argsSchema
= Lists
.newArrayList(c1
);
129 TreeBuilder
.makeInExpressionInt32(TreeBuilder
.makeField(c1
), Sets
.newHashSet(1, 2, 3, 4));
131 Condition condition
= TreeBuilder
.makeCondition(inExpr
);
133 Schema schema
= new Schema(argsSchema
);
134 Filter filter
= Filter
.make(schema
, condition
);
137 byte[] validity
= new byte[] {(byte) 255, 0};
138 // second half is "undefined"
139 int[] aValues
= new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
140 int[] expected
= {0, 1, 2, 3};
142 ArrowBuf validitya
= buf(validity
);
143 ArrowBuf validityb
= buf(validity
);
144 ArrowBuf valuesa
= intBuf(aValues
);
146 ArrowFieldNode fieldNode
= new ArrowFieldNode(numRows
, 0);
147 ArrowRecordBatch batch
=
148 new ArrowRecordBatch(
150 Lists
.newArrayList(fieldNode
),
151 Lists
.newArrayList(validitya
, valuesa
, validityb
));
153 ArrowBuf selectionBuffer
= buf(numRows
* 2);
154 SelectionVectorInt16 selectionVector
= new SelectionVectorInt16(selectionBuffer
);
156 filter
.evaluate(batch
, selectionVector
);
159 int[] actual
= selectionVectorToArray(selectionVector
);
160 releaseRecordBatch(batch
);
161 selectionBuffer
.close();
163 Assert
.assertArrayEquals(expected
, actual
);
167 public void testSimpleSV16() throws GandivaException
, Exception
{
168 Field a
= Field
.nullable("a", int32
);
169 Field b
= Field
.nullable("b", int32
);
170 List
<Field
> args
= Lists
.newArrayList(a
, b
);
172 Condition condition
= TreeBuilder
.makeCondition("less_than", args
);
174 Schema schema
= new Schema(args
);
175 Filter filter
= Filter
.make(schema
, condition
);
178 byte[] validity
= new byte[] {(byte) 255, 0};
179 // second half is "undefined"
180 int[] aValues
= new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
181 int[] bValues
= new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
182 int[] expected
= {0, 2, 4, 6};
184 verifyTestCase(filter
, numRows
, validity
, aValues
, bValues
, expected
);
188 public void testSimpleSV16_AllMatched() throws GandivaException
, Exception
{
189 Field a
= Field
.nullable("a", int32
);
190 Field b
= Field
.nullable("b", int32
);
191 List
<Field
> args
= Lists
.newArrayList(a
, b
);
193 Condition condition
= TreeBuilder
.makeCondition("less_than", args
);
195 Schema schema
= new Schema(args
);
196 Filter filter
= Filter
.make(schema
, condition
);
200 byte[] validity
= new byte[numRows
/ 8];
202 IntStream
.range(0, numRows
/ 8).forEach(i
-> validity
[i
] = (byte) 255);
204 int[] aValues
= new int[numRows
];
205 IntStream
.range(0, numRows
).forEach(i
-> aValues
[i
] = i
);
207 int[] bValues
= new int[numRows
];
208 IntStream
.range(0, numRows
).forEach(i
-> bValues
[i
] = i
+ 1);
210 int[] expected
= new int[numRows
];
211 IntStream
.range(0, numRows
).forEach(i
-> expected
[i
] = i
);
213 verifyTestCase(filter
, numRows
, validity
, aValues
, bValues
, expected
);
217 public void testSimpleSV16_GreaterThan64Recs() throws GandivaException
, Exception
{
218 Field a
= Field
.nullable("a", int32
);
219 Field b
= Field
.nullable("b", int32
);
220 List
<Field
> args
= Lists
.newArrayList(a
, b
);
222 Condition condition
= TreeBuilder
.makeCondition("greater_than", args
);
224 Schema schema
= new Schema(args
);
225 Filter filter
= Filter
.make(schema
, condition
);
229 byte[] validity
= new byte[numRows
/ 8];
231 IntStream
.range(0, numRows
/ 8).forEach(i
-> validity
[i
] = (byte) 255);
233 int[] aValues
= new int[numRows
];
234 IntStream
.range(0, numRows
).forEach(i
-> aValues
[i
] = i
);
236 int[] bValues
= new int[numRows
];
237 IntStream
.range(0, numRows
).forEach(i
-> bValues
[i
] = i
+ 1);
242 int[] expected
= {0};
244 verifyTestCase(filter
, numRows
, validity
, aValues
, bValues
, expected
);
248 public void testSimpleSV32() throws GandivaException
, Exception
{
249 Field a
= Field
.nullable("a", int32
);
250 Field b
= Field
.nullable("b", int32
);
251 List
<Field
> args
= Lists
.newArrayList(a
, b
);
253 Condition condition
= TreeBuilder
.makeCondition("less_than", args
);
255 Schema schema
= new Schema(args
);
256 Filter filter
= Filter
.make(schema
, condition
);
259 byte[] validity
= new byte[] {(byte) 255, 0};
260 // second half is "undefined"
261 int[] aValues
= new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
262 int[] bValues
= new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
263 int[] expected
= {0, 2, 4, 6};
265 verifyTestCase(filter
, numRows
, validity
, aValues
, bValues
, expected
);
269 public void testSimpleFilterWithNoOptimisation() throws GandivaException
, Exception
{
270 Field a
= Field
.nullable("a", int32
);
271 Field b
= Field
.nullable("b", int32
);
272 List
<Field
> args
= Lists
.newArrayList(a
, b
);
274 Condition condition
= TreeBuilder
.makeCondition("less_than", args
);
276 Schema schema
= new Schema(args
);
277 Filter filter
= Filter
.make(schema
, condition
, false);
280 byte[] validity
= new byte[] {(byte) 255, 0};
281 // second half is "undefined"
282 int[] aValues
= new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
283 int[] bValues
= new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
284 int[] expected
= {0, 2, 4, 6};
286 verifyTestCase(filter
, numRows
, validity
, aValues
, bValues
, expected
);
289 private void verifyTestCase(
290 Filter filter
, int numRows
, byte[] validity
, int[] aValues
, int[] bValues
, int[] expected
)
291 throws GandivaException
{
292 ArrowBuf validitya
= buf(validity
);
293 ArrowBuf valuesa
= intBuf(aValues
);
294 ArrowBuf validityb
= buf(validity
);
295 ArrowBuf valuesb
= intBuf(bValues
);
296 ArrowRecordBatch batch
=
297 new ArrowRecordBatch(
299 Lists
.newArrayList(new ArrowFieldNode(numRows
, 0), new ArrowFieldNode(numRows
, 0)),
300 Lists
.newArrayList(validitya
, valuesa
, validityb
, valuesb
));
302 ArrowBuf selectionBuffer
= buf(numRows
* 2);
303 SelectionVectorInt16 selectionVector
= new SelectionVectorInt16(selectionBuffer
);
305 filter
.evaluate(batch
, selectionVector
);
308 int[] actual
= selectionVectorToArray(selectionVector
);
309 releaseRecordBatch(batch
);
310 selectionBuffer
.close();
313 Assert
.assertArrayEquals(expected
, actual
);