]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / java / gandiva / src / test / java / org / apache / arrow / gandiva / evaluator / FilterTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.arrow.gandiva.evaluator;
19
20 import java.nio.charset.Charset;
21 import java.util.Arrays;
22 import java.util.List;
23 import java.util.stream.IntStream;
24
25 import org.apache.arrow.gandiva.exceptions.GandivaException;
26 import org.apache.arrow.gandiva.expression.Condition;
27 import org.apache.arrow.gandiva.expression.TreeBuilder;
28 import org.apache.arrow.gandiva.expression.TreeNode;
29 import org.apache.arrow.memory.ArrowBuf;
30 import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
31 import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
32 import org.apache.arrow.vector.types.pojo.ArrowType;
33 import org.apache.arrow.vector.types.pojo.Field;
34 import org.apache.arrow.vector.types.pojo.Schema;
35 import org.junit.Assert;
36 import org.junit.Test;
37
38 import com.google.common.collect.Lists;
39 import com.google.common.collect.Sets;
40
41 public class FilterTest extends BaseEvaluatorTest {
42
43 private int[] selectionVectorToArray(SelectionVector vector) {
44 int[] actual = new int[vector.getRecordCount()];
45 for (int i = 0; i < vector.getRecordCount(); ++i) {
46 actual[i] = vector.getIndex(i);
47 }
48 return actual;
49 }
50
51 private Charset utf8Charset = Charset.forName("UTF-8");
52 private Charset utf16Charset = Charset.forName("UTF-16");
53
54 List<ArrowBuf> varBufs(String[] strings, Charset charset) {
55 ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4);
56 ArrowBuf dataBuffer = allocator.buffer(strings.length * 8);
57
58 int startOffset = 0;
59 for (int i = 0; i < strings.length; i++) {
60 offsetsBuffer.writeInt(startOffset);
61
62 final byte[] bytes = strings[i].getBytes(charset);
63 dataBuffer = dataBuffer.reallocIfNeeded(dataBuffer.writerIndex() + bytes.length);
64 dataBuffer.setBytes(startOffset, bytes, 0, bytes.length);
65 startOffset += bytes.length;
66 }
67 offsetsBuffer.writeInt(startOffset); // offset for the last element
68
69 return Arrays.asList(offsetsBuffer, dataBuffer);
70 }
71
72 List<ArrowBuf> stringBufs(String[] strings) {
73 return varBufs(strings, utf8Charset);
74 }
75
76 @Test
77 public void testSimpleInString() throws GandivaException, Exception {
78 Field c1 = Field.nullable("c1", new ArrowType.Utf8());
79 TreeNode l1 = TreeBuilder.makeLiteral(1L);
80 TreeNode l2 = TreeBuilder.makeLiteral(3L);
81
82 List<Field> argsSchema = Lists.newArrayList(c1);
83 List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
84 TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
85 TreeNode inExpr =
86 TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
87
88 Condition condition = TreeBuilder.makeCondition(inExpr);
89
90 Schema schema = new Schema(argsSchema);
91 Filter filter = Filter.make(schema, condition);
92
93 int numRows = 16;
94 byte[] validity = new byte[] {(byte) 255, 0};
95 // second half is "undefined"
96 String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
97 "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
98 "sixteen"};
99 int[] expected = {0, 1, 2, 3};
100 ArrowBuf c1Validity = buf(validity);
101 ArrowBuf c2Validity = buf(validity);
102 List<ArrowBuf> dataBufsX = stringBufs(c1Values);
103
104 ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
105 ArrowRecordBatch batch =
106 new ArrowRecordBatch(
107 numRows,
108 Lists.newArrayList(fieldNode),
109 Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity));
110
111 ArrowBuf selectionBuffer = buf(numRows * 2);
112 SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
113
114 filter.evaluate(batch, selectionVector);
115
116 int[] actual = selectionVectorToArray(selectionVector);
117 releaseRecordBatch(batch);
118 selectionBuffer.close();
119 filter.close();
120 Assert.assertArrayEquals(expected, actual);
121 }
122
123 @Test
124 public void testSimpleInInt() throws GandivaException, Exception {
125 Field c1 = Field.nullable("c1", int32);
126
127 List<Field> argsSchema = Lists.newArrayList(c1);
128 TreeNode inExpr =
129 TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4));
130
131 Condition condition = TreeBuilder.makeCondition(inExpr);
132
133 Schema schema = new Schema(argsSchema);
134 Filter filter = Filter.make(schema, condition);
135
136 int numRows = 16;
137 byte[] validity = new byte[] {(byte) 255, 0};
138 // second half is "undefined"
139 int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
140 int[] expected = {0, 1, 2, 3};
141
142 ArrowBuf validitya = buf(validity);
143 ArrowBuf validityb = buf(validity);
144 ArrowBuf valuesa = intBuf(aValues);
145
146 ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
147 ArrowRecordBatch batch =
148 new ArrowRecordBatch(
149 numRows,
150 Lists.newArrayList(fieldNode),
151 Lists.newArrayList(validitya, valuesa, validityb));
152
153 ArrowBuf selectionBuffer = buf(numRows * 2);
154 SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
155
156 filter.evaluate(batch, selectionVector);
157
158 // free buffers
159 int[] actual = selectionVectorToArray(selectionVector);
160 releaseRecordBatch(batch);
161 selectionBuffer.close();
162 filter.close();
163 Assert.assertArrayEquals(expected, actual);
164 }
165
166 @Test
167 public void testSimpleSV16() throws GandivaException, Exception {
168 Field a = Field.nullable("a", int32);
169 Field b = Field.nullable("b", int32);
170 List<Field> args = Lists.newArrayList(a, b);
171
172 Condition condition = TreeBuilder.makeCondition("less_than", args);
173
174 Schema schema = new Schema(args);
175 Filter filter = Filter.make(schema, condition);
176
177 int numRows = 16;
178 byte[] validity = new byte[] {(byte) 255, 0};
179 // second half is "undefined"
180 int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
181 int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
182 int[] expected = {0, 2, 4, 6};
183
184 verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
185 }
186
187 @Test
188 public void testSimpleSV16_AllMatched() throws GandivaException, Exception {
189 Field a = Field.nullable("a", int32);
190 Field b = Field.nullable("b", int32);
191 List<Field> args = Lists.newArrayList(a, b);
192
193 Condition condition = TreeBuilder.makeCondition("less_than", args);
194
195 Schema schema = new Schema(args);
196 Filter filter = Filter.make(schema, condition);
197
198 int numRows = 32;
199
200 byte[] validity = new byte[numRows / 8];
201
202 IntStream.range(0, numRows / 8).forEach(i -> validity[i] = (byte) 255);
203
204 int[] aValues = new int[numRows];
205 IntStream.range(0, numRows).forEach(i -> aValues[i] = i);
206
207 int[] bValues = new int[numRows];
208 IntStream.range(0, numRows).forEach(i -> bValues[i] = i + 1);
209
210 int[] expected = new int[numRows];
211 IntStream.range(0, numRows).forEach(i -> expected[i] = i);
212
213 verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
214 }
215
216 @Test
217 public void testSimpleSV16_GreaterThan64Recs() throws GandivaException, Exception {
218 Field a = Field.nullable("a", int32);
219 Field b = Field.nullable("b", int32);
220 List<Field> args = Lists.newArrayList(a, b);
221
222 Condition condition = TreeBuilder.makeCondition("greater_than", args);
223
224 Schema schema = new Schema(args);
225 Filter filter = Filter.make(schema, condition);
226
227 int numRows = 1000;
228
229 byte[] validity = new byte[numRows / 8];
230
231 IntStream.range(0, numRows / 8).forEach(i -> validity[i] = (byte) 255);
232
233 int[] aValues = new int[numRows];
234 IntStream.range(0, numRows).forEach(i -> aValues[i] = i);
235
236 int[] bValues = new int[numRows];
237 IntStream.range(0, numRows).forEach(i -> bValues[i] = i + 1);
238
239 aValues[0] = 5;
240 bValues[0] = 0;
241
242 int[] expected = {0};
243
244 verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
245 }
246
247 @Test
248 public void testSimpleSV32() throws GandivaException, Exception {
249 Field a = Field.nullable("a", int32);
250 Field b = Field.nullable("b", int32);
251 List<Field> args = Lists.newArrayList(a, b);
252
253 Condition condition = TreeBuilder.makeCondition("less_than", args);
254
255 Schema schema = new Schema(args);
256 Filter filter = Filter.make(schema, condition);
257
258 int numRows = 16;
259 byte[] validity = new byte[] {(byte) 255, 0};
260 // second half is "undefined"
261 int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
262 int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
263 int[] expected = {0, 2, 4, 6};
264
265 verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
266 }
267
268 @Test
269 public void testSimpleFilterWithNoOptimisation() throws GandivaException, Exception {
270 Field a = Field.nullable("a", int32);
271 Field b = Field.nullable("b", int32);
272 List<Field> args = Lists.newArrayList(a, b);
273
274 Condition condition = TreeBuilder.makeCondition("less_than", args);
275
276 Schema schema = new Schema(args);
277 Filter filter = Filter.make(schema, condition, false);
278
279 int numRows = 16;
280 byte[] validity = new byte[] {(byte) 255, 0};
281 // second half is "undefined"
282 int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
283 int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
284 int[] expected = {0, 2, 4, 6};
285
286 verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
287 }
288
289 private void verifyTestCase(
290 Filter filter, int numRows, byte[] validity, int[] aValues, int[] bValues, int[] expected)
291 throws GandivaException {
292 ArrowBuf validitya = buf(validity);
293 ArrowBuf valuesa = intBuf(aValues);
294 ArrowBuf validityb = buf(validity);
295 ArrowBuf valuesb = intBuf(bValues);
296 ArrowRecordBatch batch =
297 new ArrowRecordBatch(
298 numRows,
299 Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
300 Lists.newArrayList(validitya, valuesa, validityb, valuesb));
301
302 ArrowBuf selectionBuffer = buf(numRows * 2);
303 SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
304
305 filter.evaluate(batch, selectionVector);
306
307 // free buffers
308 int[] actual = selectionVectorToArray(selectionVector);
309 releaseRecordBatch(batch);
310 selectionBuffer.close();
311 filter.close();
312
313 Assert.assertArrayEquals(expected, actual);
314 }
315 }