]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include <algorithm> | |
19 | #include <cstdint> | |
20 | #include <cstdio> | |
21 | #include <functional> | |
22 | #include <iosfwd> | |
23 | #include <locale> | |
24 | #include <memory> | |
25 | #include <stdexcept> | |
26 | #include <string> | |
27 | #include <vector> | |
28 | ||
29 | #include <gtest/gtest.h> | |
30 | ||
31 | #include "arrow/array/array_base.h" | |
32 | #include "arrow/array/builder_binary.h" | |
33 | #include "arrow/array/builder_primitive.h" | |
34 | #include "arrow/chunked_array.h" | |
35 | #include "arrow/compute/api.h" | |
36 | #include "arrow/compute/kernels/test_util.h" | |
37 | #include "arrow/result.h" | |
38 | #include "arrow/status.h" | |
39 | #include "arrow/testing/gtest_compat.h" | |
40 | #include "arrow/testing/gtest_util.h" | |
41 | #include "arrow/type.h" | |
42 | #include "arrow/type_traits.h" | |
43 | ||
44 | namespace arrow { | |
45 | namespace compute { | |
46 | ||
47 | // ---------------------------------------------------------------------- | |
48 | // IsIn tests | |
49 | ||
50 | void CheckIsIn(const std::shared_ptr<DataType>& type, const std::string& input_json, | |
51 | const std::string& value_set_json, const std::string& expected_json, | |
52 | bool skip_nulls = false) { | |
53 | auto input = ArrayFromJSON(type, input_json); | |
54 | auto value_set = ArrayFromJSON(type, value_set_json); | |
55 | auto expected = ArrayFromJSON(boolean(), expected_json); | |
56 | ||
57 | ASSERT_OK_AND_ASSIGN(Datum actual_datum, | |
58 | IsIn(input, SetLookupOptions(value_set, skip_nulls))); | |
59 | std::shared_ptr<Array> actual = actual_datum.make_array(); | |
60 | ValidateOutput(actual_datum); | |
61 | AssertArraysEqual(*expected, *actual, /*verbose=*/true); | |
62 | } | |
63 | ||
64 | void CheckIsInChunked(const std::shared_ptr<ChunkedArray>& input, | |
65 | const std::shared_ptr<ChunkedArray>& value_set, | |
66 | const std::shared_ptr<ChunkedArray>& expected, | |
67 | bool skip_nulls = false) { | |
68 | ASSERT_OK_AND_ASSIGN(Datum actual_datum, | |
69 | IsIn(input, SetLookupOptions(value_set, skip_nulls))); | |
70 | auto actual = actual_datum.chunked_array(); | |
71 | ValidateOutput(actual_datum); | |
72 | AssertChunkedEqual(*expected, *actual); | |
73 | } | |
74 | ||
75 | void CheckIsInDictionary(const std::shared_ptr<DataType>& type, | |
76 | const std::shared_ptr<DataType>& index_type, | |
77 | const std::string& input_dictionary_json, | |
78 | const std::string& input_index_json, | |
79 | const std::string& value_set_json, | |
80 | const std::string& expected_json, bool skip_nulls = false) { | |
81 | auto dict_type = dictionary(index_type, type); | |
82 | auto indices = ArrayFromJSON(index_type, input_index_json); | |
83 | auto dict = ArrayFromJSON(type, input_dictionary_json); | |
84 | ||
85 | ASSERT_OK_AND_ASSIGN(auto input, DictionaryArray::FromArrays(dict_type, indices, dict)); | |
86 | auto value_set = ArrayFromJSON(type, value_set_json); | |
87 | auto expected = ArrayFromJSON(boolean(), expected_json); | |
88 | ||
89 | ASSERT_OK_AND_ASSIGN(Datum actual_datum, | |
90 | IsIn(input, SetLookupOptions(value_set, skip_nulls))); | |
91 | std::shared_ptr<Array> actual = actual_datum.make_array(); | |
92 | ValidateOutput(actual_datum); | |
93 | AssertArraysEqual(*expected, *actual, /*verbose=*/true); | |
94 | } | |
95 | ||
96 | class TestIsInKernel : public ::testing::Test {}; | |
97 | ||
98 | TEST_F(TestIsInKernel, CallBinary) { | |
99 | auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]"); | |
100 | auto value_set = ArrayFromJSON(int8(), "[2, 3, 5, 7]"); | |
101 | ASSERT_RAISES(Invalid, CallFunction("is_in", {input, value_set})); | |
102 | ||
103 | ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("is_in_meta_binary", {input, value_set})); | |
104 | auto expected = ArrayFromJSON(boolean(), ("[false, false, true, true, false," | |
105 | "true, false, true, false]")); | |
106 | AssertArraysEqual(*expected, *out.make_array()); | |
107 | } | |
108 | ||
109 | TEST_F(TestIsInKernel, ImplicitlyCastValueSet) { | |
110 | auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]"); | |
111 | ||
112 | SetLookupOptions opts{ArrayFromJSON(int32(), "[2, 3, 5, 7]")}; | |
113 | ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("is_in", {input}, &opts)); | |
114 | ||
115 | auto expected = ArrayFromJSON(boolean(), ("[false, false, true, true, false," | |
116 | "true, false, true, false]")); | |
117 | AssertArraysEqual(*expected, *out.make_array()); | |
118 | ||
119 | // fails; value_set cannot be cast to int8 | |
120 | opts = SetLookupOptions{ArrayFromJSON(float32(), "[2.5, 3.1, 5.0]")}; | |
121 | ASSERT_RAISES(Invalid, CallFunction("is_in", {input}, &opts)); | |
122 | } | |
123 | ||
124 | template <typename Type> | |
125 | class TestIsInKernelPrimitive : public ::testing::Test {}; | |
126 | ||
127 | template <typename Type> | |
128 | class TestIsInKernelBinary : public ::testing::Test {}; | |
129 | ||
130 | using PrimitiveTypes = ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, | |
131 | Int32Type, UInt32Type, Int64Type, UInt64Type, | |
132 | FloatType, DoubleType, Date32Type, Date64Type>; | |
133 | ||
134 | TYPED_TEST_SUITE(TestIsInKernelPrimitive, PrimitiveTypes); | |
135 | ||
136 | TYPED_TEST(TestIsInKernelPrimitive, IsIn) { | |
137 | auto type = TypeTraits<TypeParam>::type_singleton(); | |
138 | ||
139 | // No Nulls | |
140 | CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]"); | |
141 | ||
142 | // Nulls in left array | |
143 | CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]", | |
144 | /*skip_nulls=*/false); | |
145 | CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, 1]", "[false, true, true, false, true]", | |
146 | /*skip_nulls=*/true); | |
147 | ||
148 | // Nulls in right array | |
149 | CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", | |
150 | /*skip_nulls=*/false); | |
151 | CheckIsIn(type, "[0, 1, 2, 3, 2]", "[2, null, 1]", "[false, true, true, false, true]", | |
152 | /*skip_nulls=*/true); | |
153 | ||
154 | // Nulls in both the arrays | |
155 | CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", "[true, true, true, false, true]", | |
156 | /*skip_nulls=*/false); | |
157 | CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]", | |
158 | "[false, true, true, false, true]", /*skip_nulls=*/true); | |
159 | ||
160 | // Duplicates in right array | |
161 | CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", | |
162 | "[true, true, true, false, true]", | |
163 | /*skip_nulls=*/false); | |
164 | CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]", | |
165 | "[false, true, true, false, true]", /*skip_nulls=*/true); | |
166 | ||
167 | // Empty Arrays | |
168 | CheckIsIn(type, "[]", "[]", "[]"); | |
169 | } | |
170 | ||
171 | TEST_F(TestIsInKernel, NullType) { | |
172 | auto type = null(); | |
173 | ||
174 | CheckIsIn(type, "[null, null, null]", "[null]", "[true, true, true]"); | |
175 | CheckIsIn(type, "[null, null, null]", "[]", "[false, false, false]"); | |
176 | CheckIsIn(type, "[]", "[]", "[]"); | |
177 | ||
178 | CheckIsIn(type, "[null, null]", "[null]", "[false, false]", /*skip_nulls=*/true); | |
179 | CheckIsIn(type, "[null, null]", "[]", "[false, false]", /*skip_nulls=*/true); | |
180 | ||
181 | // Duplicates in right array | |
182 | CheckIsIn(type, "[null, null, null]", "[null, null]", "[true, true, true]"); | |
183 | CheckIsIn(type, "[null, null]", "[null, null]", "[false, false]", /*skip_nulls=*/true); | |
184 | } | |
185 | ||
186 | TEST_F(TestIsInKernel, TimeTimestamp) { | |
187 | for (const auto& type : | |
188 | {time32(TimeUnit::SECOND), time64(TimeUnit::NANO), timestamp(TimeUnit::MICRO)}) { | |
189 | CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", | |
190 | "[true, true, false, true, true]", /*skip_nulls=*/false); | |
191 | CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]", | |
192 | "[true, false, false, true, true]", /*skip_nulls=*/true); | |
193 | ||
194 | // Duplicates in right array | |
195 | CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", | |
196 | "[true, true, false, true, true]", /*skip_nulls=*/false); | |
197 | CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]", | |
198 | "[true, false, false, true, true]", /*skip_nulls=*/true); | |
199 | } | |
200 | } | |
201 | ||
202 | TEST_F(TestIsInKernel, Boolean) { | |
203 | auto type = boolean(); | |
204 | ||
205 | CheckIsIn(type, "[true, false, null, true, false]", "[false]", | |
206 | "[false, true, false, false, true]", /*skip_nulls=*/false); | |
207 | CheckIsIn(type, "[true, false, null, true, false]", "[false]", | |
208 | "[false, true, false, false, true]", /*skip_nulls=*/true); | |
209 | ||
210 | CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", | |
211 | "[false, true, true, false, true]", /*skip_nulls=*/false); | |
212 | CheckIsIn(type, "[true, false, null, true, false]", "[false, null]", | |
213 | "[false, true, false, false, true]", /*skip_nulls=*/true); | |
214 | ||
215 | // Duplicates in right array | |
216 | CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", | |
217 | "[false, true, true, false, true]", /*skip_nulls=*/false); | |
218 | CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, null]", | |
219 | "[false, true, false, false, true]", /*skip_nulls=*/true); | |
220 | } | |
221 | ||
222 | TYPED_TEST_SUITE(TestIsInKernelBinary, BinaryArrowTypes); | |
223 | ||
224 | TYPED_TEST(TestIsInKernelBinary, Binary) { | |
225 | auto type = TypeTraits<TypeParam>::type_singleton(); | |
226 | ||
227 | CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])", | |
228 | "[true, true, false, false, true]", | |
229 | /*skip_nulls=*/false); | |
230 | CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", ""])", | |
231 | "[true, true, false, false, true]", | |
232 | /*skip_nulls=*/true); | |
233 | ||
234 | CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", | |
235 | "[true, true, false, true, true]", | |
236 | /*skip_nulls=*/false); | |
237 | CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])", | |
238 | "[true, true, false, false, true]", | |
239 | /*skip_nulls=*/true); | |
240 | ||
241 | // Duplicates in right array | |
242 | CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", | |
243 | R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, true, true]", | |
244 | /*skip_nulls=*/false); | |
245 | CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", | |
246 | R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, false, true]", | |
247 | /*skip_nulls=*/true); | |
248 | } | |
249 | ||
250 | TEST_F(TestIsInKernel, FixedSizeBinary) { | |
251 | auto type = fixed_size_binary(3); | |
252 | ||
253 | CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])", | |
254 | "[true, true, false, false, true]", | |
255 | /*skip_nulls=*/false); | |
256 | CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb"])", | |
257 | "[true, true, false, false, true]", | |
258 | /*skip_nulls=*/true); | |
259 | ||
260 | CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", | |
261 | "[true, true, false, true, true]", | |
262 | /*skip_nulls=*/false); | |
263 | CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", null])", | |
264 | "[true, true, false, false, true]", | |
265 | /*skip_nulls=*/true); | |
266 | ||
267 | // Duplicates in right array | |
268 | CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", | |
269 | R"(["aaa", null, "aaa", "bbb", "bbb", null])", | |
270 | "[true, true, false, true, true]", | |
271 | /*skip_nulls=*/false); | |
272 | CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", | |
273 | R"(["aaa", null, "aaa", "bbb", "bbb", null])", | |
274 | "[true, true, false, false, true]", | |
275 | /*skip_nulls=*/true); | |
276 | } | |
277 | ||
278 | TEST_F(TestIsInKernel, Decimal) { | |
279 | auto type = decimal(3, 1); | |
280 | ||
281 | CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])", | |
282 | "[true, false, true, false, true]", | |
283 | /*skip_nulls=*/false); | |
284 | CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", R"(["12.3", "78.9"])", | |
285 | "[true, false, true, false, true]", | |
286 | /*skip_nulls=*/true); | |
287 | ||
288 | CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", | |
289 | R"(["12.3", "78.9", null])", "[true, false, true, true, true]", | |
290 | /*skip_nulls=*/false); | |
291 | CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", | |
292 | R"(["12.3", "78.9", null])", "[true, false, true, false, true]", | |
293 | /*skip_nulls=*/true); | |
294 | ||
295 | // Duplicates in right array | |
296 | CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", | |
297 | R"([null, "12.3", "12.3", "78.9", "78.9", null])", | |
298 | "[true, false, true, true, true]", | |
299 | /*skip_nulls=*/false); | |
300 | CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])", | |
301 | R"([null, "12.3", "12.3", "78.9", "78.9", null])", | |
302 | "[true, false, true, false, true]", | |
303 | /*skip_nulls=*/true); | |
304 | } | |
305 | ||
306 | TEST_F(TestIsInKernel, DictionaryArray) { | |
307 | for (auto index_ty : all_dictionary_index_types()) { | |
308 | CheckIsInDictionary(/*type=*/utf8(), | |
309 | /*index_type=*/index_ty, | |
310 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
311 | /*input_index_json=*/"[1, 2, null, 0]", | |
312 | /*value_set_json=*/R"(["A", "B", "C"])", | |
313 | /*expected_json=*/"[true, true, false, true]", | |
314 | /*skip_nulls=*/false); | |
315 | CheckIsInDictionary(/*type=*/float32(), | |
316 | /*index_type=*/index_ty, | |
317 | /*input_dictionary_json=*/"[4.1, -1.0, 42, 9.8]", | |
318 | /*input_index_json=*/"[1, 2, null, 0]", | |
319 | /*value_set_json=*/"[4.1, 42, -1.0]", | |
320 | /*expected_json=*/"[true, true, false, true]", | |
321 | /*skip_nulls=*/false); | |
322 | ||
323 | // With nulls and skip_nulls=false | |
324 | CheckIsInDictionary(/*type=*/utf8(), | |
325 | /*index_type=*/index_ty, | |
326 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
327 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
328 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
329 | /*expected_json=*/"[true, false, true, true, true]", | |
330 | /*skip_nulls=*/false); | |
331 | CheckIsInDictionary(/*type=*/utf8(), | |
332 | /*index_type=*/index_ty, | |
333 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
334 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
335 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
336 | /*expected_json=*/"[true, false, true, true, true]", | |
337 | /*skip_nulls=*/false); | |
338 | CheckIsInDictionary(/*type=*/utf8(), | |
339 | /*index_type=*/index_ty, | |
340 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
341 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
342 | /*value_set_json=*/R"(["C", "B", "A"])", | |
343 | /*expected_json=*/"[false, false, false, true, false]", | |
344 | /*skip_nulls=*/false); | |
345 | ||
346 | // With nulls and skip_nulls=true | |
347 | CheckIsInDictionary(/*type=*/utf8(), | |
348 | /*index_type=*/index_ty, | |
349 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
350 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
351 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
352 | /*expected_json=*/"[true, false, false, true, true]", | |
353 | /*skip_nulls=*/true); | |
354 | CheckIsInDictionary(/*type=*/utf8(), | |
355 | /*index_type=*/index_ty, | |
356 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
357 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
358 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
359 | /*expected_json=*/"[false, false, false, true, false]", | |
360 | /*skip_nulls=*/true); | |
361 | CheckIsInDictionary(/*type=*/utf8(), | |
362 | /*index_type=*/index_ty, | |
363 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
364 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
365 | /*value_set_json=*/R"(["C", "B", "A"])", | |
366 | /*expected_json=*/"[false, false, false, true, false]", | |
367 | /*skip_nulls=*/true); | |
368 | ||
369 | // With duplicates in value_set | |
370 | CheckIsInDictionary(/*type=*/utf8(), | |
371 | /*index_type=*/index_ty, | |
372 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
373 | /*input_index_json=*/"[1, 2, null, 0]", | |
374 | /*value_set_json=*/R"(["A", "A", "B", "A", "B", "C"])", | |
375 | /*expected_json=*/"[true, true, false, true]", | |
376 | /*skip_nulls=*/false); | |
377 | CheckIsInDictionary(/*type=*/utf8(), | |
378 | /*index_type=*/index_ty, | |
379 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
380 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
381 | /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])", | |
382 | /*expected_json=*/"[true, false, true, true, true]", | |
383 | /*skip_nulls=*/false); | |
384 | CheckIsInDictionary(/*type=*/utf8(), | |
385 | /*index_type=*/index_ty, | |
386 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
387 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
388 | /*value_set_json=*/R"(["C", "C", "B", "A", null, null, "B"])", | |
389 | /*expected_json=*/"[true, false, false, true, true]", | |
390 | /*skip_nulls=*/true); | |
391 | } | |
392 | } | |
393 | ||
394 | TEST_F(TestIsInKernel, ChunkedArrayInvoke) { | |
395 | auto input = ChunkedArrayFromJSON( | |
396 | utf8(), {R"(["abc", "def", "", "abc", "jkl"])", R"(["def", null, "abc", "zzz"])"}); | |
397 | // No null in value_set | |
398 | auto value_set = ChunkedArrayFromJSON(utf8(), {R"(["", "def"])", R"(["abc"])"}); | |
399 | auto expected = ChunkedArrayFromJSON( | |
400 | boolean(), {"[true, true, true, true, false]", "[true, false, true, false]"}); | |
401 | ||
402 | CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false); | |
403 | CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true); | |
404 | ||
405 | value_set = ChunkedArrayFromJSON(utf8(), {R"(["", "def"])", R"([null])"}); | |
406 | expected = ChunkedArrayFromJSON( | |
407 | boolean(), {"[false, true, true, false, false]", "[true, true, false, false]"}); | |
408 | CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false); | |
409 | expected = ChunkedArrayFromJSON( | |
410 | boolean(), {"[false, true, true, false, false]", "[true, false, false, false]"}); | |
411 | CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true); | |
412 | ||
413 | // Duplicates in value_set | |
414 | value_set = | |
415 | ChunkedArrayFromJSON(utf8(), {R"(["", null, "", "def"])", R"(["def", null])"}); | |
416 | expected = ChunkedArrayFromJSON( | |
417 | boolean(), {"[false, true, true, false, false]", "[true, true, false, false]"}); | |
418 | CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false); | |
419 | expected = ChunkedArrayFromJSON( | |
420 | boolean(), {"[false, true, true, false, false]", "[true, false, false, false]"}); | |
421 | CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true); | |
422 | } | |
423 | ||
424 | // ---------------------------------------------------------------------- | |
425 | // IndexIn tests | |
426 | ||
427 | class TestIndexInKernel : public ::testing::Test { | |
428 | public: | |
429 | void CheckIndexIn(const std::shared_ptr<DataType>& type, const std::string& input_json, | |
430 | const std::string& value_set_json, const std::string& expected_json, | |
431 | bool skip_nulls = false) { | |
432 | std::shared_ptr<Array> input = ArrayFromJSON(type, input_json); | |
433 | std::shared_ptr<Array> value_set = ArrayFromJSON(type, value_set_json); | |
434 | std::shared_ptr<Array> expected = ArrayFromJSON(int32(), expected_json); | |
435 | ||
436 | SetLookupOptions options(value_set, skip_nulls); | |
437 | ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, options)); | |
438 | std::shared_ptr<Array> actual = actual_datum.make_array(); | |
439 | ValidateOutput(actual_datum); | |
440 | AssertArraysEqual(*expected, *actual, /*verbose=*/true); | |
441 | } | |
442 | ||
443 | void CheckIndexInChunked(const std::shared_ptr<ChunkedArray>& input, | |
444 | const std::shared_ptr<ChunkedArray>& value_set, | |
445 | const std::shared_ptr<ChunkedArray>& expected, | |
446 | bool skip_nulls) { | |
447 | ASSERT_OK_AND_ASSIGN(Datum actual, | |
448 | IndexIn(input, SetLookupOptions(value_set, skip_nulls))); | |
449 | ASSERT_EQ(Datum::CHUNKED_ARRAY, actual.kind()); | |
450 | ValidateOutput(actual); | |
451 | AssertChunkedEqual(*expected, *actual.chunked_array()); | |
452 | } | |
453 | ||
454 | void CheckIndexInDictionary(const std::shared_ptr<DataType>& type, | |
455 | const std::shared_ptr<DataType>& index_type, | |
456 | const std::string& input_dictionary_json, | |
457 | const std::string& input_index_json, | |
458 | const std::string& value_set_json, | |
459 | const std::string& expected_json, bool skip_nulls = false) { | |
460 | auto dict_type = dictionary(index_type, type); | |
461 | auto indices = ArrayFromJSON(index_type, input_index_json); | |
462 | auto dict = ArrayFromJSON(type, input_dictionary_json); | |
463 | ||
464 | ASSERT_OK_AND_ASSIGN(auto input, | |
465 | DictionaryArray::FromArrays(dict_type, indices, dict)); | |
466 | auto value_set = ArrayFromJSON(type, value_set_json); | |
467 | auto expected = ArrayFromJSON(int32(), expected_json); | |
468 | ||
469 | SetLookupOptions options(value_set, skip_nulls); | |
470 | ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, options)); | |
471 | std::shared_ptr<Array> actual = actual_datum.make_array(); | |
472 | ValidateOutput(actual_datum); | |
473 | AssertArraysEqual(*expected, *actual, /*verbose=*/true); | |
474 | } | |
475 | }; | |
476 | ||
477 | TEST_F(TestIndexInKernel, CallBinary) { | |
478 | auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]"); | |
479 | auto value_set = ArrayFromJSON(int8(), "[2, 3, 5, 7]"); | |
480 | ASSERT_RAISES(Invalid, CallFunction("index_in", {input, value_set})); | |
481 | ||
482 | ASSERT_OK_AND_ASSIGN(Datum out, | |
483 | CallFunction("index_in_meta_binary", {input, value_set})); | |
484 | auto expected = ArrayFromJSON(int32(), ("[null, null, 0, 1, null, 2, null, 3, null," | |
485 | " null, null]")); | |
486 | AssertArraysEqual(*expected, *out.make_array()); | |
487 | } | |
488 | ||
489 | template <typename Type> | |
490 | class TestIndexInKernelPrimitive : public TestIndexInKernel {}; | |
491 | ||
492 | using PrimitiveDictionaries = | |
493 | ::testing::Types<Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type, UInt32Type, | |
494 | Int64Type, UInt64Type, FloatType, DoubleType, Date32Type, | |
495 | Date64Type>; | |
496 | ||
497 | TYPED_TEST_SUITE(TestIndexInKernelPrimitive, PrimitiveDictionaries); | |
498 | ||
499 | TYPED_TEST(TestIndexInKernelPrimitive, IndexIn) { | |
500 | auto type = TypeTraits<TypeParam>::type_singleton(); | |
501 | ||
502 | // No Nulls | |
503 | this->CheckIndexIn(type, | |
504 | /* input= */ "[2, 1, 2, 1, 2, 3]", | |
505 | /* value_set= */ "[2, 1, 3]", | |
506 | /* expected= */ "[0, 1, 0, 1, 0, 2]"); | |
507 | ||
508 | // Haystack array all null | |
509 | this->CheckIndexIn(type, | |
510 | /* input= */ "[null, null, null, null, null, null]", | |
511 | /* value_set= */ "[2, 1, 3]", | |
512 | /* expected= */ "[null, null, null, null, null, null]"); | |
513 | ||
514 | // Needles array all null | |
515 | this->CheckIndexIn(type, | |
516 | /* input= */ "[2, 1, 2, 1, 2, 3]", | |
517 | /* value_set= */ "[null]", | |
518 | /* expected= */ "[null, null, null, null, null, null]"); | |
519 | ||
520 | // Both arrays all null | |
521 | this->CheckIndexIn(type, | |
522 | /* input= */ "[null, null, null, null]", | |
523 | /* value_set= */ "[null]", | |
524 | /* expected= */ "[0, 0, 0, 0]"); | |
525 | ||
526 | // Duplicates in value_set | |
527 | this->CheckIndexIn(type, | |
528 | /* input= */ "[2, 1, 2, 1, 2, 3]", | |
529 | /* value_set= */ "[2, 2, 1, 1, 1, 3, 3]", | |
530 | /* expected= */ "[0, 2, 0, 2, 0, 5]"); | |
531 | ||
532 | // Duplicates and nulls in value_set | |
533 | this->CheckIndexIn(type, | |
534 | /* input= */ "[2, 1, 2, 1, 2, 3]", | |
535 | /* value_set= */ "[2, 2, null, null, 1, 1, 1, 3, 3]", | |
536 | /* expected= */ "[0, 4, 0, 4, 0, 7]"); | |
537 | ||
538 | // No Match | |
539 | this->CheckIndexIn(type, | |
540 | /* input= */ "[2, null, 7, 3, 8]", | |
541 | /* value_set= */ "[2, null, 6, 3]", | |
542 | /* expected= */ "[0, 1, null, 3, null]"); | |
543 | ||
544 | // Empty Arrays | |
545 | this->CheckIndexIn(type, "[]", "[]", "[]"); | |
546 | } | |
547 | ||
548 | TYPED_TEST(TestIndexInKernelPrimitive, SkipNulls) { | |
549 | auto type = TypeTraits<TypeParam>::type_singleton(); | |
550 | ||
551 | // No nulls in value_set | |
552 | this->CheckIndexIn(type, | |
553 | /*input=*/"[0, 1, 2, 3, null]", | |
554 | /*value_set=*/"[1, 3]", | |
555 | /*expected=*/"[null, 0, null, 1, null]", | |
556 | /*skip_nulls=*/false); | |
557 | this->CheckIndexIn(type, | |
558 | /*input=*/"[0, 1, 2, 3, null]", | |
559 | /*value_set=*/"[1, 3]", | |
560 | /*expected=*/"[null, 0, null, 1, null]", | |
561 | /*skip_nulls=*/true); | |
562 | // Same with duplicates in value_set | |
563 | this->CheckIndexIn(type, | |
564 | /*input=*/"[0, 1, 2, 3, null]", | |
565 | /*value_set=*/"[1, 1, 3, 3]", | |
566 | /*expected=*/"[null, 0, null, 2, null]", | |
567 | /*skip_nulls=*/false); | |
568 | this->CheckIndexIn(type, | |
569 | /*input=*/"[0, 1, 2, 3, null]", | |
570 | /*value_set=*/"[1, 1, 3, 3]", | |
571 | /*expected=*/"[null, 0, null, 2, null]", | |
572 | /*skip_nulls=*/true); | |
573 | ||
574 | // Nulls in value_set | |
575 | this->CheckIndexIn(type, | |
576 | /*input=*/"[0, 1, 2, 3, null]", | |
577 | /*value_set=*/"[1, null, 3]", | |
578 | /*expected=*/"[null, 0, null, 2, 1]", | |
579 | /*skip_nulls=*/false); | |
580 | this->CheckIndexIn(type, | |
581 | /*input=*/"[0, 1, 2, 3, null]", | |
582 | /*value_set=*/"[1, 1, null, null, 3, 3]", | |
583 | /*expected=*/"[null, 0, null, 4, null]", | |
584 | /*skip_nulls=*/true); | |
585 | // Same with duplicates in value_set | |
586 | this->CheckIndexIn(type, | |
587 | /*input=*/"[0, 1, 2, 3, null]", | |
588 | /*value_set=*/"[1, 1, null, null, 3, 3]", | |
589 | /*expected=*/"[null, 0, null, 4, 2]", | |
590 | /*skip_nulls=*/false); | |
591 | } | |
592 | ||
593 | TEST_F(TestIndexInKernel, NullType) { | |
594 | CheckIndexIn(null(), "[null, null, null]", "[null]", "[0, 0, 0]"); | |
595 | CheckIndexIn(null(), "[null, null, null]", "[]", "[null, null, null]"); | |
596 | CheckIndexIn(null(), "[]", "[null, null]", "[]"); | |
597 | CheckIndexIn(null(), "[]", "[]", "[]"); | |
598 | ||
599 | CheckIndexIn(null(), "[null, null]", "[null]", "[null, null]", /*skip_nulls=*/true); | |
600 | CheckIndexIn(null(), "[null, null]", "[]", "[null, null]", /*skip_nulls=*/true); | |
601 | } | |
602 | ||
603 | TEST_F(TestIndexInKernel, TimeTimestamp) { | |
604 | CheckIndexIn(time32(TimeUnit::SECOND), | |
605 | /* input= */ "[1, null, 5, 1, 2]", | |
606 | /* value_set= */ "[2, 1, null]", | |
607 | /* expected= */ "[1, 2, null, 1, 0]"); | |
608 | ||
609 | // Duplicates in value_set | |
610 | CheckIndexIn(time32(TimeUnit::SECOND), | |
611 | /* input= */ "[1, null, 5, 1, 2]", | |
612 | /* value_set= */ "[2, 2, 1, 1, null, null]", | |
613 | /* expected= */ "[2, 4, null, 2, 0]"); | |
614 | ||
615 | // Needles array has no nulls | |
616 | CheckIndexIn(time32(TimeUnit::SECOND), | |
617 | /* input= */ "[2, null, 5, 1]", | |
618 | /* value_set= */ "[2, 1]", | |
619 | /* expected= */ "[0, null, null, 1]"); | |
620 | ||
621 | // No match | |
622 | CheckIndexIn(time32(TimeUnit::SECOND), "[3, null, 5, 3]", "[2, 1]", | |
623 | "[null, null, null, null]"); | |
624 | ||
625 | // Empty arrays | |
626 | CheckIndexIn(time32(TimeUnit::SECOND), "[]", "[]", "[]"); | |
627 | ||
628 | CheckIndexIn(time64(TimeUnit::NANO), "[2, null, 2, 1]", "[2, null, 1]", "[0, 1, 0, 2]"); | |
629 | ||
630 | CheckIndexIn(timestamp(TimeUnit::NANO), "[2, null, 2, 1]", "[2, null, 1]", | |
631 | "[0, 1, 0, 2]"); | |
632 | ||
633 | // Empty input array | |
634 | CheckIndexIn(timestamp(TimeUnit::NANO), "[]", "[2, null, 1]", "[]"); | |
635 | ||
636 | // Empty value_set array | |
637 | CheckIndexIn(timestamp(TimeUnit::NANO), "[2, null, 1]", "[]", "[null, null, null]"); | |
638 | ||
639 | // Both array are all null | |
640 | CheckIndexIn(time32(TimeUnit::SECOND), "[null, null, null, null]", "[null]", | |
641 | "[0, 0, 0, 0]"); | |
642 | } | |
643 | ||
644 | TEST_F(TestIndexInKernel, Boolean) { | |
645 | CheckIndexIn(boolean(), | |
646 | /* input= */ "[false, null, false, true]", | |
647 | /* value_set= */ "[null, false, true]", | |
648 | /* expected= */ "[1, 0, 1, 2]"); | |
649 | ||
650 | CheckIndexIn(boolean(), "[false, null, false, true]", "[false, true, null]", | |
651 | "[0, 2, 0, 1]"); | |
652 | ||
653 | // Duplicates in value_set | |
654 | CheckIndexIn(boolean(), "[false, null, false, true]", | |
655 | "[false, false, true, true, null, null]", "[0, 4, 0, 2]"); | |
656 | ||
657 | // No Nulls | |
658 | CheckIndexIn(boolean(), "[true, true, false, true]", "[false, true]", "[1, 1, 0, 1]"); | |
659 | ||
660 | CheckIndexIn(boolean(), "[false, true, false, true]", "[true]", "[null, 0, null, 0]"); | |
661 | ||
662 | // No match | |
663 | CheckIndexIn(boolean(), "[true, true, true, true]", "[false]", | |
664 | "[null, null, null, null]"); | |
665 | ||
666 | // Nulls in input array | |
667 | CheckIndexIn(boolean(), "[null, null, null, null]", "[true]", | |
668 | "[null, null, null, null]"); | |
669 | ||
670 | // Nulls in value_set array | |
671 | CheckIndexIn(boolean(), "[true, true, false, true]", "[null]", | |
672 | "[null, null, null, null]"); | |
673 | ||
674 | // Both array have Nulls | |
675 | CheckIndexIn(boolean(), "[null, null, null, null]", "[null]", "[0, 0, 0, 0]"); | |
676 | } | |
677 | ||
678 | template <typename Type> | |
679 | class TestIndexInKernelBinary : public TestIndexInKernel {}; | |
680 | ||
681 | TYPED_TEST_SUITE(TestIndexInKernelBinary, BinaryArrowTypes); | |
682 | ||
683 | TYPED_TEST(TestIndexInKernelBinary, Binary) { | |
684 | auto type = TypeTraits<TypeParam>::type_singleton(); | |
685 | this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])", R"(["foo", null, "bar"])", | |
686 | R"([0, 1, 2, 0])"); | |
687 | ||
688 | // Duplicates in value_set | |
689 | this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])", | |
690 | R"(["foo", "foo", null, null, "bar", "bar"])", R"([0, 2, 4, 0])"); | |
691 | ||
692 | // No match | |
693 | this->CheckIndexIn(type, | |
694 | /* input= */ R"(["foo", null, "bar", "foo"])", | |
695 | /* value_set= */ R"(["baz", "bazzz"])", | |
696 | /* expected= */ R"([null, null, null, null])"); | |
697 | ||
698 | // Nulls in input array | |
699 | this->CheckIndexIn(type, | |
700 | /* input= */ R"([null, null, null, null])", | |
701 | /* value_set= */ R"(["foo", "bar"])", | |
702 | /* expected= */ R"([null, null, null, null])"); | |
703 | ||
704 | // Nulls in value_set array | |
705 | this->CheckIndexIn(type, R"(["foo", "bar", "foo"])", R"([null])", | |
706 | R"([null, null, null])"); | |
707 | ||
708 | // Both array have Nulls | |
709 | this->CheckIndexIn(type, | |
710 | /* input= */ R"([null, null, null, null])", | |
711 | /* value_set= */ R"([null])", | |
712 | /* expected= */ R"([0, 0, 0, 0])"); | |
713 | ||
714 | // Empty arrays | |
715 | this->CheckIndexIn(type, R"([])", R"([])", R"([])"); | |
716 | ||
717 | // Empty input array | |
718 | this->CheckIndexIn(type, R"([])", R"(["foo", null, "bar"])", "[]"); | |
719 | ||
720 | // Empty value_set array | |
721 | this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])", "[]", | |
722 | R"([null, null, null, null])"); | |
723 | } | |
724 | ||
725 | TEST_F(TestIndexInKernel, BinaryResizeTable) { | |
726 | const int32_t kTotalValues = 10000; | |
727 | #if !defined(ARROW_VALGRIND) | |
728 | const int32_t kRepeats = 10; | |
729 | #else | |
730 | // Mitigate Valgrind's slowness | |
731 | const int32_t kRepeats = 3; | |
732 | #endif | |
733 | ||
734 | const int32_t kBufSize = 20; | |
735 | ||
736 | Int32Builder expected_builder; | |
737 | StringBuilder input_builder; | |
738 | ASSERT_OK(expected_builder.Resize(kTotalValues * kRepeats)); | |
739 | ASSERT_OK(input_builder.Resize(kTotalValues * kRepeats)); | |
740 | ASSERT_OK(input_builder.ReserveData(kBufSize * kTotalValues * kRepeats)); | |
741 | ||
742 | for (int32_t i = 0; i < kTotalValues * kRepeats; i++) { | |
743 | int32_t index = i % kTotalValues; | |
744 | ||
745 | char buf[kBufSize] = "test"; | |
746 | ASSERT_GE(snprintf(buf + 4, sizeof(buf) - 4, "%d", index), 0); | |
747 | ||
748 | input_builder.UnsafeAppend(util::string_view(buf)); | |
749 | expected_builder.UnsafeAppend(index); | |
750 | } | |
751 | ||
752 | std::shared_ptr<Array> input, value_set, expected; | |
753 | ASSERT_OK(input_builder.Finish(&input)); | |
754 | value_set = input->Slice(0, kTotalValues); | |
755 | ASSERT_OK(expected_builder.Finish(&expected)); | |
756 | ||
757 | ASSERT_OK_AND_ASSIGN(Datum actual_datum, IndexIn(input, value_set)); | |
758 | std::shared_ptr<Array> actual = actual_datum.make_array(); | |
759 | ASSERT_ARRAYS_EQUAL(*expected, *actual); | |
760 | } | |
761 | ||
762 | TEST_F(TestIndexInKernel, FixedSizeBinary) { | |
763 | CheckIndexIn(fixed_size_binary(3), | |
764 | /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", | |
765 | /*value_set=*/R"(["aaa", null, "bbb", "ccc"])", | |
766 | /*expected=*/R"([2, 1, null, 0, 3, 0])"); | |
767 | CheckIndexIn(fixed_size_binary(3), | |
768 | /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", | |
769 | /*value_set=*/R"(["aaa", null, "bbb", "ccc"])", | |
770 | /*expected=*/R"([2, null, null, 0, 3, 0])", | |
771 | /*skip_nulls=*/true); | |
772 | ||
773 | CheckIndexIn(fixed_size_binary(3), | |
774 | /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", | |
775 | /*value_set=*/R"(["aaa", "bbb", "ccc"])", | |
776 | /*expected=*/R"([1, null, null, 0, 2, 0])"); | |
777 | CheckIndexIn(fixed_size_binary(3), | |
778 | /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", | |
779 | /*value_set=*/R"(["aaa", "bbb", "ccc"])", | |
780 | /*expected=*/R"([1, null, null, 0, 2, 0])", | |
781 | /*skip_nulls=*/true); | |
782 | ||
783 | // Duplicates in value_set | |
784 | CheckIndexIn(fixed_size_binary(3), | |
785 | /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", | |
786 | /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", "ccc"])", | |
787 | /*expected=*/R"([4, 2, null, 0, 6, 0])"); | |
788 | CheckIndexIn(fixed_size_binary(3), | |
789 | /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])", | |
790 | /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", "ccc"])", | |
791 | /*expected=*/R"([4, null, null, 0, 6, 0])", | |
792 | /*skip_nulls=*/true); | |
793 | ||
794 | // Empty input array | |
795 | CheckIndexIn(fixed_size_binary(5), R"([])", R"(["bbbbb", null, "aaaaa", "ccccc"])", | |
796 | R"([])"); | |
797 | ||
798 | // Empty value_set array | |
799 | CheckIndexIn(fixed_size_binary(5), R"(["bbbbb", null, "bbbbb"])", R"([])", | |
800 | R"([null, null, null])"); | |
801 | ||
802 | // Empty arrays | |
803 | CheckIndexIn(fixed_size_binary(0), R"([])", R"([])", R"([])"); | |
804 | } | |
805 | ||
806 | TEST_F(TestIndexInKernel, MonthDayNanoInterval) { | |
807 | auto type = month_day_nano_interval(); | |
808 | ||
809 | CheckIndexIn(type, | |
810 | /*input=*/R"([[5, -1, 5], null, [4, 5, 6], [5, -1, 5], [1, 2, 3]])", | |
811 | /*value_set=*/R"([null, [4, 5, 6], [5, -1, 5]])", | |
812 | /*expected=*/R"([2, 0, 1, 2, null])", | |
813 | /*skip_nulls=*/false); | |
814 | ||
815 | // Duplicates in value_set | |
816 | CheckIndexIn( | |
817 | type, | |
818 | /*input=*/R"([[7, 8, 0], null, [0, 0, 0], [7, 8, 0], [0, 0, 1]])", | |
819 | /*value_set=*/R"([null, null, [0, 0, 0], [0, 0, 0], [7, 8, 0], [7, 8, 0]])", | |
820 | /*expected=*/R"([4, 0, 2, 4, null])", | |
821 | /*skip_nulls=*/false); | |
822 | } | |
823 | ||
824 | TEST_F(TestIndexInKernel, Decimal) { | |
825 | auto type = decimal(2, 0); | |
826 | ||
827 | CheckIndexIn(type, | |
828 | /*input=*/R"(["12", null, "11", "12", "13"])", | |
829 | /*value_set=*/R"([null, "11", "12"])", | |
830 | /*expected=*/R"([2, 0, 1, 2, null])", | |
831 | /*skip_nulls=*/false); | |
832 | CheckIndexIn(type, | |
833 | /*input=*/R"(["12", null, "11", "12", "13"])", | |
834 | /*value_set=*/R"([null, "11", "12"])", | |
835 | /*expected=*/R"([2, null, 1, 2, null])", | |
836 | /*skip_nulls=*/true); | |
837 | ||
838 | CheckIndexIn(type, | |
839 | /*input=*/R"(["12", null, "11", "12", "13"])", | |
840 | /*value_set=*/R"(["11", "12"])", | |
841 | /*expected=*/R"([1, null, 0, 1, null])", | |
842 | /*skip_nulls=*/false); | |
843 | CheckIndexIn(type, | |
844 | /*input=*/R"(["12", null, "11", "12", "13"])", | |
845 | /*value_set=*/R"(["11", "12"])", | |
846 | /*expected=*/R"([1, null, 0, 1, null])", | |
847 | /*skip_nulls=*/true); | |
848 | ||
849 | // Duplicates in value_set | |
850 | CheckIndexIn(type, | |
851 | /*input=*/R"(["12", null, "11", "12", "13"])", | |
852 | /*value_set=*/R"([null, null, "11", "11", "12", "12"])", | |
853 | /*expected=*/R"([4, 0, 2, 4, null])", | |
854 | /*skip_nulls=*/false); | |
855 | CheckIndexIn(type, | |
856 | /*input=*/R"(["12", null, "11", "12", "13"])", | |
857 | /*value_set=*/R"([null, null, "11", "11", "12", "12"])", | |
858 | /*expected=*/R"([4, null, 2, 4, null])", | |
859 | /*skip_nulls=*/true); | |
860 | } | |
861 | ||
862 | TEST_F(TestIndexInKernel, DictionaryArray) { | |
863 | for (auto index_ty : all_dictionary_index_types()) { | |
864 | CheckIndexInDictionary(/*type=*/utf8(), | |
865 | /*index_type=*/index_ty, | |
866 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
867 | /*input_index_json=*/"[1, 2, null, 0]", | |
868 | /*value_set_json=*/R"(["A", "B", "C"])", | |
869 | /*expected_json=*/"[1, 2, null, 0]", | |
870 | /*skip_nulls=*/false); | |
871 | CheckIndexInDictionary(/*type=*/float32(), | |
872 | /*index_type=*/index_ty, | |
873 | /*input_dictionary_json=*/"[4.1, -1.0, 42, 9.8]", | |
874 | /*input_index_json=*/"[1, 2, null, 0]", | |
875 | /*value_set_json=*/"[4.1, 42, -1.0]", | |
876 | /*expected_json=*/"[2, 1, null, 0]", | |
877 | /*skip_nulls=*/false); | |
878 | ||
879 | // With nulls and skip_nulls=false | |
880 | CheckIndexInDictionary(/*type=*/utf8(), | |
881 | /*index_type=*/index_ty, | |
882 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
883 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
884 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
885 | /*expected_json=*/"[1, null, 3, 2, 1]", | |
886 | /*skip_nulls=*/false); | |
887 | CheckIndexInDictionary(/*type=*/utf8(), | |
888 | /*index_type=*/index_ty, | |
889 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
890 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
891 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
892 | /*expected_json=*/"[3, null, 3, 2, 3]", | |
893 | /*skip_nulls=*/false); | |
894 | CheckIndexInDictionary(/*type=*/utf8(), | |
895 | /*index_type=*/index_ty, | |
896 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
897 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
898 | /*value_set_json=*/R"(["C", "B", "A"])", | |
899 | /*expected_json=*/"[null, null, null, 2, null]", | |
900 | /*skip_nulls=*/false); | |
901 | ||
902 | // With nulls and skip_nulls=true | |
903 | CheckIndexInDictionary(/*type=*/utf8(), | |
904 | /*index_type=*/index_ty, | |
905 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
906 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
907 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
908 | /*expected_json=*/"[1, null, null, 2, 1]", | |
909 | /*skip_nulls=*/true); | |
910 | CheckIndexInDictionary(/*type=*/utf8(), | |
911 | /*index_type=*/index_ty, | |
912 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
913 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
914 | /*value_set_json=*/R"(["C", "B", "A", null])", | |
915 | /*expected_json=*/"[null, null, null, 2, null]", | |
916 | /*skip_nulls=*/true); | |
917 | CheckIndexInDictionary(/*type=*/utf8(), | |
918 | /*index_type=*/index_ty, | |
919 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
920 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
921 | /*value_set_json=*/R"(["C", "B", "A"])", | |
922 | /*expected_json=*/"[null, null, null, 2, null]", | |
923 | /*skip_nulls=*/true); | |
924 | ||
925 | // With duplicates in value_set | |
926 | CheckIndexInDictionary(/*type=*/utf8(), | |
927 | /*index_type=*/index_ty, | |
928 | /*input_dictionary_json=*/R"(["A", "B", "C", "D"])", | |
929 | /*input_index_json=*/"[1, 2, null, 0]", | |
930 | /*value_set_json=*/R"(["A", "A", "B", "B", "C", "C"])", | |
931 | /*expected_json=*/"[2, 4, null, 0]", | |
932 | /*skip_nulls=*/false); | |
933 | CheckIndexInDictionary(/*type=*/utf8(), | |
934 | /*index_type=*/index_ty, | |
935 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
936 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
937 | /*value_set_json=*/R"(["C", "C", "B", "B", "A", "A", null])", | |
938 | /*expected_json=*/"[6, null, 6, 4, 6]", | |
939 | /*skip_nulls=*/false); | |
940 | CheckIndexInDictionary(/*type=*/utf8(), | |
941 | /*index_type=*/index_ty, | |
942 | /*input_dictionary_json=*/R"(["A", null, "C", "D"])", | |
943 | /*input_index_json=*/"[1, 3, null, 0, 1]", | |
944 | /*value_set_json=*/R"(["C", "C", "B", "B", "A", "A", null])", | |
945 | /*expected_json=*/"[null, null, null, 4, null]", | |
946 | /*skip_nulls=*/true); | |
947 | } | |
948 | } | |
949 | ||
950 | TEST_F(TestIndexInKernel, ChunkedArrayInvoke) { | |
951 | auto input = ChunkedArrayFromJSON(utf8(), {R"(["abc", "def", "ghi", "abc", "jkl"])", | |
952 | R"(["def", null, "abc", "zzz"])"}); | |
953 | // No null in value_set | |
954 | auto value_set = ChunkedArrayFromJSON(utf8(), {R"(["ghi", "def"])", R"(["abc"])"}); | |
955 | auto expected = | |
956 | ChunkedArrayFromJSON(int32(), {"[2, 1, 0, 2, null]", "[1, null, 2, null]"}); | |
957 | ||
958 | CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false); | |
959 | CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); | |
960 | ||
961 | // Null in value_set | |
962 | value_set = ChunkedArrayFromJSON(utf8(), {R"(["ghi", "def"])", R"([null, "abc"])"}); | |
963 | expected = ChunkedArrayFromJSON(int32(), {"[3, 1, 0, 3, null]", "[1, 2, 3, null]"}); | |
964 | CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false); | |
965 | expected = ChunkedArrayFromJSON(int32(), {"[3, 1, 0, 3, null]", "[1, null, 3, null]"}); | |
966 | CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); | |
967 | ||
968 | // Duplicates in value_set | |
969 | value_set = ChunkedArrayFromJSON( | |
970 | utf8(), {R"(["ghi", "ghi", "def"])", R"(["def", null, null, "abc"])"}); | |
971 | expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, 4, 6, null]"}); | |
972 | CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false); | |
973 | expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, null, 6, null]"}); | |
974 | CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); | |
975 | } | |
976 | ||
977 | TEST(TestSetLookup, DispatchBest) { | |
978 | for (std::string name : {"is_in", "index_in"}) { | |
979 | CheckDispatchBest(name, {int32()}, {int32()}); | |
980 | CheckDispatchBest(name, {dictionary(int32(), utf8())}, {utf8()}); | |
981 | } | |
982 | } | |
983 | ||
984 | TEST(TestSetLookup, IsInWithImplicitCasts) { | |
985 | SetLookupOptions opts{ArrayFromJSON(utf8(), R"(["b", null])")}; | |
986 | CheckScalarUnary("is_in", | |
987 | ArrayFromJSON(dictionary(int32(), utf8()), R"(["a", "b", "c", null])"), | |
988 | ArrayFromJSON(boolean(), "[0, 1, 0, 1]"), &opts); | |
989 | } | |
990 | ||
991 | } // namespace compute | |
992 | } // namespace arrow |