]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/compute/kernels/test_util.h" | |
19 | ||
20 | #include <cstdint> | |
21 | #include <memory> | |
22 | #include <string> | |
23 | ||
24 | #include "arrow/array.h" | |
25 | #include "arrow/array/validate.h" | |
26 | #include "arrow/chunked_array.h" | |
27 | #include "arrow/compute/cast.h" | |
28 | #include "arrow/compute/exec.h" | |
29 | #include "arrow/compute/function.h" | |
30 | #include "arrow/compute/registry.h" | |
31 | #include "arrow/datum.h" | |
32 | #include "arrow/result.h" | |
33 | #include "arrow/table.h" | |
34 | #include "arrow/testing/gtest_util.h" | |
35 | ||
36 | namespace arrow { | |
37 | namespace compute { | |
38 | ||
39 | namespace { | |
40 | ||
41 | template <typename T> | |
42 | DatumVector GetDatums(const std::vector<T>& inputs) { | |
43 | std::vector<Datum> datums; | |
44 | for (const auto& input : inputs) { | |
45 | datums.emplace_back(input); | |
46 | } | |
47 | return datums; | |
48 | } | |
49 | ||
50 | template <typename... SliceArgs> | |
51 | DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) { | |
52 | DatumVector sliced; | |
53 | for (const auto& input : inputs) { | |
54 | if (input.is_array()) { | |
55 | sliced.push_back(*input.make_array()->Slice(slice_args...)); | |
56 | } else { | |
57 | sliced.push_back(input); | |
58 | } | |
59 | } | |
60 | return sliced; | |
61 | } | |
62 | ||
63 | ScalarVector GetScalars(const DatumVector& inputs, int64_t index) { | |
64 | ScalarVector scalars; | |
65 | for (const auto& input : inputs) { | |
66 | if (input.is_array()) { | |
67 | scalars.push_back(*input.make_array()->GetScalar(index)); | |
68 | } else { | |
69 | scalars.push_back(input.scalar()); | |
70 | } | |
71 | } | |
72 | return scalars; | |
73 | } | |
74 | ||
75 | } // namespace | |
76 | ||
77 | void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, | |
78 | const Datum& expected, const FunctionOptions* options) { | |
79 | ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); | |
80 | ValidateOutput(out); | |
81 | AssertDatumsEqual(expected, out, /*verbose=*/true); | |
82 | } | |
83 | ||
84 | void CheckScalar(std::string func_name, const ScalarVector& inputs, | |
85 | std::shared_ptr<Scalar> expected, const FunctionOptions* options) { | |
86 | ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options)); | |
87 | ValidateOutput(out); | |
88 | if (!out.scalar()->Equals(expected)) { | |
89 | std::string summary = func_name + "("; | |
90 | for (const auto& input : inputs) { | |
91 | summary += input->ToString() + ","; | |
92 | } | |
93 | summary.back() = ')'; | |
94 | ||
95 | summary += " = " + out.scalar()->ToString() + " != " + expected->ToString(); | |
96 | ||
97 | if (!out.type()->Equals(expected->type)) { | |
98 | summary += " (types differed: " + out.type()->ToString() + " vs " + | |
99 | expected->type->ToString() + ")"; | |
100 | } | |
101 | ||
102 | FAIL() << summary; | |
103 | } | |
104 | } | |
105 | ||
106 | void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected_datum, | |
107 | const FunctionOptions* options) { | |
108 | CheckScalarNonRecursive(func_name, inputs, expected_datum, options); | |
109 | ||
110 | if (expected_datum.is_scalar()) return; | |
111 | ASSERT_TRUE(expected_datum.is_array()) | |
112 | << "CheckScalar is only implemented for scalar/array expected values"; | |
113 | auto expected = expected_datum.make_array(); | |
114 | ||
115 | // check for at least 1 array, and make sure the others are of equal length | |
116 | bool has_array = false; | |
117 | for (const auto& input : inputs) { | |
118 | if (input.is_array()) { | |
119 | ASSERT_EQ(input.array()->length, expected->length()); | |
120 | has_array = true; | |
121 | } | |
122 | } | |
123 | ASSERT_TRUE(has_array) << "Must have at least 1 array input to have an array output"; | |
124 | ||
125 | // Check all the input scalars | |
126 | for (int64_t i = 0; i < expected->length(); ++i) { | |
127 | CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options); | |
128 | } | |
129 | ||
130 | // Since it's a scalar function, calling it on sliced inputs should | |
131 | // result in the sliced expected output. | |
132 | const auto slice_length = expected->length() / 3; | |
133 | if (slice_length > 0) { | |
134 | CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, slice_length), | |
135 | expected->Slice(0, slice_length), options); | |
136 | ||
137 | CheckScalarNonRecursive(func_name, SliceArrays(inputs, slice_length, slice_length), | |
138 | expected->Slice(slice_length, slice_length), options); | |
139 | ||
140 | CheckScalarNonRecursive(func_name, SliceArrays(inputs, 2 * slice_length), | |
141 | expected->Slice(2 * slice_length), options); | |
142 | } | |
143 | ||
144 | // Should also work with an empty slice | |
145 | CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, 0), expected->Slice(0, 0), | |
146 | options); | |
147 | ||
148 | // Ditto with ChunkedArray inputs | |
149 | if (slice_length > 0) { | |
150 | DatumVector chunked_inputs; | |
151 | chunked_inputs.reserve(inputs.size()); | |
152 | for (const auto& input : inputs) { | |
153 | if (input.is_array()) { | |
154 | auto ar = input.make_array(); | |
155 | auto ar_chunked = std::make_shared<ChunkedArray>( | |
156 | ArrayVector{ar->Slice(0, slice_length), ar->Slice(slice_length)}); | |
157 | chunked_inputs.push_back(ar_chunked); | |
158 | } else { | |
159 | chunked_inputs.push_back(input.scalar()); | |
160 | } | |
161 | } | |
162 | ArrayVector expected_chunks{expected->Slice(0, slice_length), | |
163 | expected->Slice(slice_length)}; | |
164 | ||
165 | ASSERT_OK_AND_ASSIGN(Datum out, | |
166 | CallFunction(func_name, GetDatums(chunked_inputs), options)); | |
167 | ValidateOutput(out); | |
168 | auto chunked = out.chunked_array(); | |
169 | (void)chunked; | |
170 | AssertDatumsEqual(std::make_shared<ChunkedArray>(expected_chunks), out); | |
171 | } | |
172 | } | |
173 | ||
174 | Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args, | |
175 | bool result_is_encoded) { | |
176 | EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args)); | |
177 | ValidateOutput(actual); | |
178 | ||
179 | DatumVector decoded_args; | |
180 | decoded_args.reserve(args.size()); | |
181 | for (const auto& arg : args) { | |
182 | if (arg.type()->id() == Type::DICTIONARY) { | |
183 | const auto& to_type = checked_cast<const DictionaryType&>(*arg.type()).value_type(); | |
184 | EXPECT_OK_AND_ASSIGN(auto decoded, Cast(arg, to_type)); | |
185 | decoded_args.push_back(decoded); | |
186 | } else { | |
187 | decoded_args.push_back(arg); | |
188 | } | |
189 | } | |
190 | EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args)); | |
191 | ||
192 | if (result_is_encoded) { | |
193 | EXPECT_EQ(Type::DICTIONARY, actual.type()->id()) | |
194 | << "Result should have been dictionary-encoded"; | |
195 | // Decode before comparison - we care about equivalent not identical results | |
196 | const auto& to_type = | |
197 | checked_cast<const DictionaryType&>(*actual.type()).value_type(); | |
198 | EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type)); | |
199 | AssertDatumsApproxEqual(expected, decoded, /*verbose=*/true); | |
200 | } else { | |
201 | AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); | |
202 | } | |
203 | return actual; | |
204 | } | |
205 | ||
206 | void CheckDictionary(const std::string& func_name, const DatumVector& args, | |
207 | bool result_is_encoded) { | |
208 | auto actual = CheckDictionaryNonRecursive(func_name, args, result_is_encoded); | |
209 | ||
210 | if (actual.is_scalar()) return; | |
211 | ASSERT_TRUE(actual.is_array()); | |
212 | ASSERT_GE(actual.length(), 0); | |
213 | ||
214 | // Check all scalars | |
215 | for (int64_t i = 0; i < actual.length(); i++) { | |
216 | CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i)), | |
217 | result_is_encoded); | |
218 | } | |
219 | ||
220 | // Check slices of the input | |
221 | const auto slice_length = actual.length() / 3; | |
222 | if (slice_length > 0) { | |
223 | CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length), | |
224 | result_is_encoded); | |
225 | CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length), | |
226 | result_is_encoded); | |
227 | CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length), | |
228 | result_is_encoded); | |
229 | } | |
230 | ||
231 | // Check empty slice | |
232 | CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0), result_is_encoded); | |
233 | ||
234 | // Check chunked arrays | |
235 | if (slice_length > 0) { | |
236 | DatumVector chunked_args; | |
237 | chunked_args.reserve(args.size()); | |
238 | for (const auto& arg : args) { | |
239 | if (arg.is_array()) { | |
240 | auto arr = arg.make_array(); | |
241 | ArrayVector chunks{arr->Slice(0, slice_length), arr->Slice(slice_length)}; | |
242 | chunked_args.push_back(std::make_shared<ChunkedArray>(std::move(chunks))); | |
243 | } else { | |
244 | chunked_args.push_back(arg); | |
245 | } | |
246 | } | |
247 | CheckDictionaryNonRecursive(func_name, chunked_args, result_is_encoded); | |
248 | } | |
249 | } | |
250 | ||
251 | void CheckScalarUnary(std::string func_name, Datum input, Datum expected, | |
252 | const FunctionOptions* options) { | |
253 | std::vector<Datum> input_vector = {std::move(input)}; | |
254 | CheckScalar(std::move(func_name), input_vector, expected, options); | |
255 | } | |
256 | ||
257 | void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty, | |
258 | std::string json_input, std::shared_ptr<DataType> out_ty, | |
259 | std::string json_expected, const FunctionOptions* options) { | |
260 | CheckScalarUnary(std::move(func_name), ArrayFromJSON(in_ty, json_input), | |
261 | ArrayFromJSON(out_ty, json_expected), options); | |
262 | } | |
263 | ||
264 | void CheckVectorUnary(std::string func_name, Datum input, Datum expected, | |
265 | const FunctionOptions* options) { | |
266 | ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, {input}, options)); | |
267 | ValidateOutput(actual); | |
268 | AssertDatumsEqual(expected, actual, /*verbose=*/true); | |
269 | } | |
270 | ||
271 | void CheckScalarBinary(std::string func_name, Datum left_input, Datum right_input, | |
272 | Datum expected, const FunctionOptions* options) { | |
273 | CheckScalar(std::move(func_name), {left_input, right_input}, expected, options); | |
274 | } | |
275 | ||
276 | namespace { | |
277 | ||
278 | void ValidateOutput(const ArrayData& output) { | |
279 | ASSERT_OK(::arrow::internal::ValidateArrayFull(output)); | |
280 | TestInitialized(output); | |
281 | } | |
282 | ||
283 | void ValidateOutput(const ChunkedArray& output) { | |
284 | ASSERT_OK(output.ValidateFull()); | |
285 | for (const auto& chunk : output.chunks()) { | |
286 | TestInitialized(*chunk); | |
287 | } | |
288 | } | |
289 | ||
290 | void ValidateOutput(const RecordBatch& output) { | |
291 | ASSERT_OK(output.ValidateFull()); | |
292 | for (const auto& column : output.column_data()) { | |
293 | TestInitialized(*column); | |
294 | } | |
295 | } | |
296 | ||
297 | void ValidateOutput(const Table& output) { | |
298 | ASSERT_OK(output.ValidateFull()); | |
299 | for (const auto& column : output.columns()) { | |
300 | for (const auto& chunk : column->chunks()) { | |
301 | TestInitialized(*chunk); | |
302 | } | |
303 | } | |
304 | } | |
305 | ||
306 | void ValidateOutput(const Scalar& output) { ASSERT_OK(output.ValidateFull()); } | |
307 | ||
308 | } // namespace | |
309 | ||
310 | void ValidateOutput(const Datum& output) { | |
311 | switch (output.kind()) { | |
312 | case Datum::ARRAY: | |
313 | ValidateOutput(*output.array()); | |
314 | break; | |
315 | case Datum::CHUNKED_ARRAY: | |
316 | ValidateOutput(*output.chunked_array()); | |
317 | break; | |
318 | case Datum::RECORD_BATCH: | |
319 | ValidateOutput(*output.record_batch()); | |
320 | break; | |
321 | case Datum::TABLE: | |
322 | ValidateOutput(*output.table()); | |
323 | break; | |
324 | case Datum::SCALAR: | |
325 | ValidateOutput(*output.scalar()); | |
326 | break; | |
327 | default: | |
328 | break; | |
329 | } | |
330 | } | |
331 | ||
332 | void CheckDispatchBest(std::string func_name, std::vector<ValueDescr> original_values, | |
333 | std::vector<ValueDescr> expected_equivalent_values) { | |
334 | ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name)); | |
335 | ||
336 | auto values = original_values; | |
337 | ASSERT_OK_AND_ASSIGN(auto actual_kernel, function->DispatchBest(&values)); | |
338 | ||
339 | ASSERT_OK_AND_ASSIGN(auto expected_kernel, | |
340 | function->DispatchExact(expected_equivalent_values)); | |
341 | ||
342 | EXPECT_EQ(actual_kernel, expected_kernel) | |
343 | << " DispatchBest" << ValueDescr::ToString(original_values) << " => " | |
344 | << actual_kernel->signature->ToString() << "\n" | |
345 | << " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => " | |
346 | << expected_kernel->signature->ToString(); | |
347 | EXPECT_EQ(values.size(), expected_equivalent_values.size()); | |
348 | for (size_t i = 0; i < values.size(); i++) { | |
349 | EXPECT_EQ(values[i].shape, expected_equivalent_values[i].shape) | |
350 | << "Argument " << i << " should have the same shape"; | |
351 | AssertTypeEqual(values[i].type, expected_equivalent_values[i].type); | |
352 | } | |
353 | } | |
354 | ||
355 | void CheckDispatchFails(std::string func_name, std::vector<ValueDescr> values) { | |
356 | ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name)); | |
357 | ASSERT_NOT_OK(function->DispatchBest(&values)); | |
358 | ASSERT_NOT_OK(function->DispatchExact(values)); | |
359 | } | |
360 | ||
361 | } // namespace compute | |
362 | } // namespace arrow |