]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include <gtest/gtest.h> | |
19 | #include "arrow/memory_pool.h" | |
20 | #include "gandiva/filter.h" | |
21 | #include "gandiva/projector.h" | |
22 | #include "gandiva/selection_vector.h" | |
23 | #include "gandiva/tests/test_util.h" | |
24 | #include "gandiva/tree_expr_builder.h" | |
25 | ||
26 | namespace gandiva { | |
27 | ||
28 | using arrow::boolean; | |
29 | using arrow::float32; | |
30 | using arrow::int32; | |
31 | ||
32 | class TestFilterProject : public ::testing::Test { | |
33 | public: | |
34 | void SetUp() { pool_ = arrow::default_memory_pool(); } | |
35 | ||
36 | protected: | |
37 | arrow::MemoryPool* pool_; | |
38 | }; | |
39 | ||
40 | TEST_F(TestFilterProject, TestSimple16) { | |
41 | // schema for input fields | |
42 | auto field0 = field("f0", int32()); | |
43 | auto field1 = field("f1", int32()); | |
44 | auto field2 = field("f2", int32()); | |
45 | auto resultField = field("result", int32()); | |
46 | auto schema = arrow::schema({field0, field1, field2}); | |
47 | ||
48 | // Build condition f0 < f1 | |
49 | auto node_f0 = TreeExprBuilder::MakeField(field0); | |
50 | auto node_f1 = TreeExprBuilder::MakeField(field1); | |
51 | auto node_f2 = TreeExprBuilder::MakeField(field2); | |
52 | auto less_than_function = | |
53 | TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean()); | |
54 | auto condition = TreeExprBuilder::MakeCondition(less_than_function); | |
55 | auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField); | |
56 | ||
57 | auto configuration = TestConfiguration(); | |
58 | ||
59 | std::shared_ptr<Filter> filter; | |
60 | std::shared_ptr<Projector> projector; | |
61 | ||
62 | auto status = Filter::Make(schema, condition, configuration, &filter); | |
63 | EXPECT_TRUE(status.ok()); | |
64 | ||
65 | status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT16, | |
66 | configuration, &projector); | |
67 | EXPECT_TRUE(status.ok()); | |
68 | ||
69 | // Create a row-batch with some sample data | |
70 | int num_records = 5; | |
71 | auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true}); | |
72 | auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true}); | |
73 | auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false}); | |
74 | // expected output | |
75 | auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false}); | |
76 | // prepare input record batch | |
77 | auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); | |
78 | ||
79 | std::shared_ptr<SelectionVector> selection_vector; | |
80 | status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); | |
81 | EXPECT_TRUE(status.ok()); | |
82 | // Evaluate expression | |
83 | status = filter->Evaluate(*in_batch, selection_vector); | |
84 | EXPECT_TRUE(status.ok()); | |
85 | ||
86 | // Evaluate expression | |
87 | arrow::ArrayVector outputs; | |
88 | ||
89 | status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); | |
90 | EXPECT_TRUE(status.ok()); | |
91 | ||
92 | // Validate results | |
93 | EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0)); | |
94 | } | |
95 | ||
96 | TEST_F(TestFilterProject, TestSimple32) { | |
97 | // schema for input fields | |
98 | auto field0 = field("f0", int32()); | |
99 | auto field1 = field("f1", int32()); | |
100 | auto field2 = field("f2", int32()); | |
101 | auto resultField = field("result", int32()); | |
102 | auto schema = arrow::schema({field0, field1, field2}); | |
103 | ||
104 | // Build condition f0 < f1 | |
105 | auto node_f0 = TreeExprBuilder::MakeField(field0); | |
106 | auto node_f1 = TreeExprBuilder::MakeField(field1); | |
107 | auto node_f2 = TreeExprBuilder::MakeField(field2); | |
108 | auto less_than_function = | |
109 | TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean()); | |
110 | auto condition = TreeExprBuilder::MakeCondition(less_than_function); | |
111 | auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField); | |
112 | ||
113 | auto configuration = TestConfiguration(); | |
114 | ||
115 | std::shared_ptr<Filter> filter; | |
116 | std::shared_ptr<Projector> projector; | |
117 | ||
118 | auto status = Filter::Make(schema, condition, configuration, &filter); | |
119 | EXPECT_TRUE(status.ok()); | |
120 | ||
121 | status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT32, | |
122 | configuration, &projector); | |
123 | EXPECT_TRUE(status.ok()); | |
124 | ||
125 | // Create a row-batch with some sample data | |
126 | int num_records = 5; | |
127 | auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true}); | |
128 | auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true}); | |
129 | auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false}); | |
130 | // expected output | |
131 | auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false}); | |
132 | // prepare input record batch | |
133 | auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); | |
134 | ||
135 | std::shared_ptr<SelectionVector> selection_vector; | |
136 | status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector); | |
137 | EXPECT_TRUE(status.ok()); | |
138 | // Evaluate expression | |
139 | status = filter->Evaluate(*in_batch, selection_vector); | |
140 | EXPECT_TRUE(status.ok()); | |
141 | ||
142 | // Evaluate expression | |
143 | arrow::ArrayVector outputs; | |
144 | ||
145 | status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); | |
146 | ASSERT_OK(status); | |
147 | ||
148 | // Validate results | |
149 | EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0)); | |
150 | } | |
151 | ||
152 | TEST_F(TestFilterProject, TestSimple64) { | |
153 | // schema for input fields | |
154 | auto field0 = field("f0", int32()); | |
155 | auto field1 = field("f1", int32()); | |
156 | auto field2 = field("f2", int32()); | |
157 | auto resultField = field("result", int32()); | |
158 | auto schema = arrow::schema({field0, field1, field2}); | |
159 | ||
160 | // Build condition f0 < f1 | |
161 | auto node_f0 = TreeExprBuilder::MakeField(field0); | |
162 | auto node_f1 = TreeExprBuilder::MakeField(field1); | |
163 | auto node_f2 = TreeExprBuilder::MakeField(field2); | |
164 | auto less_than_function = | |
165 | TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean()); | |
166 | auto condition = TreeExprBuilder::MakeCondition(less_than_function); | |
167 | auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField); | |
168 | ||
169 | auto configuration = TestConfiguration(); | |
170 | ||
171 | std::shared_ptr<Filter> filter; | |
172 | std::shared_ptr<Projector> projector; | |
173 | ||
174 | auto status = Filter::Make(schema, condition, configuration, &filter); | |
175 | EXPECT_TRUE(status.ok()); | |
176 | ||
177 | status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT64, | |
178 | configuration, &projector); | |
179 | ASSERT_OK(status); | |
180 | ||
181 | // Create a row-batch with some sample data | |
182 | int num_records = 5; | |
183 | auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true}); | |
184 | auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true}); | |
185 | auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false}); | |
186 | // expected output | |
187 | auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false}); | |
188 | // prepare input record batch | |
189 | auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); | |
190 | ||
191 | std::shared_ptr<SelectionVector> selection_vector; | |
192 | status = SelectionVector::MakeInt64(num_records, pool_, &selection_vector); | |
193 | EXPECT_TRUE(status.ok()); | |
194 | // Evaluate expression | |
195 | status = filter->Evaluate(*in_batch, selection_vector); | |
196 | EXPECT_TRUE(status.ok()); | |
197 | ||
198 | // Evaluate expression | |
199 | arrow::ArrayVector outputs; | |
200 | ||
201 | status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); | |
202 | EXPECT_TRUE(status.ok()); | |
203 | ||
204 | // Validate results | |
205 | EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0)); | |
206 | } | |
207 | ||
208 | TEST_F(TestFilterProject, TestSimpleIf) { | |
209 | // schema for input fields | |
210 | auto fielda = field("a", int32()); | |
211 | auto fieldb = field("b", int32()); | |
212 | auto fieldc = field("c", int32()); | |
213 | auto schema = arrow::schema({fielda, fieldb, fieldc}); | |
214 | ||
215 | // output fields | |
216 | auto field_result = field("res", int32()); | |
217 | ||
218 | auto node_a = TreeExprBuilder::MakeField(fielda); | |
219 | auto node_b = TreeExprBuilder::MakeField(fieldb); | |
220 | auto node_c = TreeExprBuilder::MakeField(fieldc); | |
221 | ||
222 | auto greater_than_function = | |
223 | TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); | |
224 | auto filter_condition = TreeExprBuilder::MakeCondition(greater_than_function); | |
225 | ||
226 | auto project_condition = | |
227 | TreeExprBuilder::MakeFunction("less_than", {node_b, node_c}, boolean()); | |
228 | auto if_node = TreeExprBuilder::MakeIf(project_condition, node_b, node_c, int32()); | |
229 | ||
230 | auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); | |
231 | auto configuration = TestConfiguration(); | |
232 | ||
233 | // Build a filter for the expressions. | |
234 | std::shared_ptr<Filter> filter; | |
235 | auto status = Filter::Make(schema, filter_condition, configuration, &filter); | |
236 | EXPECT_TRUE(status.ok()); | |
237 | ||
238 | // Build a projector for the expressions. | |
239 | std::shared_ptr<Projector> projector; | |
240 | status = Projector::Make(schema, {expr}, SelectionVector::MODE_UINT32, configuration, | |
241 | &projector); | |
242 | ASSERT_OK(status); | |
243 | ||
244 | // Create a row-batch with some sample data | |
245 | int num_records = 6; | |
246 | auto array0 = | |
247 | MakeArrowArrayInt32({10, 12, -20, 5, 21, 29}, {true, true, true, true, true, true}); | |
248 | auto array1 = | |
249 | MakeArrowArrayInt32({5, 15, 15, 17, 12, 3}, {true, true, true, true, true, true}); | |
250 | auto array2 = MakeArrowArrayInt32({1, 25, 11, 30, -21, 30}, | |
251 | {true, true, true, true, true, false}); | |
252 | ||
253 | // Create a selection vector | |
254 | std::shared_ptr<SelectionVector> selection_vector; | |
255 | status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector); | |
256 | EXPECT_TRUE(status.ok()); | |
257 | ||
258 | // expected output | |
259 | auto exp = MakeArrowArrayInt32({1, -21, 0}, {true, true, false}); | |
260 | ||
261 | // prepare input record batch | |
262 | auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); | |
263 | ||
264 | // Evaluate filter | |
265 | status = filter->Evaluate(*in_batch, selection_vector); | |
266 | EXPECT_TRUE(status.ok()); | |
267 | ||
268 | // Evaluate project | |
269 | arrow::ArrayVector outputs; | |
270 | status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); | |
271 | EXPECT_TRUE(status.ok()); | |
272 | ||
273 | // Validate results | |
274 | EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); | |
275 | } | |
276 | } // namespace gandiva |