]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/compute/kernels/test_util.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / kernels / test_util.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/compute/kernels/test_util.h"
19
20#include <cstdint>
21#include <memory>
22#include <string>
23
24#include "arrow/array.h"
25#include "arrow/array/validate.h"
26#include "arrow/chunked_array.h"
27#include "arrow/compute/cast.h"
28#include "arrow/compute/exec.h"
29#include "arrow/compute/function.h"
30#include "arrow/compute/registry.h"
31#include "arrow/datum.h"
32#include "arrow/result.h"
33#include "arrow/table.h"
34#include "arrow/testing/gtest_util.h"
35
36namespace arrow {
37namespace compute {
38
39namespace {
40
41template <typename T>
42DatumVector GetDatums(const std::vector<T>& inputs) {
43 std::vector<Datum> datums;
44 for (const auto& input : inputs) {
45 datums.emplace_back(input);
46 }
47 return datums;
48}
49
50template <typename... SliceArgs>
51DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) {
52 DatumVector sliced;
53 for (const auto& input : inputs) {
54 if (input.is_array()) {
55 sliced.push_back(*input.make_array()->Slice(slice_args...));
56 } else {
57 sliced.push_back(input);
58 }
59 }
60 return sliced;
61}
62
63ScalarVector GetScalars(const DatumVector& inputs, int64_t index) {
64 ScalarVector scalars;
65 for (const auto& input : inputs) {
66 if (input.is_array()) {
67 scalars.push_back(*input.make_array()->GetScalar(index));
68 } else {
69 scalars.push_back(input.scalar());
70 }
71 }
72 return scalars;
73}
74
75} // namespace
76
77void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
78 const Datum& expected, const FunctionOptions* options) {
79 ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options));
80 ValidateOutput(out);
81 AssertDatumsEqual(expected, out, /*verbose=*/true);
82}
83
84void CheckScalar(std::string func_name, const ScalarVector& inputs,
85 std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
86 ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
87 ValidateOutput(out);
88 if (!out.scalar()->Equals(expected)) {
89 std::string summary = func_name + "(";
90 for (const auto& input : inputs) {
91 summary += input->ToString() + ",";
92 }
93 summary.back() = ')';
94
95 summary += " = " + out.scalar()->ToString() + " != " + expected->ToString();
96
97 if (!out.type()->Equals(expected->type)) {
98 summary += " (types differed: " + out.type()->ToString() + " vs " +
99 expected->type->ToString() + ")";
100 }
101
102 FAIL() << summary;
103 }
104}
105
106void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected_datum,
107 const FunctionOptions* options) {
108 CheckScalarNonRecursive(func_name, inputs, expected_datum, options);
109
110 if (expected_datum.is_scalar()) return;
111 ASSERT_TRUE(expected_datum.is_array())
112 << "CheckScalar is only implemented for scalar/array expected values";
113 auto expected = expected_datum.make_array();
114
115 // check for at least 1 array, and make sure the others are of equal length
116 bool has_array = false;
117 for (const auto& input : inputs) {
118 if (input.is_array()) {
119 ASSERT_EQ(input.array()->length, expected->length());
120 has_array = true;
121 }
122 }
123 ASSERT_TRUE(has_array) << "Must have at least 1 array input to have an array output";
124
125 // Check all the input scalars
126 for (int64_t i = 0; i < expected->length(); ++i) {
127 CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options);
128 }
129
130 // Since it's a scalar function, calling it on sliced inputs should
131 // result in the sliced expected output.
132 const auto slice_length = expected->length() / 3;
133 if (slice_length > 0) {
134 CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, slice_length),
135 expected->Slice(0, slice_length), options);
136
137 CheckScalarNonRecursive(func_name, SliceArrays(inputs, slice_length, slice_length),
138 expected->Slice(slice_length, slice_length), options);
139
140 CheckScalarNonRecursive(func_name, SliceArrays(inputs, 2 * slice_length),
141 expected->Slice(2 * slice_length), options);
142 }
143
144 // Should also work with an empty slice
145 CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, 0), expected->Slice(0, 0),
146 options);
147
148 // Ditto with ChunkedArray inputs
149 if (slice_length > 0) {
150 DatumVector chunked_inputs;
151 chunked_inputs.reserve(inputs.size());
152 for (const auto& input : inputs) {
153 if (input.is_array()) {
154 auto ar = input.make_array();
155 auto ar_chunked = std::make_shared<ChunkedArray>(
156 ArrayVector{ar->Slice(0, slice_length), ar->Slice(slice_length)});
157 chunked_inputs.push_back(ar_chunked);
158 } else {
159 chunked_inputs.push_back(input.scalar());
160 }
161 }
162 ArrayVector expected_chunks{expected->Slice(0, slice_length),
163 expected->Slice(slice_length)};
164
165 ASSERT_OK_AND_ASSIGN(Datum out,
166 CallFunction(func_name, GetDatums(chunked_inputs), options));
167 ValidateOutput(out);
168 auto chunked = out.chunked_array();
169 (void)chunked;
170 AssertDatumsEqual(std::make_shared<ChunkedArray>(expected_chunks), out);
171 }
172}
173
174Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args,
175 bool result_is_encoded) {
176 EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args));
177 ValidateOutput(actual);
178
179 DatumVector decoded_args;
180 decoded_args.reserve(args.size());
181 for (const auto& arg : args) {
182 if (arg.type()->id() == Type::DICTIONARY) {
183 const auto& to_type = checked_cast<const DictionaryType&>(*arg.type()).value_type();
184 EXPECT_OK_AND_ASSIGN(auto decoded, Cast(arg, to_type));
185 decoded_args.push_back(decoded);
186 } else {
187 decoded_args.push_back(arg);
188 }
189 }
190 EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args));
191
192 if (result_is_encoded) {
193 EXPECT_EQ(Type::DICTIONARY, actual.type()->id())
194 << "Result should have been dictionary-encoded";
195 // Decode before comparison - we care about equivalent not identical results
196 const auto& to_type =
197 checked_cast<const DictionaryType&>(*actual.type()).value_type();
198 EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type));
199 AssertDatumsApproxEqual(expected, decoded, /*verbose=*/true);
200 } else {
201 AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
202 }
203 return actual;
204}
205
206void CheckDictionary(const std::string& func_name, const DatumVector& args,
207 bool result_is_encoded) {
208 auto actual = CheckDictionaryNonRecursive(func_name, args, result_is_encoded);
209
210 if (actual.is_scalar()) return;
211 ASSERT_TRUE(actual.is_array());
212 ASSERT_GE(actual.length(), 0);
213
214 // Check all scalars
215 for (int64_t i = 0; i < actual.length(); i++) {
216 CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i)),
217 result_is_encoded);
218 }
219
220 // Check slices of the input
221 const auto slice_length = actual.length() / 3;
222 if (slice_length > 0) {
223 CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length),
224 result_is_encoded);
225 CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length),
226 result_is_encoded);
227 CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length),
228 result_is_encoded);
229 }
230
231 // Check empty slice
232 CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0), result_is_encoded);
233
234 // Check chunked arrays
235 if (slice_length > 0) {
236 DatumVector chunked_args;
237 chunked_args.reserve(args.size());
238 for (const auto& arg : args) {
239 if (arg.is_array()) {
240 auto arr = arg.make_array();
241 ArrayVector chunks{arr->Slice(0, slice_length), arr->Slice(slice_length)};
242 chunked_args.push_back(std::make_shared<ChunkedArray>(std::move(chunks)));
243 } else {
244 chunked_args.push_back(arg);
245 }
246 }
247 CheckDictionaryNonRecursive(func_name, chunked_args, result_is_encoded);
248 }
249}
250
251void CheckScalarUnary(std::string func_name, Datum input, Datum expected,
252 const FunctionOptions* options) {
253 std::vector<Datum> input_vector = {std::move(input)};
254 CheckScalar(std::move(func_name), input_vector, expected, options);
255}
256
257void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
258 std::string json_input, std::shared_ptr<DataType> out_ty,
259 std::string json_expected, const FunctionOptions* options) {
260 CheckScalarUnary(std::move(func_name), ArrayFromJSON(in_ty, json_input),
261 ArrayFromJSON(out_ty, json_expected), options);
262}
263
264void CheckVectorUnary(std::string func_name, Datum input, Datum expected,
265 const FunctionOptions* options) {
266 ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, {input}, options));
267 ValidateOutput(actual);
268 AssertDatumsEqual(expected, actual, /*verbose=*/true);
269}
270
271void CheckScalarBinary(std::string func_name, Datum left_input, Datum right_input,
272 Datum expected, const FunctionOptions* options) {
273 CheckScalar(std::move(func_name), {left_input, right_input}, expected, options);
274}
275
276namespace {
277
278void ValidateOutput(const ArrayData& output) {
279 ASSERT_OK(::arrow::internal::ValidateArrayFull(output));
280 TestInitialized(output);
281}
282
283void ValidateOutput(const ChunkedArray& output) {
284 ASSERT_OK(output.ValidateFull());
285 for (const auto& chunk : output.chunks()) {
286 TestInitialized(*chunk);
287 }
288}
289
290void ValidateOutput(const RecordBatch& output) {
291 ASSERT_OK(output.ValidateFull());
292 for (const auto& column : output.column_data()) {
293 TestInitialized(*column);
294 }
295}
296
297void ValidateOutput(const Table& output) {
298 ASSERT_OK(output.ValidateFull());
299 for (const auto& column : output.columns()) {
300 for (const auto& chunk : column->chunks()) {
301 TestInitialized(*chunk);
302 }
303 }
304}
305
306void ValidateOutput(const Scalar& output) { ASSERT_OK(output.ValidateFull()); }
307
308} // namespace
309
310void ValidateOutput(const Datum& output) {
311 switch (output.kind()) {
312 case Datum::ARRAY:
313 ValidateOutput(*output.array());
314 break;
315 case Datum::CHUNKED_ARRAY:
316 ValidateOutput(*output.chunked_array());
317 break;
318 case Datum::RECORD_BATCH:
319 ValidateOutput(*output.record_batch());
320 break;
321 case Datum::TABLE:
322 ValidateOutput(*output.table());
323 break;
324 case Datum::SCALAR:
325 ValidateOutput(*output.scalar());
326 break;
327 default:
328 break;
329 }
330}
331
332void CheckDispatchBest(std::string func_name, std::vector<ValueDescr> original_values,
333 std::vector<ValueDescr> expected_equivalent_values) {
334 ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
335
336 auto values = original_values;
337 ASSERT_OK_AND_ASSIGN(auto actual_kernel, function->DispatchBest(&values));
338
339 ASSERT_OK_AND_ASSIGN(auto expected_kernel,
340 function->DispatchExact(expected_equivalent_values));
341
342 EXPECT_EQ(actual_kernel, expected_kernel)
343 << " DispatchBest" << ValueDescr::ToString(original_values) << " => "
344 << actual_kernel->signature->ToString() << "\n"
345 << " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => "
346 << expected_kernel->signature->ToString();
347 EXPECT_EQ(values.size(), expected_equivalent_values.size());
348 for (size_t i = 0; i < values.size(); i++) {
349 EXPECT_EQ(values[i].shape, expected_equivalent_values[i].shape)
350 << "Argument " << i << " should have the same shape";
351 AssertTypeEqual(values[i].type, expected_equivalent_values[i].type);
352 }
353}
354
355void CheckDispatchFails(std::string func_name, std::vector<ValueDescr> values) {
356 ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name));
357 ASSERT_NOT_OK(function->DispatchBest(&values));
358 ASSERT_NOT_OK(function->DispatchExact(values));
359}
360
361} // namespace compute
362} // namespace arrow